diff --git a/libsolidity/ast/AST.cpp b/libsolidity/ast/AST.cpp index a76e6e274..b711758d6 100644 --- a/libsolidity/ast/AST.cpp +++ b/libsolidity/ast/AST.cpp @@ -319,6 +319,37 @@ FunctionDefinitionAnnotation& FunctionDefinition::annotation() const return initAnnotation(); } +FunctionDefinition const& FunctionDefinition::resolveVirtual( + ContractDefinition const& _mostDerivedContract, + ContractDefinition const* _searchStart +) const +{ + solAssert(!isConstructor(), ""); + // If we are not doing super-lookup and the function is not virtual, we can stop here. + if (_searchStart == nullptr && !virtualSemantics()) + return *this; + + solAssert(!dynamic_cast(*scope()).isLibrary(), ""); + + FunctionType const* functionType = TypeProvider::function(*this)->asCallableFunction(false); + + for (ContractDefinition const* c: _mostDerivedContract.annotation().linearizedBaseContracts) + { + if (_searchStart != nullptr && c != _searchStart) + continue; + _searchStart = nullptr; + for (FunctionDefinition const* function: c->definedFunctions()) + if ( + function->name() == name() && + !function->isConstructor() && + FunctionType(*function).asCallableFunction(false)->hasEqualParameterTypes(*functionType) + ) + return *function; + } + solAssert(false, "Virtual function " + name() + " not found."); + return *this; // not reached +} + TypePointer ModifierDefinition::type() const { return TypeProvider::modifier(*this); @@ -329,6 +360,33 @@ ModifierDefinitionAnnotation& ModifierDefinition::annotation() const return initAnnotation(); } +ModifierDefinition const& ModifierDefinition::resolveVirtual( + ContractDefinition const& _mostDerivedContract, + ContractDefinition const* _searchStart +) const +{ + solAssert(_searchStart == nullptr, "Used super in connection with modifiers."); + + // If we are not doing super-lookup and the modifier is not virtual, we can stop here. + if (_searchStart == nullptr && !virtualSemantics()) + return *this; + + solAssert(!dynamic_cast(*scope()).isLibrary(), ""); + + for (ContractDefinition const* c: _mostDerivedContract.annotation().linearizedBaseContracts) + { + if (_searchStart != nullptr && c != _searchStart) + continue; + _searchStart = nullptr; + for (ModifierDefinition const* modifier: c->functionModifiers()) + if (modifier->name() == name()) + return *modifier; + } + solAssert(false, "Virtual modifier " + name() + " not found."); + return *this; // not reached +} + + TypePointer EventDefinition::type() const { return TypeProvider::function(*this); diff --git a/libsolidity/ast/AST.h b/libsolidity/ast/AST.h index 052970bc8..4bbacbdf3 100644 --- a/libsolidity/ast/AST.h +++ b/libsolidity/ast/AST.h @@ -689,6 +689,18 @@ public: CallableDeclarationAnnotation& annotation() const override = 0; + /// Performs virtual or super function/modifier lookup: + /// If @a _searchStart is nullptr, performs virtual function lookup, i.e. + /// searches the inheritance hierarchy of @a _mostDerivedContract towards the base + /// and returns the first function/modifier definition that + /// is overwritten by this callable. + /// If @a _searchStart is non-null, starts searching only from that contract, but + /// still in the hierarchy of @a _mostDerivedContract. + virtual CallableDeclaration const& resolveVirtual( + ContractDefinition const& _mostDerivedContract, + ContractDefinition const* _searchStart = nullptr + ) const = 0; + protected: ASTPointer m_parameters; ASTPointer m_overrides; @@ -799,6 +811,12 @@ public: CallableDeclaration::virtualSemantics() || (annotation().contract && annotation().contract->isInterface()); } + + FunctionDefinition const& resolveVirtual( + ContractDefinition const& _mostDerivedContract, + ContractDefinition const* _searchStart = nullptr + ) const override; + private: StateMutability m_stateMutability; Token const m_kind; @@ -945,6 +963,12 @@ public: ModifierDefinitionAnnotation& annotation() const override; + ModifierDefinition const& resolveVirtual( + ContractDefinition const& _mostDerivedContract, + ContractDefinition const* _searchStart = nullptr + ) const override; + + private: ASTPointer m_body; }; @@ -1010,6 +1034,14 @@ public: EventDefinitionAnnotation& annotation() const override; + CallableDeclaration const& resolveVirtual( + ContractDefinition const&, + ContractDefinition const* + ) const override + { + solAssert(false, "Tried to resolve virtual event."); + } + private: bool m_anonymous = false; }; diff --git a/libsolidity/codegen/CompilerContext.cpp b/libsolidity/codegen/CompilerContext.cpp index 26cff59ec..6984a030b 100644 --- a/libsolidity/codegen/CompilerContext.cpp +++ b/libsolidity/codegen/CompilerContext.cpp @@ -272,57 +272,43 @@ evmasm::AssemblyItem CompilerContext::functionEntryLabelIfExists(Declaration con return m_functionCompilationQueue.entryLabelIfExists(_declaration); } -FunctionDefinition const& CompilerContext::resolveVirtualFunction(FunctionDefinition const& _function) -{ - // Libraries do not allow inheritance and their functions can be inlined, so we should not - // search the inheritance hierarchy (which will be the wrong one in case the function - // is inlined). - if (auto scope = dynamic_cast(_function.scope())) - if (scope->isLibrary()) - return _function; - solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set."); - return resolveVirtualFunction(_function, m_inheritanceHierarchy.begin()); -} - FunctionDefinition const& CompilerContext::superFunction(FunctionDefinition const& _function, ContractDefinition const& _base) { - solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set."); - return resolveVirtualFunction(_function, superContract(_base)); + solAssert(m_mostDerivedContract, "No most derived contract set."); + ContractDefinition const* super = superContract(_base); + solAssert(super, "Super contract not available."); + return _function.resolveVirtual(mostDerivedContract(), super); } FunctionDefinition const* CompilerContext::nextConstructor(ContractDefinition const& _contract) const { - vector::const_iterator it = superContract(_contract); - for (; it != m_inheritanceHierarchy.end(); ++it) - if ((*it)->constructor()) - return (*it)->constructor(); + ContractDefinition const* next = superContract(_contract); + if (next == nullptr) + return nullptr; + for (ContractDefinition const* c: m_mostDerivedContract->annotation().linearizedBaseContracts) + if (next != nullptr && next != c) + continue; + else + { + next = nullptr; + if (c->constructor()) + return c->constructor(); + } return nullptr; } +ContractDefinition const& CompilerContext::mostDerivedContract() const +{ + solAssert(m_mostDerivedContract, "Most derived contract not set."); + return *m_mostDerivedContract; +} + Declaration const* CompilerContext::nextFunctionToCompile() const { return m_functionCompilationQueue.nextFunctionToCompile(); } -ModifierDefinition const& CompilerContext::resolveVirtualFunctionModifier( - ModifierDefinition const& _modifier -) const -{ - // Libraries do not allow inheritance and their functions can be inlined, so we should not - // search the inheritance hierarchy (which will be the wrong one in case the function - // is inlined). - if (auto scope = dynamic_cast(_modifier.scope())) - if (scope->isLibrary()) - return _modifier; - solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set."); - for (ContractDefinition const* contract: m_inheritanceHierarchy) - for (ModifierDefinition const* modifier: contract->functionModifiers()) - if (modifier->name() == _modifier.name()) - return *modifier; - solAssert(false, "Function modifier " + _modifier.name() + " not found in inheritance hierarchy."); -} - unsigned CompilerContext::baseStackOffsetOfVariable(Declaration const& _declaration) const { auto res = m_localVariables.find(&_declaration); @@ -556,32 +542,19 @@ LinkerObject const& CompilerContext::assembledObject() const return object; } -FunctionDefinition const& CompilerContext::resolveVirtualFunction( - FunctionDefinition const& _function, - vector::const_iterator _searchStart -) +ContractDefinition const* CompilerContext::superContract(ContractDefinition const& _contract) const { - string name = _function.name(); - FunctionType functionType(_function); - auto it = _searchStart; - for (; it != m_inheritanceHierarchy.end(); ++it) - for (FunctionDefinition const* function: (*it)->definedFunctions()) - if ( - function->name() == name && - !function->isConstructor() && - FunctionType(*function).asCallableFunction(false)->hasEqualParameterTypes(functionType) - ) - return *function; - solAssert(false, "Super function " + name + " not found."); - return _function; // not reached -} - -vector::const_iterator CompilerContext::superContract(ContractDefinition const& _contract) const -{ - solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set."); - auto it = find(m_inheritanceHierarchy.begin(), m_inheritanceHierarchy.end(), &_contract); - solAssert(it != m_inheritanceHierarchy.end(), "Base not found in inheritance hierarchy."); - return ++it; + auto const& hierarchy = mostDerivedContract().annotation().linearizedBaseContracts; + auto it = find(hierarchy.begin(), hierarchy.end(), &_contract); + solAssert(it != hierarchy.end(), "Base not found in inheritance hierarchy."); + ++it; + if (it == hierarchy.end()) + return nullptr; + else + { + solAssert(*it != &_contract, ""); + return *it; + } } string CompilerContext::revertReasonIfDebug(string const& _message) diff --git a/libsolidity/codegen/CompilerContext.h b/libsolidity/codegen/CompilerContext.h index ec846af83..a8e494435 100644 --- a/libsolidity/codegen/CompilerContext.h +++ b/libsolidity/codegen/CompilerContext.h @@ -114,15 +114,14 @@ public: /// @returns the entry label of the given function. Might return an AssemblyItem of type /// UndefinedItem if it does not exist yet. evmasm::AssemblyItem functionEntryLabelIfExists(Declaration const& _declaration) const; - /// @returns the entry label of the given function and takes overrides into account. - FunctionDefinition const& resolveVirtualFunction(FunctionDefinition const& _function); /// @returns the function that overrides the given declaration from the most derived class just /// above _base in the current inheritance hierarchy. FunctionDefinition const& superFunction(FunctionDefinition const& _function, ContractDefinition const& _base); /// @returns the next constructor in the inheritance hierarchy. FunctionDefinition const* nextConstructor(ContractDefinition const& _contract) const; - /// Sets the current inheritance hierarchy from derived to base. - void setInheritanceHierarchy(std::vector const& _hierarchy) { m_inheritanceHierarchy = _hierarchy; } + /// Sets the contract currently being compiled - the most derived one. + void setMostDerivedContract(ContractDefinition const& _contract) { m_mostDerivedContract = &_contract; } + ContractDefinition const& mostDerivedContract() const; /// @returns the next function in the queue of functions that are still to be compiled /// (i.e. that were referenced during compilation but where we did not yet generate code for). @@ -171,7 +170,6 @@ public: /// empty return value. std::pair> requestedYulFunctions(); - ModifierDefinition const& resolveVirtualFunctionModifier(ModifierDefinition const& _modifier) const; /// Returns the distance of the given local variable from the bottom of the stack (of the current function). unsigned baseStackOffsetOfVariable(Declaration const& _declaration) const; /// If supplied by a value returned by @ref baseStackOffsetOfVariable(variable), returns @@ -315,14 +313,8 @@ public: RevertStrings revertStrings() const { return m_revertStrings; } private: - /// Searches the inheritance hierarchy towards the base starting from @a _searchStart and returns - /// the first function definition that is overwritten by _function. - FunctionDefinition const& resolveVirtualFunction( - FunctionDefinition const& _function, - std::vector::const_iterator _searchStart - ); - /// @returns an iterator to the contract directly above the given contract. - std::vector::const_iterator superContract(ContractDefinition const& _contract) const; + /// @returns a pointer to the contract directly above the given contract. + ContractDefinition const* superContract(ContractDefinition const& _contract) const; /// Updates source location set in the assembly. void updateSourceLocation(); @@ -381,8 +373,8 @@ private: /// modifier is applied twice, the position of the variable needs to be restored /// after the nested modifier is left. std::map> m_localVariables; - /// List of current inheritance hierarchy from derived to base. - std::vector m_inheritanceHierarchy; + /// The contract currently being compiled. Virtual function lookup starts from this contarct. + ContractDefinition const* m_mostDerivedContract = nullptr; /// Stack of current visited AST nodes, used for location attachment std::stack m_visitedNodes; /// The runtime context if in Creation mode, this is used for generating tags that would be stored into the storage and then used at runtime. diff --git a/libsolidity/codegen/ContractCompiler.cpp b/libsolidity/codegen/ContractCompiler.cpp index d5f5ffe33..50fb030d2 100644 --- a/libsolidity/codegen/ContractCompiler.cpp +++ b/libsolidity/codegen/ContractCompiler.cpp @@ -129,7 +129,7 @@ void ContractCompiler::initializeContext( { m_context.setExperimentalFeatures(_contract.sourceUnit().annotation().experimentalFeatures); m_context.setOtherCompilers(_otherCompilers); - m_context.setInheritanceHierarchy(_contract.annotation().linearizedBaseContracts); + m_context.setMostDerivedContract(_contract); if (m_runtimeCompiler) registerImmutableVariables(_contract); CompilerUtils(m_context).initialiseFreeMemoryPointer(); @@ -689,7 +689,7 @@ bool ContractCompiler::visit(InlineAssembly const& _inlineAssembly) if (FunctionDefinition const* functionDef = dynamic_cast(decl)) { solAssert(!ref->second.isOffset && !ref->second.isSlot, ""); - functionDef = &m_context.resolveVirtualFunction(*functionDef); + functionDef = &functionDef->resolveVirtual(m_context.mostDerivedContract()); auto functionEntryLabel = m_context.functionEntryLabel(*functionDef).pushTag(); solAssert(functionEntryLabel.data() <= std::numeric_limits::max(), ""); _assembly.appendLabelReference(size_t(functionEntryLabel.data())); @@ -1335,10 +1335,9 @@ void ContractCompiler::appendModifierOrFunctionCode() appendModifierOrFunctionCode(); else { - ModifierDefinition const& nonVirtualModifier = dynamic_cast( + ModifierDefinition const& modifier = dynamic_cast( *modifierInvocation->name()->annotation().referencedDeclaration - ); - ModifierDefinition const& modifier = m_context.resolveVirtualFunctionModifier(nonVirtualModifier); + ).resolveVirtual(m_context.mostDerivedContract()); CompilerContext::LocationSetter locationSetter(m_context, modifier); std::vector> const& modifierArguments = modifierInvocation->arguments() ? *modifierInvocation->arguments() : std::vector>(); diff --git a/libsolidity/codegen/ExpressionCompiler.cpp b/libsolidity/codegen/ExpressionCompiler.cpp index c71d067c4..58992426b 100644 --- a/libsolidity/codegen/ExpressionCompiler.cpp +++ b/libsolidity/codegen/ExpressionCompiler.cpp @@ -572,7 +572,10 @@ bool ExpressionCompiler::visit(FunctionCall const& _functionCall) // Do not directly visit the identifier, because this way, we can avoid // the runtime entry label to be created at the creation time context. CompilerContext::LocationSetter locationSetter2(m_context, *identifier); - utils().pushCombinedFunctionEntryLabel(m_context.resolveVirtualFunction(*functionDef), false); + utils().pushCombinedFunctionEntryLabel( + functionDef->resolveVirtual(m_context.mostDerivedContract()), + false + ); shortcutTaken = true; } } @@ -1861,7 +1864,7 @@ void ExpressionCompiler::endVisit(Identifier const& _identifier) // we want to avoid having a reference to the runtime function entry point in the // constructor context, since this would force the compiler to include unreferenced // internal functions in the runtime contex. - utils().pushCombinedFunctionEntryLabel(m_context.resolveVirtualFunction(*functionDef)); + utils().pushCombinedFunctionEntryLabel(functionDef->resolveVirtual(m_context.mostDerivedContract())); else if (auto variable = dynamic_cast(declaration)) appendVariable(*variable, static_cast(_identifier)); else if (auto contract = dynamic_cast(declaration)) diff --git a/libsolidity/codegen/ir/IRGenerationContext.cpp b/libsolidity/codegen/ir/IRGenerationContext.cpp index 7184247f4..288da5937 100644 --- a/libsolidity/codegen/ir/IRGenerationContext.cpp +++ b/libsolidity/codegen/ir/IRGenerationContext.cpp @@ -31,6 +31,12 @@ using namespace solidity; using namespace solidity::util; using namespace solidity::frontend; +ContractDefinition const& IRGenerationContext::mostDerivedContract() const +{ + solAssert(m_mostDerivedContract, "Most derived contract requested but not set."); + return *m_mostDerivedContract; +} + IRVariable const& IRGenerationContext::addLocalVariable(VariableDeclaration const& _varDecl) { auto const& [it, didInsert] = m_localVariables.emplace( @@ -70,26 +76,9 @@ string IRGenerationContext::functionName(VariableDeclaration const& _varDecl) return "getter_fun_" + _varDecl.name() + "_" + to_string(_varDecl.id()); } -FunctionDefinition const& IRGenerationContext::virtualFunction(FunctionDefinition const& _function) -{ - // @TODO previously, we had to distinguish creation context and runtime context, - // but since we do not work with jump positions anymore, this should not be a problem, right? - string name = _function.name(); - FunctionType functionType(_function); - for (auto const& contract: m_inheritanceHierarchy) - for (FunctionDefinition const* function: contract->definedFunctions()) - if ( - function->name() == name && - !function->isConstructor() && - FunctionType(*function).asCallableFunction(false)->hasEqualParameterTypes(functionType) - ) - return *function; - solAssert(false, "Super function " + name + " not found."); -} - string IRGenerationContext::virtualFunctionName(FunctionDefinition const& _functionDeclaration) { - return functionName(virtualFunction(_functionDeclaration)); + return functionName(_functionDeclaration.resolveVirtual(mostDerivedContract())); } string IRGenerationContext::newYulVariable() @@ -120,7 +109,7 @@ string IRGenerationContext::internalDispatch(size_t _in, size_t _out) templ("arrow", _out > 0 ? "->" : ""); templ("out", suffixedVariableNameList("out_", 0, _out)); vector> functions; - for (auto const& contract: m_inheritanceHierarchy) + for (auto const& contract: mostDerivedContract().annotation().linearizedBaseContracts) for (FunctionDefinition const* function: contract->definedFunctions()) if ( !function->isConstructor() && diff --git a/libsolidity/codegen/ir/IRGenerationContext.h b/libsolidity/codegen/ir/IRGenerationContext.h index 473b62482..b0fc92cdb 100644 --- a/libsolidity/codegen/ir/IRGenerationContext.h +++ b/libsolidity/codegen/ir/IRGenerationContext.h @@ -61,11 +61,12 @@ public: MultiUseYulFunctionCollector& functionCollector() { return m_functions; } - /// Sets the current inheritance hierarchy from derived to base. - void setInheritanceHierarchy(std::vector _hierarchy) + /// Sets the most derived contract (the one currently being compiled)> + void setMostDerivedContract(ContractDefinition const& _mostDerivedContract) { - m_inheritanceHierarchy = std::move(_hierarchy); + m_mostDerivedContract = &_mostDerivedContract; } + ContractDefinition const& mostDerivedContract() const; IRVariable const& addLocalVariable(VariableDeclaration const& _varDecl); @@ -81,7 +82,6 @@ public: std::string functionName(FunctionDefinition const& _function); std::string functionName(VariableDeclaration const& _varDecl); - FunctionDefinition const& virtualFunction(FunctionDefinition const& _functionDeclaration); std::string virtualFunctionName(FunctionDefinition const& _functionDeclaration); std::string newYulVariable(); @@ -103,7 +103,7 @@ private: langutil::EVMVersion m_evmVersion; RevertStrings m_revertStrings; OptimiserSettings m_optimiserSettings; - std::vector m_inheritanceHierarchy; + ContractDefinition const* m_mostDerivedContract = nullptr; std::map m_localVariables; /// Storage offsets of state variables std::map> m_stateVariables; diff --git a/libsolidity/codegen/ir/IRGenerator.cpp b/libsolidity/codegen/ir/IRGenerator.cpp index 9c7d9c8ac..4ca00befb 100644 --- a/libsolidity/codegen/ir/IRGenerator.cpp +++ b/libsolidity/codegen/ir/IRGenerator.cpp @@ -110,7 +110,7 @@ string IRGenerator::generate(ContractDefinition const& _contract) t("functions", m_context.functionCollector().requestedFunctions()); resetContext(_contract); - m_context.setInheritanceHierarchy(_contract.annotation().linearizedBaseContracts); + m_context.setMostDerivedContract(_contract); t("RuntimeObject", runtimeObjectName(_contract)); t("dispatch", dispatchRoutine(_contract)); for (auto const* contract: _contract.annotation().linearizedBaseContracts) @@ -389,7 +389,7 @@ void IRGenerator::resetContext(ContractDefinition const& _contract) ); m_context = IRGenerationContext(m_evmVersion, m_context.revertStrings(), m_optimiserSettings); - m_context.setInheritanceHierarchy(_contract.annotation().linearizedBaseContracts); + m_context.setMostDerivedContract(_contract); for (auto const& var: ContractType(_contract).stateVariables()) m_context.addStateVariable(*get<0>(var), get<1>(var), get<2>(var)); } diff --git a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp index 1d40c35b1..b5d9589c0 100644 --- a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp +++ b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp @@ -1165,7 +1165,7 @@ void IRGeneratorForStatements::endVisit(Identifier const& _identifier) return; } else if (FunctionDefinition const* functionDef = dynamic_cast(declaration)) - define(_identifier) << to_string(m_context.virtualFunction(*functionDef).id()) << "\n"; + define(_identifier) << to_string(functionDef->resolveVirtual(m_context.mostDerivedContract()).id()) << "\n"; else if (VariableDeclaration const* varDecl = dynamic_cast(declaration)) { // TODO for the constant case, we have to be careful: diff --git a/test/libsolidity/SolidityExpressionCompiler.cpp b/test/libsolidity/SolidityExpressionCompiler.cpp index cb0066bee..cdaee104b 100644 --- a/test/libsolidity/SolidityExpressionCompiler.cpp +++ b/test/libsolidity/SolidityExpressionCompiler.cpp @@ -119,13 +119,9 @@ bytes compileFirstExpression( NameAndTypeResolver resolver(globalContext, solidity::test::CommonOptions::get().evmVersion(), scopes, errorReporter); resolver.registerDeclarations(*sourceUnit); - vector inheritanceHierarchy; for (ASTPointer const& node: sourceUnit->nodes()) if (ContractDefinition* contract = dynamic_cast(node.get())) - { BOOST_REQUIRE_MESSAGE(resolver.resolveNamesAndTypes(*contract), "Resolving names failed"); - inheritanceHierarchy = vector(1, contract); - } for (ASTPointer const& node: sourceUnit->nodes()) if (ContractDefinition* contract = dynamic_cast(node.get())) { @@ -144,7 +140,7 @@ bytes compileFirstExpression( RevertStrings::Default ); context.resetVisitedNodes(contract); - context.setInheritanceHierarchy(inheritanceHierarchy); + context.setMostDerivedContract(*contract); unsigned parametersSize = _localVariables.size(); // assume they are all one slot on the stack context.adjustStackOffset(parametersSize); for (vector const& variable: _localVariables)