Merge pull request #8533 from ethereum/refactorVirtualResolution

Refactor virtual resolution
This commit is contained in:
chriseth 2020-03-26 18:28:03 +01:00 committed by GitHub
commit 514eef92be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 157 additions and 115 deletions

View File

@ -319,6 +319,37 @@ FunctionDefinitionAnnotation& FunctionDefinition::annotation() const
return initAnnotation<FunctionDefinitionAnnotation>();
}
FunctionDefinition const& FunctionDefinition::resolveVirtual(
ContractDefinition const& _mostDerivedContract,
ContractDefinition const* _searchStart
) const
{
solAssert(!isConstructor(), "");
// If we are not doing super-lookup and the function is not virtual, we can stop here.
if (_searchStart == nullptr && !virtualSemantics())
return *this;
solAssert(!dynamic_cast<ContractDefinition const&>(*scope()).isLibrary(), "");
FunctionType const* functionType = TypeProvider::function(*this)->asCallableFunction(false);
for (ContractDefinition const* c: _mostDerivedContract.annotation().linearizedBaseContracts)
{
if (_searchStart != nullptr && c != _searchStart)
continue;
_searchStart = nullptr;
for (FunctionDefinition const* function: c->definedFunctions())
if (
function->name() == name() &&
!function->isConstructor() &&
FunctionType(*function).asCallableFunction(false)->hasEqualParameterTypes(*functionType)
)
return *function;
}
solAssert(false, "Virtual function " + name() + " not found.");
return *this; // not reached
}
TypePointer ModifierDefinition::type() const
{
return TypeProvider::modifier(*this);
@ -329,6 +360,33 @@ ModifierDefinitionAnnotation& ModifierDefinition::annotation() const
return initAnnotation<ModifierDefinitionAnnotation>();
}
ModifierDefinition const& ModifierDefinition::resolveVirtual(
ContractDefinition const& _mostDerivedContract,
ContractDefinition const* _searchStart
) const
{
solAssert(_searchStart == nullptr, "Used super in connection with modifiers.");
// If we are not doing super-lookup and the modifier is not virtual, we can stop here.
if (_searchStart == nullptr && !virtualSemantics())
return *this;
solAssert(!dynamic_cast<ContractDefinition const&>(*scope()).isLibrary(), "");
for (ContractDefinition const* c: _mostDerivedContract.annotation().linearizedBaseContracts)
{
if (_searchStart != nullptr && c != _searchStart)
continue;
_searchStart = nullptr;
for (ModifierDefinition const* modifier: c->functionModifiers())
if (modifier->name() == name())
return *modifier;
}
solAssert(false, "Virtual modifier " + name() + " not found.");
return *this; // not reached
}
TypePointer EventDefinition::type() const
{
return TypeProvider::function(*this);

View File

@ -689,6 +689,18 @@ public:
CallableDeclarationAnnotation& annotation() const override = 0;
/// Performs virtual or super function/modifier lookup:
/// If @a _searchStart is nullptr, performs virtual function lookup, i.e.
/// searches the inheritance hierarchy of @a _mostDerivedContract towards the base
/// and returns the first function/modifier definition that
/// is overwritten by this callable.
/// If @a _searchStart is non-null, starts searching only from that contract, but
/// still in the hierarchy of @a _mostDerivedContract.
virtual CallableDeclaration const& resolveVirtual(
ContractDefinition const& _mostDerivedContract,
ContractDefinition const* _searchStart = nullptr
) const = 0;
protected:
ASTPointer<ParameterList> m_parameters;
ASTPointer<OverrideSpecifier> m_overrides;
@ -799,6 +811,12 @@ public:
CallableDeclaration::virtualSemantics() ||
(annotation().contract && annotation().contract->isInterface());
}
FunctionDefinition const& resolveVirtual(
ContractDefinition const& _mostDerivedContract,
ContractDefinition const* _searchStart = nullptr
) const override;
private:
StateMutability m_stateMutability;
Token const m_kind;
@ -945,6 +963,12 @@ public:
ModifierDefinitionAnnotation& annotation() const override;
ModifierDefinition const& resolveVirtual(
ContractDefinition const& _mostDerivedContract,
ContractDefinition const* _searchStart = nullptr
) const override;
private:
ASTPointer<Block> m_body;
};
@ -1010,6 +1034,14 @@ public:
EventDefinitionAnnotation& annotation() const override;
CallableDeclaration const& resolveVirtual(
ContractDefinition const&,
ContractDefinition const*
) const override
{
solAssert(false, "Tried to resolve virtual event.");
}
private:
bool m_anonymous = false;
};

