mirror of
https://github.com/ethereum/solidity
synced 2023-10-03 13:03:40 +00:00
Modifier overrides and callgraph analysis.
This commit is contained in:
parent
7ded95c776
commit
fd5899d038
@ -33,7 +33,11 @@ namespace solidity
|
|||||||
|
|
||||||
void CallGraph::addNode(ASTNode const& _node)
|
void CallGraph::addNode(ASTNode const& _node)
|
||||||
{
|
{
|
||||||
_node.accept(*this);
|
if (!m_nodesSeen.count(&_node))
|
||||||
|
{
|
||||||
|
m_workQueue.push(&_node);
|
||||||
|
m_nodesSeen.insert(&_node);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
set<FunctionDefinition const*> const& CallGraph::getCalls()
|
set<FunctionDefinition const*> const& CallGraph::getCalls()
|
||||||
@ -53,20 +57,26 @@ void CallGraph::computeCallGraph()
|
|||||||
|
|
||||||
bool CallGraph::visit(Identifier const& _identifier)
|
bool CallGraph::visit(Identifier const& _identifier)
|
||||||
{
|
{
|
||||||
FunctionDefinition const* fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration());
|
if (auto fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration()))
|
||||||
if (fun)
|
|
||||||
{
|
{
|
||||||
if (m_overrideResolver)
|
if (m_functionOverrideResolver)
|
||||||
fun = (*m_overrideResolver)(fun->getName());
|
fun = (*m_functionOverrideResolver)(fun->getName());
|
||||||
solAssert(fun, "Error finding override for function " + fun->getName());
|
solAssert(fun, "Error finding override for function " + fun->getName());
|
||||||
addFunction(*fun);
|
addNode(*fun);
|
||||||
|
}
|
||||||
|
if (auto modifier = dynamic_cast<ModifierDefinition const*>(_identifier.getReferencedDeclaration()))
|
||||||
|
{
|
||||||
|
if (m_modifierOverrideResolver)
|
||||||
|
modifier = (*m_modifierOverrideResolver)(modifier->getName());
|
||||||
|
solAssert(modifier, "Error finding override for modifier " + modifier->getName());
|
||||||
|
addNode(*modifier);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CallGraph::visit(FunctionDefinition const& _function)
|
bool CallGraph::visit(FunctionDefinition const& _function)
|
||||||
{
|
{
|
||||||
addFunction(_function);
|
m_functionsSeen.insert(&_function);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -83,7 +93,7 @@ bool CallGraph::visit(MemberAccess const& _memberAccess)
|
|||||||
for (ASTPointer<FunctionDefinition> const& function: contract.getDefinedFunctions())
|
for (ASTPointer<FunctionDefinition> const& function: contract.getDefinedFunctions())
|
||||||
if (function->getName() == _memberAccess.getMemberName())
|
if (function->getName() == _memberAccess.getMemberName())
|
||||||
{
|
{
|
||||||
addFunction(*function);
|
addNode(*function);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -91,14 +101,5 @@ bool CallGraph::visit(MemberAccess const& _memberAccess)
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CallGraph::addFunction(FunctionDefinition const& _function)
|
|
||||||
{
|
|
||||||
if (!m_functionsSeen.count(&_function))
|
|
||||||
{
|
|
||||||
m_functionsSeen.insert(&_function);
|
|
||||||
m_workQueue.push(&_function);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
15
CallGraph.h
15
CallGraph.h
@ -39,9 +39,13 @@ namespace solidity
|
|||||||
class CallGraph: private ASTConstVisitor
|
class CallGraph: private ASTConstVisitor
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
using OverrideResolver = std::function<FunctionDefinition const*(std::string const&)>;
|
using FunctionOverrideResolver = std::function<FunctionDefinition const*(std::string const&)>;
|
||||||
|
using ModifierOverrideResolver = std::function<ModifierDefinition const*(std::string const&)>;
|
||||||
|
|
||||||
CallGraph(OverrideResolver const& _overrideResolver): m_overrideResolver(&_overrideResolver) {}
|
CallGraph(FunctionOverrideResolver const& _functionOverrideResolver,
|
||||||
|
ModifierOverrideResolver const& _modifierOverrideResolver):
|
||||||
|
m_functionOverrideResolver(&_functionOverrideResolver),
|
||||||
|
m_modifierOverrideResolver(&_modifierOverrideResolver) {}
|
||||||
|
|
||||||
void addNode(ASTNode const& _node);
|
void addNode(ASTNode const& _node);
|
||||||
|
|
||||||
@ -53,11 +57,12 @@ private:
|
|||||||
virtual bool visit(MemberAccess const& _memberAccess) override;
|
virtual bool visit(MemberAccess const& _memberAccess) override;
|
||||||
|
|
||||||
void computeCallGraph();
|
void computeCallGraph();
|
||||||
void addFunction(FunctionDefinition const& _function);
|
|
||||||
|
|
||||||
OverrideResolver const* m_overrideResolver;
|
FunctionOverrideResolver const* m_functionOverrideResolver;
|
||||||
|
ModifierOverrideResolver const* m_modifierOverrideResolver;
|
||||||
|
std::set<ASTNode const*> m_nodesSeen;
|
||||||
std::set<FunctionDefinition const*> m_functionsSeen;
|
std::set<FunctionDefinition const*> m_functionsSeen;
|
||||||
std::queue<FunctionDefinition const*> m_workQueue;
|
std::queue<ASTNode const*> m_workQueue;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
52
Compiler.cpp
52
Compiler.cpp
@ -42,9 +42,13 @@ void Compiler::compileContract(ContractDefinition const& _contract,
|
|||||||
initializeContext(_contract, _contracts);
|
initializeContext(_contract, _contracts);
|
||||||
|
|
||||||
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
|
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
|
||||||
|
{
|
||||||
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
||||||
if (!function->isConstructor())
|
if (!function->isConstructor())
|
||||||
m_context.addFunction(*function);
|
m_context.addFunction(*function);
|
||||||
|
for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers())
|
||||||
|
m_context.addModifier(*modifier);
|
||||||
|
}
|
||||||
|
|
||||||
appendFunctionSelector(_contract);
|
appendFunctionSelector(_contract);
|
||||||
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
|
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
|
||||||
@ -67,6 +71,13 @@ void Compiler::initializeContext(ContractDefinition const& _contract,
|
|||||||
|
|
||||||
void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext)
|
void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext)
|
||||||
{
|
{
|
||||||
|
std::vector<ContractDefinition const*> const& bases = _contract.getLinearizedBaseContracts();
|
||||||
|
|
||||||
|
// Make all modifiers known to the context.
|
||||||
|
for (ContractDefinition const* contract: bases)
|
||||||
|
for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers())
|
||||||
|
m_context.addModifier(*modifier);
|
||||||
|
|
||||||
// arguments for base constructors, filled in derived-to-base order
|
// arguments for base constructors, filled in derived-to-base order
|
||||||
map<ContractDefinition const*, vector<ASTPointer<Expression>> const*> baseArguments;
|
map<ContractDefinition const*, vector<ASTPointer<Expression>> const*> baseArguments;
|
||||||
set<FunctionDefinition const*> neededFunctions;
|
set<FunctionDefinition const*> neededFunctions;
|
||||||
@ -74,10 +85,8 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp
|
|||||||
|
|
||||||
// Determine the arguments that are used for the base constructors and also which functions
|
// Determine the arguments that are used for the base constructors and also which functions
|
||||||
// are needed at compile time.
|
// are needed at compile time.
|
||||||
std::vector<ContractDefinition const*> const& bases = _contract.getLinearizedBaseContracts();
|
|
||||||
for (ContractDefinition const* contract: bases)
|
for (ContractDefinition const* contract: bases)
|
||||||
{
|
{
|
||||||
//TODO include modifiers
|
|
||||||
if (FunctionDefinition const* constructor = contract->getConstructor())
|
if (FunctionDefinition const* constructor = contract->getConstructor())
|
||||||
nodesUsedInConstructors.insert(constructor);
|
nodesUsedInConstructors.insert(constructor);
|
||||||
for (ASTPointer<InheritanceSpecifier> const& base: contract->getBaseContracts())
|
for (ASTPointer<InheritanceSpecifier> const& base: contract->getBaseContracts())
|
||||||
@ -94,7 +103,7 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto overrideResolver = [&](string const& _name) -> FunctionDefinition const*
|
auto functionOverrideResolver = [&](string const& _name) -> FunctionDefinition const*
|
||||||
{
|
{
|
||||||
for (ContractDefinition const* contract: bases)
|
for (ContractDefinition const* contract: bases)
|
||||||
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
||||||
@ -102,21 +111,26 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp
|
|||||||
return function.get();
|
return function.get();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
};
|
};
|
||||||
|
auto modifierOverrideResolver = [&](string const& _name) -> ModifierDefinition const*
|
||||||
|
{
|
||||||
|
return &m_context.getFunctionModifier(_name);
|
||||||
|
};
|
||||||
|
|
||||||
neededFunctions = getFunctionsCalled(nodesUsedInConstructors, overrideResolver);
|
neededFunctions = getFunctionsCalled(nodesUsedInConstructors, functionOverrideResolver,
|
||||||
|
modifierOverrideResolver);
|
||||||
|
|
||||||
// First add all overrides (or the functions themselves if there is no override)
|
// First add all overrides (or the functions themselves if there is no override)
|
||||||
for (FunctionDefinition const* fun: neededFunctions)
|
for (FunctionDefinition const* fun: neededFunctions)
|
||||||
{
|
{
|
||||||
FunctionDefinition const* override = nullptr;
|
FunctionDefinition const* override = nullptr;
|
||||||
if (!fun->isConstructor())
|
if (!fun->isConstructor())
|
||||||
override = overrideResolver(fun->getName());
|
override = functionOverrideResolver(fun->getName());
|
||||||
if (!!override && neededFunctions.count(override))
|
if (!!override && neededFunctions.count(override))
|
||||||
m_context.addFunction(*override);
|
m_context.addFunction(*override);
|
||||||
}
|
}
|
||||||
// now add the rest
|
// now add the rest
|
||||||
for (FunctionDefinition const* fun: neededFunctions)
|
for (FunctionDefinition const* fun: neededFunctions)
|
||||||
if (fun->isConstructor() || overrideResolver(fun->getName()) != fun)
|
if (fun->isConstructor() || functionOverrideResolver(fun->getName()) != fun)
|
||||||
m_context.addFunction(*fun);
|
m_context.addFunction(*fun);
|
||||||
|
|
||||||
// Call constructors in base-to-derived order.
|
// Call constructors in base-to-derived order.
|
||||||
@ -176,9 +190,10 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor)
|
|||||||
}
|
}
|
||||||
|
|
||||||
set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes,
|
set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes,
|
||||||
function<FunctionDefinition const*(string const&)> const& _resolveOverrides)
|
function<FunctionDefinition const*(string const&)> const& _resolveFunctionOverrides,
|
||||||
|
function<ModifierDefinition const*(string const&)> const& _resolveModifierOverrides)
|
||||||
{
|
{
|
||||||
CallGraph callgraph(_resolveOverrides);
|
CallGraph callgraph(_resolveFunctionOverrides, _resolveModifierOverrides);
|
||||||
for (ASTNode const* node: _nodes)
|
for (ASTNode const* node: _nodes)
|
||||||
callgraph.addNode(*node);
|
callgraph.addNode(*node);
|
||||||
return callgraph.getCalls();
|
return callgraph.getCalls();
|
||||||
@ -471,25 +486,22 @@ void Compiler::appendModifierOrFunctionCode()
|
|||||||
{
|
{
|
||||||
ASTPointer<ModifierInvocation> const& modifierInvocation = m_currentFunction->getModifiers()[m_modifierDepth];
|
ASTPointer<ModifierInvocation> const& modifierInvocation = m_currentFunction->getModifiers()[m_modifierDepth];
|
||||||
|
|
||||||
// TODO get the most derived override of the modifier
|
ModifierDefinition const& modifier = m_context.getFunctionModifier(modifierInvocation->getName()->getName());
|
||||||
ModifierDefinition const* modifier = dynamic_cast<ModifierDefinition const*>(
|
solAssert(modifier.getParameters().size() == modifierInvocation->getArguments().size(), "");
|
||||||
modifierInvocation->getName()->getReferencedDeclaration());
|
for (unsigned i = 0; i < modifier.getParameters().size(); ++i)
|
||||||
solAssert(!!modifier, "Modifier not found.");
|
|
||||||
solAssert(modifier->getParameters().size() == modifierInvocation->getArguments().size(), "");
|
|
||||||
for (unsigned i = 0; i < modifier->getParameters().size(); ++i)
|
|
||||||
{
|
{
|
||||||
m_context.addVariable(*modifier->getParameters()[i]);
|
m_context.addVariable(*modifier.getParameters()[i]);
|
||||||
compileExpression(*modifierInvocation->getArguments()[i],
|
compileExpression(*modifierInvocation->getArguments()[i],
|
||||||
modifier->getParameters()[i]->getType());
|
modifier.getParameters()[i]->getType());
|
||||||
}
|
}
|
||||||
for (VariableDeclaration const* localVariable: modifier->getLocalVariables())
|
for (VariableDeclaration const* localVariable: modifier.getLocalVariables())
|
||||||
m_context.addAndInitializeVariable(*localVariable);
|
m_context.addAndInitializeVariable(*localVariable);
|
||||||
|
|
||||||
unsigned const c_stackSurplus = CompilerUtils::getSizeOnStack(modifier->getParameters()) +
|
unsigned const c_stackSurplus = CompilerUtils::getSizeOnStack(modifier.getParameters()) +
|
||||||
CompilerUtils::getSizeOnStack(modifier->getLocalVariables());
|
CompilerUtils::getSizeOnStack(modifier.getLocalVariables());
|
||||||
m_stackCleanupForReturn += c_stackSurplus;
|
m_stackCleanupForReturn += c_stackSurplus;
|
||||||
|
|
||||||
modifier->getBody().accept(*this);
|
modifier.getBody().accept(*this);
|
||||||
|
|
||||||
for (unsigned i = 0; i < c_stackSurplus; ++i)
|
for (unsigned i = 0; i < c_stackSurplus; ++i)
|
||||||
m_context << eth::Instruction::POP;
|
m_context << eth::Instruction::POP;
|
||||||
|
@ -53,7 +53,8 @@ private:
|
|||||||
/// Recursively searches the call graph and returns all functions referenced inside _nodes.
|
/// Recursively searches the call graph and returns all functions referenced inside _nodes.
|
||||||
/// _resolveOverride is called to resolve virtual function overrides.
|
/// _resolveOverride is called to resolve virtual function overrides.
|
||||||
std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes,
|
std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes,
|
||||||
std::function<FunctionDefinition const*(std::string const&)> const& _resolveOverride);
|
std::function<FunctionDefinition const*(std::string const&)> const& _resolveFunctionOverride,
|
||||||
|
std::function<ModifierDefinition const*(std::string const&)> const& _resolveModifierOverride);
|
||||||
void appendFunctionSelector(ContractDefinition const& _contract);
|
void appendFunctionSelector(ContractDefinition const& _contract);
|
||||||
/// Creates code that unpacks the arguments for the given function, from memory if
|
/// Creates code that unpacks the arguments for the given function, from memory if
|
||||||
/// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes.
|
/// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes.
|
||||||
|
@ -66,6 +66,11 @@ void CompilerContext::addFunction(FunctionDefinition const& _function)
|
|||||||
m_virtualFunctionEntryLabels.insert(make_pair(_function.getName(), tag));
|
m_virtualFunctionEntryLabels.insert(make_pair(_function.getName(), tag));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CompilerContext::addModifier(ModifierDefinition const& _modifier)
|
||||||
|
{
|
||||||
|
m_functionModifiers.insert(make_pair(_modifier.getName(), &_modifier));
|
||||||
|
}
|
||||||
|
|
||||||
bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const
|
bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const
|
||||||
{
|
{
|
||||||
auto ret = m_compiledContracts.find(&_contract);
|
auto ret = m_compiledContracts.find(&_contract);
|
||||||
@ -92,6 +97,13 @@ eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefiniti
|
|||||||
return res->second.tag();
|
return res->second.tag();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ModifierDefinition const& CompilerContext::getFunctionModifier(string const& _name) const
|
||||||
|
{
|
||||||
|
auto res = m_functionModifiers.find(_name);
|
||||||
|
solAssert(res != m_functionModifiers.end(), "Function modifier override not found.");
|
||||||
|
return *res->second;
|
||||||
|
}
|
||||||
|
|
||||||
unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const
|
unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const
|
||||||
{
|
{
|
||||||
auto res = m_localVariables.find(&_declaration);
|
auto res = m_localVariables.find(&_declaration);
|
||||||
|
@ -45,6 +45,8 @@ public:
|
|||||||
void addVariable(VariableDeclaration const& _declaration, unsigned _offsetToCurrent = 0);
|
void addVariable(VariableDeclaration const& _declaration, unsigned _offsetToCurrent = 0);
|
||||||
void addAndInitializeVariable(VariableDeclaration const& _declaration);
|
void addAndInitializeVariable(VariableDeclaration const& _declaration);
|
||||||
void addFunction(FunctionDefinition const& _function);
|
void addFunction(FunctionDefinition const& _function);
|
||||||
|
/// Adds the given modifier to the list by name if the name is not present already.
|
||||||
|
void addModifier(ModifierDefinition const& _modifier);
|
||||||
|
|
||||||
void setCompiledContracts(std::map<ContractDefinition const*, bytes const*> const& _contracts) { m_compiledContracts = _contracts; }
|
void setCompiledContracts(std::map<ContractDefinition const*, bytes const*> const& _contracts) { m_compiledContracts = _contracts; }
|
||||||
bytes const& getCompiledContract(ContractDefinition const& _contract) const;
|
bytes const& getCompiledContract(ContractDefinition const& _contract) const;
|
||||||
@ -59,6 +61,7 @@ public:
|
|||||||
eth::AssemblyItem getFunctionEntryLabel(FunctionDefinition const& _function) const;
|
eth::AssemblyItem getFunctionEntryLabel(FunctionDefinition const& _function) const;
|
||||||
/// @returns the entry label of the given function and takes overrides into account.
|
/// @returns the entry label of the given function and takes overrides into account.
|
||||||
eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const;
|
eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const;
|
||||||
|
ModifierDefinition const& getFunctionModifier(std::string const& _name) const;
|
||||||
/// Returns the distance of the given local variable from the bottom of the stack (of the current function).
|
/// Returns the distance of the given local variable from the bottom of the stack (of the current function).
|
||||||
unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const;
|
unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const;
|
||||||
/// If supplied by a value returned by @ref getBaseStackOffsetOfVariable(variable), returns
|
/// If supplied by a value returned by @ref getBaseStackOffsetOfVariable(variable), returns
|
||||||
@ -118,6 +121,8 @@ private:
|
|||||||
std::map<Declaration const*, eth::AssemblyItem> m_functionEntryLabels;
|
std::map<Declaration const*, eth::AssemblyItem> m_functionEntryLabels;
|
||||||
/// Labels pointing to the entry points of function overrides.
|
/// Labels pointing to the entry points of function overrides.
|
||||||
std::map<std::string, eth::AssemblyItem> m_virtualFunctionEntryLabels;
|
std::map<std::string, eth::AssemblyItem> m_virtualFunctionEntryLabels;
|
||||||
|
/// Mapping to obtain function modifiers by name. Should be filled from derived to base.
|
||||||
|
std::map<std::string, ModifierDefinition const*> m_functionModifiers;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user