From 3b9b71e0ae86cc20c6a0201b027bd45bee4257e5 Mon Sep 17 00:00:00 2001 From: Lu Guanqun Date: Sun, 1 Mar 2015 11:34:39 +0800 Subject: [PATCH] implement overload resolution --- AST.cpp | 159 ++++++++++++++++++++++++++++++++++++++-- AST.h | 15 +++- CompilerContext.cpp | 4 +- ExpressionCompiler.cpp | 6 +- NameAndTypeResolver.cpp | 39 +++++----- NameAndTypeResolver.h | 4 +- Parser.cpp | 7 +- Types.h | 15 +++- 8 files changed, 210 insertions(+), 39 deletions(-) diff --git a/AST.cpp b/AST.cpp index 79b755e97..428c82f21 100644 --- a/AST.cpp +++ b/AST.cpp @@ -76,6 +76,15 @@ void ContractDefinition::checkTypeRequirements() for (ASTPointer const& function: getDefinedFunctions()) function->checkTypeRequirements(); + // check for duplicate declaration + set functions; + for (ASTPointer const& function: getDefinedFunctions()) + { + string signature = function->getCanonicalSignature(); + if (functions.count(signature)) + BOOST_THROW_EXCEPTION(DeclarationError() << errinfo_comment("Duplicate functions are not allowed.")); + functions.insert(signature); + } for (ASTPointer const& variable: m_stateVariables) variable->checkTypeRequirements(); @@ -129,6 +138,7 @@ void ContractDefinition::checkIllegalOverrides() const // TODO unify this at a later point. for this we need to put the constness and the access specifier // into the types map functions; + set functionNames; map modifiers; // We search from derived to base, so the stored item causes the error. @@ -141,7 +151,8 @@ void ContractDefinition::checkIllegalOverrides() const string const& name = function->getName(); if (modifiers.count(name)) BOOST_THROW_EXCEPTION(modifiers[name]->createTypeError("Override changes function to modifier.")); - FunctionDefinition const*& override = functions[name]; + FunctionDefinition const*& override = functions[function->getCanonicalSignature()]; + functionNames.insert(name); if (!override) override = function.get(); else if (override->getVisibility() != function->getVisibility() || @@ -152,13 +163,13 @@ void ContractDefinition::checkIllegalOverrides() const for (ASTPointer const& modifier: contract->getFunctionModifiers()) { string const& name = modifier->getName(); - if (functions.count(name)) - BOOST_THROW_EXCEPTION(functions[name]->createTypeError("Override changes modifier to function.")); ModifierDefinition const*& override = modifiers[name]; if (!override) override = modifier.get(); else if (ModifierType(*override) != ModifierType(*modifier)) BOOST_THROW_EXCEPTION(override->createTypeError("Override changes modifier signature.")); + if (functionNames.count(name)) + BOOST_THROW_EXCEPTION(override->createTypeError("Override changes modifier to function.")); } } } @@ -185,16 +196,21 @@ vector, FunctionTypePointer>> const& ContractDefinition::getIn if (!m_interfaceFunctionList) { set functionsSeen; + set signaturesSeen; m_interfaceFunctionList.reset(new vector, FunctionTypePointer>>()); for (ContractDefinition const* contract: getLinearizedBaseContracts()) { for (ASTPointer const& f: contract->getDefinedFunctions()) - if (f->isPublic() && !f->isConstructor() && !f->getName().empty() && functionsSeen.count(f->getName()) == 0) + { + string functionSignature = f->getCanonicalSignature(); + if (f->isPublic() && !f->isConstructor() && !f->getName().empty() && signaturesSeen.count(functionSignature) == 0) { functionsSeen.insert(f->getName()); - FixedHash<4> hash(dev::sha3(f->getCanonicalSignature())); + signaturesSeen.insert(functionSignature); + FixedHash<4> hash(dev::sha3(functionSignature)); m_interfaceFunctionList->push_back(make_pair(hash, make_shared(*f, false))); } + } for (ASTPointer const& v: contract->getStateVariables()) if (v->isPublic() && functionsSeen.count(v->getName()) == 0) @@ -467,7 +483,43 @@ void Return::checkTypeRequirements() void VariableDeclarationStatement::checkTypeRequirements() { +<<<<<<< HEAD m_variable->checkTypeRequirements(); +======= + // Variables can be declared without type (with "var"), in which case the first assignment + // sets the type. + // Note that assignments before the first declaration are legal because of the special scoping + // rules inherited from JavaScript. + if (m_variable->getValue()) + { + if (m_variable->getType()) + { + std::cout << "getType() ok" << std::endl; + m_variable->getValue()->expectType(*m_variable->getType()); + } + else + { + // no type declared and no previous assignment, infer the type + std::cout << "here's where called...." << std::endl; + Identifier* identifier = dynamic_cast(m_variable->getValue().get()); + if (identifier) + identifier->checkTypeRequirementsFromVariableDeclaration(); + else + m_variable->getValue()->checkTypeRequirements(); + TypePointer type = m_variable->getValue()->getType(); + if (type->getCategory() == Type::Category::IntegerConstant) + { + auto intType = dynamic_pointer_cast(type)->getIntegerType(); + if (!intType) + BOOST_THROW_EXCEPTION(m_variable->getValue()->createTypeError("Invalid integer constant " + type->toString())); + type = intType; + } + else if (type->getCategory() == Type::Category::Void) + BOOST_THROW_EXCEPTION(m_variable->createTypeError("var cannot be void type")); + m_variable->setType(type); + } + } +>>>>>>> implement overload resolution } void Assignment::checkTypeRequirements() @@ -544,10 +596,16 @@ void BinaryOperation::checkTypeRequirements() void FunctionCall::checkTypeRequirements() { - m_expression->checkTypeRequirements(); + // we need to check arguments' type first as their info will be used by m_express(Identifier). for (ASTPointer const& argument: m_arguments) argument->checkTypeRequirements(); + auto identifier = dynamic_cast(m_expression.get()); + if (identifier) + identifier->checkTypeRequirementsWithFunctionCall(*this); + else + m_expression->checkTypeRequirements(); + Type const* expressionType = m_expression->getType().get(); if (isTypeConversion()) { @@ -617,6 +675,19 @@ void FunctionCall::checkTypeRequirements() else m_type = functionType->getReturnParameterTypes().front(); } + else if (OverloadedFunctionType const* overloadedTypes = dynamic_cast(expressionType)) + { + // this only applies to "x(3)" where x is assigned by "var x = f;" where f is an overloaded functions. + overloadedTypes->m_identifier->overloadResolution(*this); + FunctionType const* functionType = dynamic_cast(overloadedTypes->m_identifier->getType().get()); + + // @todo actually the return type should be an anonymous struct, + // but we change it to the type of the first return value until we have structs + if (functionType->getReturnParameterTypes().empty()) + m_type = make_shared(); + else + m_type = functionType->getReturnParameterTypes().front(); + } else BOOST_THROW_EXCEPTION(createTypeError("Type is not callable.")); } @@ -709,16 +780,92 @@ void IndexAccess::checkTypeRequirements() } } +void Identifier::checkTypeRequirementsWithFunctionCall(FunctionCall const& _functionCall) +{ + solAssert(m_referencedDeclaration || !m_overloadedDeclarations.empty(), "Identifier not resolved."); + + if (!m_referencedDeclaration) + overloadResolution(_functionCall); + + checkTypeRequirements(); +} + +void Identifier::checkTypeRequirementsFromVariableDeclaration() +{ + solAssert(m_referencedDeclaration || !m_overloadedDeclarations.empty(), "Identifier not resolved."); + + if (!m_referencedDeclaration) + m_type = make_shared(m_overloadedDeclarations, this); + else + checkTypeRequirements(); + + m_isLValue = true; +} + void Identifier::checkTypeRequirements() { + // var x = f; TODO! solAssert(m_referencedDeclaration, "Identifier not resolved."); m_isLValue = m_referencedDeclaration->isLValue(); + if (m_isLValue) + std::cout << "Identifier: " << string(getName()) << " -> true" << std::endl; + else + std::cout << "Identifier: " << string(getName()) << " -> true" << std::endl; m_type = m_referencedDeclaration->getType(m_currentContract); if (!m_type) BOOST_THROW_EXCEPTION(createTypeError("Declaration referenced before type could be determined.")); } +void Identifier::overloadResolution(FunctionCall const& _functionCall) +{ + solAssert(m_overloadedDeclarations.size() > 1, "FunctionIdentifier not resolved."); + solAssert(!m_referencedDeclaration, "Referenced declaration should be null before overload resolution."); + + bool resolved = false; + + std::vector> arguments = _functionCall.getArguments(); + std::vector> const& argumentNames = _functionCall.getNames(); + + if (argumentNames.empty()) + { + // positional arguments + std::vector possibles; + for (Declaration const* declaration: m_overloadedDeclarations) + { + TypePointer const& function = declaration->getType(); + auto const& functionType = dynamic_cast(*function); + TypePointers const& parameterTypes = functionType.getParameterTypes(); + + if (functionType.takesArbitraryParameters() || + (arguments.size() == parameterTypes.size() && + std::equal(arguments.cbegin(), arguments.cend(), parameterTypes.cbegin(), + [](ASTPointer const& argument, TypePointer const& parameterType) + { + return argument->getType()->isImplicitlyConvertibleTo(*parameterType); + }))) + possibles.push_back(declaration); + } + std::cout << "possibles: " << possibles.size() << std::endl; + if (possibles.empty()) + BOOST_THROW_EXCEPTION(createTypeError("Can't resolve identifier")); + else if (std::none_of(possibles.cbegin() + 1, possibles.cend(), + [&possibles](Declaration const* declaration) + { + return declaration->getScope() == possibles.front()->getScope(); + })) + setReferencedDeclaration(*possibles.front()); + else + BOOST_THROW_EXCEPTION(createTypeError("Can't resolve identifier")); + } + else + { + // named arguments + // TODO: don't support right now + // BOOST_THROW_EXCEPTION(createTypeError("Named arguments with overloaded functions are not supported yet.")); + } +} + void ElementaryTypeNameExpression::checkTypeRequirements() { m_type = make_shared(Type::fromElementaryTypeName(m_typeToken)); diff --git a/AST.h b/AST.h index b21e505e9..fa1d4a92e 100644 --- a/AST.h +++ b/AST.h @@ -1134,8 +1134,8 @@ public: class Identifier: public PrimaryExpression { public: - Identifier(SourceLocation const& _location, ASTPointer const& _name, bool _isCallable): - PrimaryExpression(_location), m_name(_name), m_isCallable(_isCallable) {} + Identifier(SourceLocation const& _location, ASTPointer const& _name): + PrimaryExpression(_location), m_name(_name) {} virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; virtual void checkTypeRequirements() override; @@ -1151,9 +1151,15 @@ public: Declaration const* getReferencedDeclaration() const { return m_referencedDeclaration; } ContractDefinition const* getCurrentContract() const { return m_currentContract; } - bool isCallable() const { return m_isCallable; } + void setOverloadedDeclarations(std::set const& _declarations) { m_overloadedDeclarations = _declarations; } + std::set getOverloadedDeclarations() const { return m_overloadedDeclarations; } + void checkTypeRequirementsWithFunctionCall(FunctionCall const& _functionCall); + void checkTypeRequirementsFromVariableDeclaration(); + + void overloadResolution(FunctionCall const& _functionCall); private: + ASTPointer m_name; /// Declaration the name refers to. @@ -1161,7 +1167,8 @@ private: /// Stores a reference to the current contract. This is needed because types of base contracts /// change depending on the context. ContractDefinition const* m_currentContract = nullptr; - bool m_isCallable = false; + /// A set of overloaded declarations, right now only FunctionDefinition has overloaded declarations. + std::set m_overloadedDeclarations; }; /** diff --git a/CompilerContext.cpp b/CompilerContext.cpp index f787db7fc..b12e01923 100644 --- a/CompilerContext.cpp +++ b/CompilerContext.cpp @@ -108,8 +108,8 @@ eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefiniti for (ASTPointer const& function: contract->getDefinedFunctions()) { if (!function->isConstructor() && - dynamic_cast(*function->getType()).getCanonicalSignature() == - dynamic_cast(*_function.getType()).getCanonicalSignature()) + dynamic_cast(*function->getType(contract)).getCanonicalSignature() == + dynamic_cast(*_function.getType(contract)).getCanonicalSignature()) return getFunctionEntryLabel(*function); } solAssert(false, "Virtual function " + _function.getName() + " not found."); diff --git a/ExpressionCompiler.cpp b/ExpressionCompiler.cpp index 3d7a25311..5e5442ba3 100644 --- a/ExpressionCompiler.cpp +++ b/ExpressionCompiler.cpp @@ -822,7 +822,11 @@ bool ExpressionCompiler::visit(IndexAccess const& _indexAccess) void ExpressionCompiler::endVisit(Identifier const& _identifier) { Declaration const* declaration = _identifier.getReferencedDeclaration(); - if (MagicVariableDeclaration const* magicVar = dynamic_cast(declaration)) + if (declaration == nullptr) + { + // no-op + } + else if (MagicVariableDeclaration const* magicVar = dynamic_cast(declaration)) { if (magicVar->getType()->getCategory() == Type::Category::Contract) // "this" or "super" diff --git a/NameAndTypeResolver.cpp b/NameAndTypeResolver.cpp index f6ee2f1d0..c787ae6b0 100644 --- a/NameAndTypeResolver.cpp +++ b/NameAndTypeResolver.cpp @@ -90,15 +90,15 @@ void NameAndTypeResolver::updateDeclaration(Declaration const& _declaration) solAssert(_declaration.getScope() == nullptr, "Updated declaration outside global scope."); } -Declaration const* NameAndTypeResolver::resolveName(ASTString const& _name, Declaration const* _scope) const +std::set NameAndTypeResolver::resolveName(ASTString const& _name, Declaration const* _scope) const { auto iterator = m_scopes.find(_scope); if (iterator == end(m_scopes)) - return nullptr; + return std::set({}); return iterator->second.resolveName(_name, false); } -Declaration const* NameAndTypeResolver::getNameFromCurrentScope(ASTString const& _name, bool _recursive) +std::set NameAndTypeResolver::getNameFromCurrentScope(ASTString const& _name, bool _recursive) { return m_currentScope->resolveName(_name, _recursive); } @@ -108,13 +108,11 @@ void NameAndTypeResolver::importInheritedScope(ContractDefinition const& _base) auto iterator = m_scopes.find(&_base); solAssert(iterator != end(m_scopes), ""); for (auto const& nameAndDeclaration: iterator->second.getDeclarations()) - { - Declaration const* declaration = nameAndDeclaration.second; - // Import if it was declared in the base, is not the constructor and is visible in derived classes - if (declaration->getScope() == &_base && declaration->getName() != _base.getName() && - declaration->isVisibleInDerivedContracts()) - m_currentScope->registerDeclaration(*declaration); - } + for (auto const& declaration: nameAndDeclaration.second) + // Import if it was declared in the base, is not the constructor and is visible in derived classes + if (declaration->getScope() == &_base && declaration->getName() != _base.getName() && + declaration->isVisibleInDerivedContracts()) + m_currentScope->registerDeclaration(*declaration); } void NameAndTypeResolver::linearizeBaseContracts(ContractDefinition& _contract) const @@ -361,24 +359,31 @@ bool ReferencesResolver::visit(Mapping&) bool ReferencesResolver::visit(UserDefinedTypeName& _typeName) { - Declaration const* declaration = m_resolver.getNameFromCurrentScope(_typeName.getName()); - if (!declaration) + auto declarations = m_resolver.getNameFromCurrentScope(_typeName.getName()); + if (declarations.empty()) BOOST_THROW_EXCEPTION(DeclarationError() << errinfo_sourceLocation(_typeName.getLocation()) << errinfo_comment("Undeclared identifier.")); - _typeName.setReferencedDeclaration(*declaration); + else if (declarations.size() > 1) + BOOST_THROW_EXCEPTION(DeclarationError() << errinfo_sourceLocation(_typeName.getLocation()) + << errinfo_comment("Duplicate identifier.")); + else + _typeName.setReferencedDeclaration(**declarations.begin()); return false; } bool ReferencesResolver::visit(Identifier& _identifier) { - Declaration const* declaration = m_resolver.getNameFromCurrentScope(_identifier.getName()); - if (!declaration) + auto declarations = m_resolver.getNameFromCurrentScope(_identifier.getName()); + if (declarations.empty()) BOOST_THROW_EXCEPTION(DeclarationError() << errinfo_sourceLocation(_identifier.getLocation()) << errinfo_comment("Undeclared identifier.")); - _identifier.setReferencedDeclaration(*declaration, m_currentContract); + else if (declarations.size() == 1) + _identifier.setReferencedDeclaration(**declarations.begin(), m_currentContract); + else + // Duplicate declaration will be checked in checkTypeRequirements() + _identifier.setOverloadedDeclarations(declarations); return false; } - } } diff --git a/NameAndTypeResolver.h b/NameAndTypeResolver.h index 63b8ab637..828776179 100644 --- a/NameAndTypeResolver.h +++ b/NameAndTypeResolver.h @@ -56,11 +56,11 @@ public: /// Resolves the given @a _name inside the scope @a _scope. If @a _scope is omitted, /// the global scope is used (i.e. the one containing only the contract). /// @returns a pointer to the declaration on success or nullptr on failure. - Declaration const* resolveName(ASTString const& _name, Declaration const* _scope = nullptr) const; + std::set resolveName(ASTString const& _name, Declaration const* _scope = nullptr) const; /// Resolves a name in the "current" scope. Should only be called during the initial /// resolving phase. - Declaration const* getNameFromCurrentScope(ASTString const& _name, bool _recursive = true); + std::set getNameFromCurrentScope(ASTString const& _name, bool _recursive = true); private: void reset(); diff --git a/Parser.cpp b/Parser.cpp index cecf772da..44d111591 100644 --- a/Parser.cpp +++ b/Parser.cpp @@ -837,14 +837,9 @@ ASTPointer Parser::parsePrimaryExpression() expression = nodeFactory.createNode(token, getLiteralAndAdvance()); break; case Token::Identifier: - { nodeFactory.markEndPosition(); - // if the next token is '(', this identifier looks like function call, - // it could be a contract, event etc. - bool isCallable = m_scanner->peekNextToken() == Token::LParen; - expression = nodeFactory.createNode(getLiteralAndAdvance(), isCallable); + expression = nodeFactory.createNode(getLiteralAndAdvance()); break; - } case Token::LParen: { m_scanner->next(); diff --git a/Types.h b/Types.h index 6cef8d64a..5dd742ada 100644 --- a/Types.h +++ b/Types.h @@ -77,7 +77,7 @@ public: enum class Category { Integer, IntegerConstant, Bool, Real, Array, - String, Contract, Struct, Function, Enum, + String, Contract, Struct, Function, OverloadedFunctions, Enum, Mapping, Void, TypeType, Modifier, Magic }; @@ -524,6 +524,19 @@ private: Declaration const* m_declaration = nullptr; }; +class OverloadedFunctionType: public Type +{ +public: + explicit OverloadedFunctionType(std::set const& _overloadedDeclarations, Identifier* _identifier): + m_overloadedDeclarations(_overloadedDeclarations), m_identifier(_identifier) {} + virtual Category getCategory() const override { return Category::OverloadedFunctions; } + virtual std::string toString() const override { return "OverloadedFunctions"; } + +// private: + std::set m_overloadedDeclarations; + Identifier * m_identifier; +}; + /** * The type of a mapping, there is one distinct type per key/value type pair. */