View File

@ -272,57 +272,43 @@ evmasm::AssemblyItem CompilerContext::functionEntryLabelIfExists(Declaration con
return m_functionCompilationQueue.entryLabelIfExists(_declaration);
}
FunctionDefinition const& CompilerContext::resolveVirtualFunction(FunctionDefinition const& _function)
{
// Libraries do not allow inheritance and their functions can be inlined, so we should not
// search the inheritance hierarchy (which will be the wrong one in case the function
// is inlined).
if (auto scope = dynamic_cast<ContractDefinition const*>(_function.scope()))
if (scope->isLibrary())
return _function;
solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set.");
return resolveVirtualFunction(_function, m_inheritanceHierarchy.begin());
}
FunctionDefinition const& CompilerContext::superFunction(FunctionDefinition const& _function, ContractDefinition const& _base)
{
solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set.");
return resolveVirtualFunction(_function, superContract(_base));
solAssert(m_mostDerivedContract, "No most derived contract set.");
ContractDefinition const* super = superContract(_base);
solAssert(super, "Super contract not available.");
return _function.resolveVirtual(mostDerivedContract(), super);
}
FunctionDefinition const* CompilerContext::nextConstructor(ContractDefinition const& _contract) const
{
vector<ContractDefinition const*>::const_iterator it = superContract(_contract);
for (; it != m_inheritanceHierarchy.end(); ++it)
if ((*it)->constructor())
return (*it)->constructor();
ContractDefinition const* next = superContract(_contract);
if (next == nullptr)
return nullptr;
for (ContractDefinition const* c: m_mostDerivedContract->annotation().linearizedBaseContracts)
if (next != nullptr && next != c)
continue;
else
{
next = nullptr;
if (c->constructor())
return c->constructor();
}
return nullptr;
}
ContractDefinition const& CompilerContext::mostDerivedContract() const
{
solAssert(m_mostDerivedContract, "Most derived contract not set.");
return *m_mostDerivedContract;
}
Declaration const* CompilerContext::nextFunctionToCompile() const
{
return m_functionCompilationQueue.nextFunctionToCompile();
}
ModifierDefinition const& CompilerContext::resolveVirtualFunctionModifier(
ModifierDefinition const& _modifier
) const
{
// Libraries do not allow inheritance and their functions can be inlined, so we should not
// search the inheritance hierarchy (which will be the wrong one in case the function
// is inlined).
if (auto scope = dynamic_cast<ContractDefinition const*>(_modifier.scope()))
if (scope->isLibrary())
return _modifier;
solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set.");
for (ContractDefinition const* contract: m_inheritanceHierarchy)
for (ModifierDefinition const* modifier: contract->functionModifiers())
if (modifier->name() == _modifier.name())
return *modifier;
solAssert(false, "Function modifier " + _modifier.name() + " not found in inheritance hierarchy.");
}
unsigned CompilerContext::baseStackOffsetOfVariable(Declaration const& _declaration) const
{
auto res = m_localVariables.find(&_declaration);
@ -556,32 +542,19 @@ LinkerObject const& CompilerContext::assembledObject() const
return object;
}
FunctionDefinition const& CompilerContext::resolveVirtualFunction(
FunctionDefinition const& _function,
vector<ContractDefinition const*>::const_iterator _searchStart
)
ContractDefinition const* CompilerContext::superContract(ContractDefinition const& _contract) const
{
string name = _function.name();
FunctionType functionType(_function);
auto it = _searchStart;
for (; it != m_inheritanceHierarchy.end(); ++it)
for (FunctionDefinition const* function: (*it)->definedFunctions())
if (
function->name() == name &&
!function->isConstructor() &&
FunctionType(*function).asCallableFunction(false)->hasEqualParameterTypes(functionType)
)
return *function;
solAssert(false, "Super function " + name + " not found.");
return _function; // not reached
}
vector<ContractDefinition const*>::const_iterator CompilerContext::superContract(ContractDefinition const& _contract) const
{
solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set.");
auto it = find(m_inheritanceHierarchy.begin(), m_inheritanceHierarchy.end(), &_contract);
solAssert(it != m_inheritanceHierarchy.end(), "Base not found in inheritance hierarchy.");
return ++it;
auto const& hierarchy = mostDerivedContract().annotation().linearizedBaseContracts;
auto it = find(hierarchy.begin(), hierarchy.end(), &_contract);
solAssert(it != hierarchy.end(), "Base not found in inheritance hierarchy.");
++it;
if (it == hierarchy.end())
return nullptr;
else
{
solAssert(*it != &_contract, "");
return *it;
}
}
string CompilerContext::revertReasonIfDebug(string const& _message)

