Modifier overrides and callgraph analysis.

This commit is contained in:
Christian 2015-01-23 02:46:31 +01:00
parent 7ded95c776
commit fd5899d038
6 changed files with 79 additions and 43 deletions

View File

@ -33,7 +33,11 @@ namespace solidity
void CallGraph::addNode(ASTNode const& _node) void CallGraph::addNode(ASTNode const& _node)
{ {
_node.accept(*this); if (!m_nodesSeen.count(&_node))
{
m_workQueue.push(&_node);
m_nodesSeen.insert(&_node);
}
} }
set<FunctionDefinition const*> const& CallGraph::getCalls() set<FunctionDefinition const*> const& CallGraph::getCalls()
@ -53,20 +57,26 @@ void CallGraph::computeCallGraph()
bool CallGraph::visit(Identifier const& _identifier) bool CallGraph::visit(Identifier const& _identifier)
{ {
FunctionDefinition const* fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration()); if (auto fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration()))
if (fun)
{ {
if (m_overrideResolver) if (m_functionOverrideResolver)
fun = (*m_overrideResolver)(fun->getName()); fun = (*m_functionOverrideResolver)(fun->getName());
solAssert(fun, "Error finding override for function " + fun->getName()); solAssert(fun, "Error finding override for function " + fun->getName());
addFunction(*fun); addNode(*fun);
}
if (auto modifier = dynamic_cast<ModifierDefinition const*>(_identifier.getReferencedDeclaration()))
{
if (m_modifierOverrideResolver)
modifier = (*m_modifierOverrideResolver)(modifier->getName());
solAssert(modifier, "Error finding override for modifier " + modifier->getName());
addNode(*modifier);
} }
return true; return true;
} }
bool CallGraph::visit(FunctionDefinition const& _function) bool CallGraph::visit(FunctionDefinition const& _function)
{ {
addFunction(_function); m_functionsSeen.insert(&_function);
return true; return true;
} }
@ -83,7 +93,7 @@ bool CallGraph::visit(MemberAccess const& _memberAccess)
for (ASTPointer<FunctionDefinition> const& function: contract.getDefinedFunctions()) for (ASTPointer<FunctionDefinition> const& function: contract.getDefinedFunctions())
if (function->getName() == _memberAccess.getMemberName()) if (function->getName() == _memberAccess.getMemberName())
{ {
addFunction(*function); addNode(*function);
return true; return true;
} }
} }
@ -91,14 +101,5 @@ bool CallGraph::visit(MemberAccess const& _memberAccess)
return true; return true;
} }
void CallGraph::addFunction(FunctionDefinition const& _function)
{
if (!m_functionsSeen.count(&_function))
{
m_functionsSeen.insert(&_function);
m_workQueue.push(&_function);
}
}
} }
} }

View File

@ -39,9 +39,13 @@ namespace solidity
class CallGraph: private ASTConstVisitor class CallGraph: private ASTConstVisitor
{ {
public: public:
using OverrideResolver = std::function<FunctionDefinition const*(std::string const&)>; using FunctionOverrideResolver = std::function<FunctionDefinition const*(std::string const&)>;
using ModifierOverrideResolver = std::function<ModifierDefinition const*(std::string const&)>;
CallGraph(OverrideResolver const& _overrideResolver): m_overrideResolver(&_overrideResolver) {} CallGraph(FunctionOverrideResolver const& _functionOverrideResolver,
ModifierOverrideResolver const& _modifierOverrideResolver):
m_functionOverrideResolver(&_functionOverrideResolver),
m_modifierOverrideResolver(&_modifierOverrideResolver) {}
void addNode(ASTNode const& _node); void addNode(ASTNode const& _node);
@ -53,11 +57,12 @@ private:
virtual bool visit(MemberAccess const& _memberAccess) override; virtual bool visit(MemberAccess const& _memberAccess) override;
void computeCallGraph(); void computeCallGraph();
void addFunction(FunctionDefinition const& _function);
OverrideResolver const* m_overrideResolver; FunctionOverrideResolver const* m_functionOverrideResolver;
ModifierOverrideResolver const* m_modifierOverrideResolver;
std::set<ASTNode const*> m_nodesSeen;
std::set<FunctionDefinition const*> m_functionsSeen; std::set<FunctionDefinition const*> m_functionsSeen;
std::queue<FunctionDefinition const*> m_workQueue; std::queue<ASTNode const*> m_workQueue;
}; };
} }

View File

@ -42,9 +42,13 @@ void Compiler::compileContract(ContractDefinition const& _contract,
initializeContext(_contract, _contracts); initializeContext(_contract, _contracts);
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
{
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
if (!function->isConstructor()) if (!function->isConstructor())
m_context.addFunction(*function); m_context.addFunction(*function);
for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers())
m_context.addModifier(*modifier);
}
appendFunctionSelector(_contract); appendFunctionSelector(_contract);
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
@ -67,6 +71,13 @@ void Compiler::initializeContext(ContractDefinition const& _contract,
void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext) void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext)
{ {
std::vector<ContractDefinition const*> const& bases = _contract.getLinearizedBaseContracts();
// Make all modifiers known to the context.
for (ContractDefinition const* contract: bases)
for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers())
m_context.addModifier(*modifier);
// arguments for base constructors, filled in derived-to-base order // arguments for base constructors, filled in derived-to-base order
map<ContractDefinition const*, vector<ASTPointer<Expression>> const*> baseArguments; map<ContractDefinition const*, vector<ASTPointer<Expression>> const*> baseArguments;
set<FunctionDefinition const*> neededFunctions; set<FunctionDefinition const*> neededFunctions;
@ -74,10 +85,8 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp
// Determine the arguments that are used for the base constructors and also which functions // Determine the arguments that are used for the base constructors and also which functions
// are needed at compile time. // are needed at compile time.
std::vector<ContractDefinition const*> const& bases = _contract.getLinearizedBaseContracts();
for (ContractDefinition const* contract: bases) for (ContractDefinition const* contract: bases)
{ {
//TODO include modifiers
if (FunctionDefinition const* constructor = contract->getConstructor()) if (FunctionDefinition const* constructor = contract->getConstructor())
nodesUsedInConstructors.insert(constructor); nodesUsedInConstructors.insert(constructor);
for (ASTPointer<InheritanceSpecifier> const& base: contract->getBaseContracts()) for (ASTPointer<InheritanceSpecifier> const& base: contract->getBaseContracts())
@ -94,7 +103,7 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp
} }
} }
auto overrideResolver = [&](string const& _name) -> FunctionDefinition const* auto functionOverrideResolver = [&](string const& _name) -> FunctionDefinition const*
{ {
for (ContractDefinition const* contract: bases) for (ContractDefinition const* contract: bases)
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
@ -102,21 +111,26 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp
return function.get(); return function.get();
return nullptr; return nullptr;
}; };
auto modifierOverrideResolver = [&](string const& _name) -> ModifierDefinition const*
{
return &m_context.getFunctionModifier(_name);
};
neededFunctions = getFunctionsCalled(nodesUsedInConstructors, overrideResolver); neededFunctions = getFunctionsCalled(nodesUsedInConstructors, functionOverrideResolver,
modifierOverrideResolver);
// First add all overrides (or the functions themselves if there is no override) // First add all overrides (or the functions themselves if there is no override)
for (FunctionDefinition const* fun: neededFunctions) for (FunctionDefinition const* fun: neededFunctions)
{ {
FunctionDefinition const* override = nullptr; FunctionDefinition const* override = nullptr;
if (!fun->isConstructor()) if (!fun->isConstructor())
override = overrideResolver(fun->getName()); override = functionOverrideResolver(fun->getName());
if (!!override && neededFunctions.count(override)) if (!!override && neededFunctions.count(override))
m_context.addFunction(*override); m_context.addFunction(*override);
} }
// now add the rest // now add the rest
for (FunctionDefinition const* fun: neededFunctions) for (FunctionDefinition const* fun: neededFunctions)
if (fun->isConstructor() || overrideResolver(fun->getName()) != fun) if (fun->isConstructor() || functionOverrideResolver(fun->getName()) != fun)
m_context.addFunction(*fun); m_context.addFunction(*fun);
// Call constructors in base-to-derived order. // Call constructors in base-to-derived order.
@ -176,9 +190,10 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor)
} }
set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes, set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes,
function<FunctionDefinition const*(string const&)> const& _resolveOverrides) function<FunctionDefinition const*(string const&)> const& _resolveFunctionOverrides,
function<ModifierDefinition const*(string const&)> const& _resolveModifierOverrides)
{ {
CallGraph callgraph(_resolveOverrides); CallGraph callgraph(_resolveFunctionOverrides, _resolveModifierOverrides);
for (ASTNode const* node: _nodes) for (ASTNode const* node: _nodes)
callgraph.addNode(*node); callgraph.addNode(*node);
return callgraph.getCalls(); return callgraph.getCalls();
@ -471,25 +486,22 @@ void Compiler::appendModifierOrFunctionCode()
{ {
ASTPointer<ModifierInvocation> const& modifierInvocation = m_currentFunction->getModifiers()[m_modifierDepth]; ASTPointer<ModifierInvocation> const& modifierInvocation = m_currentFunction->getModifiers()[m_modifierDepth];
// TODO get the most derived override of the modifier ModifierDefinition const& modifier = m_context.getFunctionModifier(modifierInvocation->getName()->getName());
ModifierDefinition const* modifier = dynamic_cast<ModifierDefinition const*>( solAssert(modifier.getParameters().size() == modifierInvocation->getArguments().size(), "");
modifierInvocation->getName()->getReferencedDeclaration()); for (unsigned i = 0; i < modifier.getParameters().size(); ++i)
solAssert(!!modifier, "Modifier not found.");
solAssert(modifier->getParameters().size() == modifierInvocation->getArguments().size(), "");
for (unsigned i = 0; i < modifier->getParameters().size(); ++i)
{ {
m_context.addVariable(*modifier->getParameters()[i]); m_context.addVariable(*modifier.getParameters()[i]);
compileExpression(*modifierInvocation->getArguments()[i], compileExpression(*modifierInvocation->getArguments()[i],
modifier->getParameters()[i]->getType()); modifier.getParameters()[i]->getType());
} }
for (VariableDeclaration const* localVariable: modifier->getLocalVariables()) for (VariableDeclaration const* localVariable: modifier.getLocalVariables())
m_context.addAndInitializeVariable(*localVariable); m_context.addAndInitializeVariable(*localVariable);
unsigned const c_stackSurplus = CompilerUtils::getSizeOnStack(modifier->getParameters()) + unsigned const c_stackSurplus = CompilerUtils::getSizeOnStack(modifier.getParameters()) +
CompilerUtils::getSizeOnStack(modifier->getLocalVariables()); CompilerUtils::getSizeOnStack(modifier.getLocalVariables());
m_stackCleanupForReturn += c_stackSurplus; m_stackCleanupForReturn += c_stackSurplus;
modifier->getBody().accept(*this); modifier.getBody().accept(*this);
for (unsigned i = 0; i < c_stackSurplus; ++i) for (unsigned i = 0; i < c_stackSurplus; ++i)
m_context << eth::Instruction::POP; m_context << eth::Instruction::POP;

View File

@ -53,7 +53,8 @@ private:
/// Recursively searches the call graph and returns all functions referenced inside _nodes. /// Recursively searches the call graph and returns all functions referenced inside _nodes.
/// _resolveOverride is called to resolve virtual function overrides. /// _resolveOverride is called to resolve virtual function overrides.
std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes, std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes,
std::function<FunctionDefinition const*(std::string const&)> const& _resolveOverride); std::function<FunctionDefinition const*(std::string const&)> const& _resolveFunctionOverride,
std::function<ModifierDefinition const*(std::string const&)> const& _resolveModifierOverride);
void appendFunctionSelector(ContractDefinition const& _contract); void appendFunctionSelector(ContractDefinition const& _contract);
/// Creates code that unpacks the arguments for the given function, from memory if /// Creates code that unpacks the arguments for the given function, from memory if
/// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes. /// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes.

View File

@ -66,6 +66,11 @@ void CompilerContext::addFunction(FunctionDefinition const& _function)
m_virtualFunctionEntryLabels.insert(make_pair(_function.getName(), tag)); m_virtualFunctionEntryLabels.insert(make_pair(_function.getName(), tag));
} }
void CompilerContext::addModifier(ModifierDefinition const& _modifier)
{
m_functionModifiers.insert(make_pair(_modifier.getName(), &_modifier));
}
bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const
{ {
auto ret = m_compiledContracts.find(&_contract); auto ret = m_compiledContracts.find(&_contract);
@ -92,6 +97,13 @@ eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefiniti
return res->second.tag(); return res->second.tag();
} }
ModifierDefinition const& CompilerContext::getFunctionModifier(string const& _name) const
{
auto res = m_functionModifiers.find(_name);
solAssert(res != m_functionModifiers.end(), "Function modifier override not found.");
return *res->second;
}
unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const
{ {
auto res = m_localVariables.find(&_declaration); auto res = m_localVariables.find(&_declaration);

View File

@ -45,6 +45,8 @@ public:
void addVariable(VariableDeclaration const& _declaration, unsigned _offsetToCurrent = 0); void addVariable(VariableDeclaration const& _declaration, unsigned _offsetToCurrent = 0);
void addAndInitializeVariable(VariableDeclaration const& _declaration); void addAndInitializeVariable(VariableDeclaration const& _declaration);
void addFunction(FunctionDefinition const& _function); void addFunction(FunctionDefinition const& _function);
/// Adds the given modifier to the list by name if the name is not present already.
void addModifier(ModifierDefinition const& _modifier);
void setCompiledContracts(std::map<ContractDefinition const*, bytes const*> const& _contracts) { m_compiledContracts = _contracts; } void setCompiledContracts(std::map<ContractDefinition const*, bytes const*> const& _contracts) { m_compiledContracts = _contracts; }
bytes const& getCompiledContract(ContractDefinition const& _contract) const; bytes const& getCompiledContract(ContractDefinition const& _contract) const;
@ -59,6 +61,7 @@ public:
eth::AssemblyItem getFunctionEntryLabel(FunctionDefinition const& _function) const; eth::AssemblyItem getFunctionEntryLabel(FunctionDefinition const& _function) const;
/// @returns the entry label of the given function and takes overrides into account. /// @returns the entry label of the given function and takes overrides into account.
eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const; eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const;
ModifierDefinition const& getFunctionModifier(std::string const& _name) const;
/// Returns the distance of the given local variable from the bottom of the stack (of the current function). /// Returns the distance of the given local variable from the bottom of the stack (of the current function).
unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const; unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const;
/// If supplied by a value returned by @ref getBaseStackOffsetOfVariable(variable), returns /// If supplied by a value returned by @ref getBaseStackOffsetOfVariable(variable), returns
@ -118,6 +121,8 @@ private:
std::map<Declaration const*, eth::AssemblyItem> m_functionEntryLabels; std::map<Declaration const*, eth::AssemblyItem> m_functionEntryLabels;
/// Labels pointing to the entry points of function overrides. /// Labels pointing to the entry points of function overrides.
std::map<std::string, eth::AssemblyItem> m_virtualFunctionEntryLabels; std::map<std::string, eth::AssemblyItem> m_virtualFunctionEntryLabels;
/// Mapping to obtain function modifiers by name. Should be filled from derived to base.
std::map<std::string, ModifierDefinition const*> m_functionModifiers;
}; };
} }