View File

@ -114,15 +114,14 @@ public:
/// @returns the entry label of the given function. Might return an AssemblyItem of type
/// UndefinedItem if it does not exist yet.
evmasm::AssemblyItem functionEntryLabelIfExists(Declaration const& _declaration) const;
/// @returns the entry label of the given function and takes overrides into account.
FunctionDefinition const& resolveVirtualFunction(FunctionDefinition const& _function);
/// @returns the function that overrides the given declaration from the most derived class just
/// above _base in the current inheritance hierarchy.
FunctionDefinition const& superFunction(FunctionDefinition const& _function, ContractDefinition const& _base);
/// @returns the next constructor in the inheritance hierarchy.
FunctionDefinition const* nextConstructor(ContractDefinition const& _contract) const;
/// Sets the current inheritance hierarchy from derived to base.
void setInheritanceHierarchy(std::vector<ContractDefinition const*> const& _hierarchy) { m_inheritanceHierarchy = _hierarchy; }
/// Sets the contract currently being compiled - the most derived one.
void setMostDerivedContract(ContractDefinition const& _contract) { m_mostDerivedContract = &_contract; }
ContractDefinition const& mostDerivedContract() const;
/// @returns the next function in the queue of functions that are still to be compiled
/// (i.e. that were referenced during compilation but where we did not yet generate code for).
@ -171,7 +170,6 @@ public:
/// empty return value.
std::pair<std::string, std::set<std::string>> requestedYulFunctions();
ModifierDefinition const& resolveVirtualFunctionModifier(ModifierDefinition const& _modifier) const;
/// Returns the distance of the given local variable from the bottom of the stack (of the current function).
unsigned baseStackOffsetOfVariable(Declaration const& _declaration) const;
/// If supplied by a value returned by @ref baseStackOffsetOfVariable(variable), returns
@ -315,14 +313,8 @@ public:
RevertStrings revertStrings() const { return m_revertStrings; }
private:
/// Searches the inheritance hierarchy towards the base starting from @a _searchStart and returns
/// the first function definition that is overwritten by _function.
FunctionDefinition const& resolveVirtualFunction(
FunctionDefinition const& _function,
std::vector<ContractDefinition const*>::const_iterator _searchStart
);
/// @returns an iterator to the contract directly above the given contract.
std::vector<ContractDefinition const*>::const_iterator superContract(ContractDefinition const& _contract) const;
/// @returns a pointer to the contract directly above the given contract.
ContractDefinition const* superContract(ContractDefinition const& _contract) const;
/// Updates source location set in the assembly.
void updateSourceLocation();
@ -381,8 +373,8 @@ private:
/// modifier is applied twice, the position of the variable needs to be restored
/// after the nested modifier is left.
std::map<Declaration const*, std::vector<unsigned>> m_localVariables;
/// List of current inheritance hierarchy from derived to base.
std::vector<ContractDefinition const*> m_inheritanceHierarchy;
/// The contract currently being compiled. Virtual function lookup starts from this contarct.
ContractDefinition const* m_mostDerivedContract = nullptr;
/// Stack of current visited AST nodes, used for location attachment
std::stack<ASTNode const*> m_visitedNodes;
/// The runtime context if in Creation mode, this is used for generating tags that would be stored into the storage and then used at runtime.

View File

@ -129,7 +129,7 @@ void ContractCompiler::initializeContext(
{
m_context.setExperimentalFeatures(_contract.sourceUnit().annotation().experimentalFeatures);
m_context.setOtherCompilers(_otherCompilers);
m_context.setInheritanceHierarchy(_contract.annotation().linearizedBaseContracts);
m_context.setMostDerivedContract(_contract);
if (m_runtimeCompiler)
registerImmutableVariables(_contract);
CompilerUtils(m_context).initialiseFreeMemoryPointer();
@ -689,7 +689,7 @@ bool ContractCompiler::visit(InlineAssembly const& _inlineAssembly)
if (FunctionDefinition const* functionDef = dynamic_cast<FunctionDefinition const*>(decl))
{
solAssert(!ref->second.isOffset && !ref->second.isSlot, "");
functionDef = &m_context.resolveVirtualFunction(*functionDef);
functionDef = &functionDef->resolveVirtual(m_context.mostDerivedContract());
auto functionEntryLabel = m_context.functionEntryLabel(*functionDef).pushTag();
solAssert(functionEntryLabel.data() <= std::numeric_limits<size_t>::max(), "");
_assembly.appendLabelReference(size_t(functionEntryLabel.data()));
@ -1335,10 +1335,9 @@ void ContractCompiler::appendModifierOrFunctionCode()
appendModifierOrFunctionCode();
else
{
ModifierDefinition const& nonVirtualModifier = dynamic_cast<ModifierDefinition const&>(
ModifierDefinition const& modifier = dynamic_cast<ModifierDefinition const&>(
*modifierInvocation->name()->annotation().referencedDeclaration
);
ModifierDefinition const& modifier = m_context.resolveVirtualFunctionModifier(nonVirtualModifier);
).resolveVirtual(m_context.mostDerivedContract());
CompilerContext::LocationSetter locationSetter(m_context, modifier);
std::vector<ASTPointer<Expression>> const& modifierArguments =
modifierInvocation->arguments() ? *modifierInvocation->arguments() : std::vector<ASTPointer<Expression>>();

View File

@ -572,7 +572,10 @@ bool ExpressionCompiler::visit(FunctionCall const& _functionCall)
// Do not directly visit the identifier, because this way, we can avoid
// the runtime entry label to be created at the creation time context.
CompilerContext::LocationSetter locationSetter2(m_context, *identifier);
utils().pushCombinedFunctionEntryLabel(m_context.resolveVirtualFunction(*functionDef), false);
utils().pushCombinedFunctionEntryLabel(
functionDef->resolveVirtual(m_context.mostDerivedContract()),
false
);
shortcutTaken = true;
}
}
@ -1861,7 +1864,7 @@ void ExpressionCompiler::endVisit(Identifier const& _identifier)
// we want to avoid having a reference to the runtime function entry point in the
// constructor context, since this would force the compiler to include unreferenced
// internal functions in the runtime contex.
utils().pushCombinedFunctionEntryLabel(m_context.resolveVirtualFunction(*functionDef));
utils().pushCombinedFunctionEntryLabel(functionDef->resolveVirtual(m_context.mostDerivedContract()));
else if (auto variable = dynamic_cast<VariableDeclaration const*>(declaration))
appendVariable(*variable, static_cast<Expression const&>(_identifier));
else if (auto contract = dynamic_cast<ContractDefinition const*>(declaration))

View File

@ -31,6 +31,12 @@ using namespace solidity;
using namespace solidity::util;
using namespace solidity::frontend;
ContractDefinition const& IRGenerationContext::mostDerivedContract() const
{
solAssert(m_mostDerivedContract, "Most derived contract requested but not set.");
return *m_mostDerivedContract;
}
IRVariable const& IRGenerationContext::addLocalVariable(VariableDeclaration const& _varDecl)
{
auto const& [it, didInsert] = m_localVariables.emplace(
@ -70,26 +76,9 @@ string IRGenerationContext::functionName(VariableDeclaration const& _varDecl)
return "getter_fun_" + _varDecl.name() + "_" + to_string(_varDecl.id());
}
FunctionDefinition const& IRGenerationContext::virtualFunction(FunctionDefinition const& _function)
{
// @TODO previously, we had to distinguish creation context and runtime context,
// but since we do not work with jump positions anymore, this should not be a problem, right?
string name = _function.name();
FunctionType functionType(_function);
for (auto const& contract: m_inheritanceHierarchy)
for (FunctionDefinition const* function: contract->definedFunctions())
if (
function->name() == name &&
!function->isConstructor() &&
FunctionType(*function).asCallableFunction(false)->hasEqualParameterTypes(functionType)
)
return *function;
solAssert(false, "Super function " + name + " not found.");
}
string IRGenerationContext::virtualFunctionName(FunctionDefinition const& _functionDeclaration)
{
return functionName(virtualFunction(_functionDeclaration));
return functionName(_functionDeclaration.resolveVirtual(mostDerivedContract()));
}
string IRGenerationContext::newYulVariable()
@ -120,7 +109,7 @@ string IRGenerationContext::internalDispatch(size_t _in, size_t _out)
templ("arrow", _out > 0 ? "->" : "");
templ("out", suffixedVariableNameList("out_", 0, _out));
vector<map<string, string>> functions;
for (auto const& contract: m_inheritanceHierarchy)
for (auto const& contract: mostDerivedContract().annotation().linearizedBaseContracts)
for (FunctionDefinition const* function: contract->definedFunctions())
if (
!function->isConstructor() &&

View File

@ -61,11 +61,12 @@ public:
MultiUseYulFunctionCollector& functionCollector() { return m_functions; }
/// Sets the current inheritance hierarchy from derived to base.
void setInheritanceHierarchy(std::vector<ContractDefinition const*> _hierarchy)
/// Sets the most derived contract (the one currently being compiled)>
void setMostDerivedContract(ContractDefinition const& _mostDerivedContract)
{
m_inheritanceHierarchy = std::move(_hierarchy);
m_mostDerivedContract = &_mostDerivedContract;
}
ContractDefinition const& mostDerivedContract() const;
IRVariable const& addLocalVariable(VariableDeclaration const& _varDecl);
@ -81,7 +82,6 @@ public:
std::string functionName(FunctionDefinition const& _function);
std::string functionName(VariableDeclaration const& _varDecl);
FunctionDefinition const& virtualFunction(FunctionDefinition const& _functionDeclaration);
std::string virtualFunctionName(FunctionDefinition const& _functionDeclaration);
std::string newYulVariable();
@ -103,7 +103,7 @@ private:
langutil::EVMVersion m_evmVersion;
RevertStrings m_revertStrings;
OptimiserSettings m_optimiserSettings;
std::vector<ContractDefinition const*> m_inheritanceHierarchy;
ContractDefinition const* m_mostDerivedContract = nullptr;
std::map<VariableDeclaration const*, IRVariable> m_localVariables;
/// Storage offsets of state variables
std::map<VariableDeclaration const*, std::pair<u256, unsigned>> m_stateVariables;

View File

@ -110,7 +110,7 @@ string IRGenerator::generate(ContractDefinition const& _contract)
t("functions", m_context.functionCollector().requestedFunctions());
resetContext(_contract);
m_context.setInheritanceHierarchy(_contract.annotation().linearizedBaseContracts);
m_context.setMostDerivedContract(_contract);
t("RuntimeObject", runtimeObjectName(_contract));
t("dispatch", dispatchRoutine(_contract));
for (auto const* contract: _contract.annotation().linearizedBaseContracts)
@ -389,7 +389,7 @@ void IRGenerator::resetContext(ContractDefinition const& _contract)
);
m_context = IRGenerationContext(m_evmVersion, m_context.revertStrings(), m_optimiserSettings);
m_context.setInheritanceHierarchy(_contract.annotation().linearizedBaseContracts);
m_context.setMostDerivedContract(_contract);
for (auto const& var: ContractType(_contract).stateVariables())
m_context.addStateVariable(*get<0>(var), get<1>(var), get<2>(var));
}

View File

@ -1165,7 +1165,7 @@ void IRGeneratorForStatements::endVisit(Identifier const& _identifier)
return;
}
else if (FunctionDefinition const* functionDef = dynamic_cast<FunctionDefinition const*>(declaration))
define(_identifier) << to_string(m_context.virtualFunction(*functionDef).id()) << "\n";
define(_identifier) << to_string(functionDef->resolveVirtual(m_context.mostDerivedContract()).id()) << "\n";
else if (VariableDeclaration const* varDecl = dynamic_cast<VariableDeclaration const*>(declaration))
{
// TODO for the constant case, we have to be careful:

View File

@ -119,13 +119,9 @@ bytes compileFirstExpression(
NameAndTypeResolver resolver(globalContext, solidity::test::CommonOptions::get().evmVersion(), scopes, errorReporter);
resolver.registerDeclarations(*sourceUnit);
vector<ContractDefinition const*> inheritanceHierarchy;
for (ASTPointer<ASTNode> const& node: sourceUnit->nodes())
if (ContractDefinition* contract = dynamic_cast<ContractDefinition*>(node.get()))
{
BOOST_REQUIRE_MESSAGE(resolver.resolveNamesAndTypes(*contract), "Resolving names failed");
inheritanceHierarchy = vector<ContractDefinition const*>(1, contract);
}
for (ASTPointer<ASTNode> const& node: sourceUnit->nodes())
if (ContractDefinition* contract = dynamic_cast<ContractDefinition*>(node.get()))
{
@ -144,7 +140,7 @@ bytes compileFirstExpression(
RevertStrings::Default
);
context.resetVisitedNodes(contract);
context.setInheritanceHierarchy(inheritanceHierarchy);
context.setMostDerivedContract(*contract);
unsigned parametersSize = _localVariables.size(); // assume they are all one slot on the stack
context.adjustStackOffset(parametersSize);
for (vector<string> const& variable: _localVariables)