fixup! Impleme

This commit is contained in:
Mathias Baumann 2021-01-06 11:49:54 +01:00 committed by Kamil Śliwak
parent 3838a3a29f
commit 23ca9c4324
4 changed files with 61 additions and 73 deletions

View File

@ -22,6 +22,55 @@
using namespace std; using namespace std;
using namespace solidity::frontend; using namespace solidity::frontend;
FunctionCallGraphBuilder::FunctionCallGraphBuilder(ContractDefinition const& _contract):
m_contract(&_contract),
m_graph(std::make_unique<ContractCallGraph>(_contract))
{
// Create graph for constructor, state vars, etc
m_currentNode = SpecialNode::EntryCreation;
m_currentDispatch = SpecialNode::InternalCreationDispatch;
for (ContractDefinition const* contract: _contract.annotation().linearizedBaseContracts | boost::adaptors::reversed)
{
for (auto const* stateVar: contract->stateVariables())
stateVar->accept(*this);
for (auto arg: contract->baseContracts())
arg->accept(*this);
if (contract->constructor())
{
add(*m_currentNode, contract->constructor());
contract->constructor()->accept(*this);
m_currentNode = contract->constructor();
}
}
m_currentNode.reset();
m_currentDispatch = SpecialNode::InternalDispatch;
// Create graph for all publicly reachable functions
for (auto& [hash, functionType]: _contract.interfaceFunctionList())
{
(void)hash;
if (auto const* funcDef = dynamic_cast<FunctionDefinition const*>(&functionType->declaration()))
if (!m_graph->edges.count(funcDef))
visitCallable(funcDef);
// Add all external functions to the RuntimeDispatch
add(SpecialNode::Entry, &functionType->declaration());
}
// Add all InternalCreationDispatch calls to the RuntimeDispatch as well
add(SpecialNode::InternalDispatch, SpecialNode::InternalCreationDispatch);
if (_contract.fallbackFunction())
add(SpecialNode::Entry, _contract.fallbackFunction());
if (_contract.receiveFunction())
add(SpecialNode::Entry, _contract.receiveFunction());
}
bool FunctionCallGraphBuilder::CompareByID::operator()(Node const& _lhs, Node const& _rhs) const bool FunctionCallGraphBuilder::CompareByID::operator()(Node const& _lhs, Node const& _rhs) const
{ {
if (_lhs.index() != _rhs.index()) if (_lhs.index() != _rhs.index())
@ -48,57 +97,7 @@ bool FunctionCallGraphBuilder::CompareByID::operator()(int64_t _lhs, Node const&
unique_ptr<FunctionCallGraphBuilder::ContractCallGraph> FunctionCallGraphBuilder::create(ContractDefinition const& _contract) unique_ptr<FunctionCallGraphBuilder::ContractCallGraph> FunctionCallGraphBuilder::create(ContractDefinition const& _contract)
{ {
FunctionCallGraphBuilder builder; return FunctionCallGraphBuilder(_contract).m_graph;
builder.m_contract = &_contract;
builder.m_graph = make_unique<ContractCallGraph>(_contract);
// Create graph for constructor, state vars, etc
builder.m_currentNode = SpecialNode::EntryCreation;
builder.m_currentDispatch = SpecialNode::InternalCreationDispatch;
for (ContractDefinition const* contract: _contract.annotation().linearizedBaseContracts | boost::adaptors::reversed)
{
for (auto const* stateVar: contract->stateVariables())
stateVar->accept(builder);
for (auto arg: contract->baseContracts())
arg->accept(builder);
if (contract->constructor())
{
builder.add(*builder.m_currentNode, contract->constructor());
contract->constructor()->accept(builder);
builder.m_currentNode = contract->constructor();
}
}
builder.m_currentNode.reset();
builder.m_currentDispatch = SpecialNode::InternalDispatch;
// Create graph for all publicly reachable functions
for (auto& [hash, functionType]: _contract.interfaceFunctionList())
{
(void)hash;
if (auto const* funcDef = dynamic_cast<FunctionDefinition const*>(&functionType->declaration()))
if (!builder.m_graph->edges.count(funcDef))
builder.visitCallable(funcDef);
// Add all external functions to the RuntimeDispatch
builder.add(SpecialNode::Entry, &functionType->declaration());
}
// Add all InternalCreationDispatch calls to the RuntimeDispatch as well
builder.add(SpecialNode::InternalDispatch, SpecialNode::InternalCreationDispatch);
if (_contract.fallbackFunction())
builder.add(SpecialNode::Entry, _contract.fallbackFunction());
if (_contract.receiveFunction())
builder.add(SpecialNode::Entry, _contract.receiveFunction());
return std::move(builder.m_graph);
} }
bool FunctionCallGraphBuilder::visit(Identifier const& _identifier) bool FunctionCallGraphBuilder::visit(Identifier const& _identifier)
@ -159,14 +158,16 @@ void FunctionCallGraphBuilder::endVisit(MemberAccess const& _memberAccess)
void FunctionCallGraphBuilder::endVisit(ModifierInvocation const& _modifierInvocation) void FunctionCallGraphBuilder::endVisit(ModifierInvocation const& _modifierInvocation)
{ {
VirtualLookup const& requiredLookup = *_modifierInvocation.name().annotation().requiredLookup;
if (auto const* modifier = dynamic_cast<ModifierDefinition const*>(_modifierInvocation.name().annotation().referencedDeclaration)) if (auto const* modifier = dynamic_cast<ModifierDefinition const*>(_modifierInvocation.name().annotation().referencedDeclaration))
{ {
if (*_modifierInvocation.name().annotation().requiredLookup == VirtualLookup::Virtual) if (requiredLookup == VirtualLookup::Virtual)
modifier = &modifier->resolveVirtual(*m_contract); modifier = &modifier->resolveVirtual(*m_contract);
else else
solAssert(*_modifierInvocation.name().annotation().requiredLookup == VirtualLookup::Static, ""); solAssert(requiredLookup == VirtualLookup::Static, "");
processFunction(*modifier); processFunction(*modifier, requiredLookup == VirtualLookup::Static);
} }
} }
@ -187,21 +188,6 @@ void FunctionCallGraphBuilder::visitCallable(CallableDeclaration const* _callabl
m_currentNode = previousNode; m_currentNode = previousNode;
} }
void FunctionCallGraphBuilder::visitConstructor(ContractDefinition const& _contract)
{
for (auto const* stateVar: _contract.stateVariables())
stateVar->accept(*this);
for (auto arg: _contract.baseContracts())
arg->accept(*this);
if (_contract.constructor())
{
add(*m_currentNode, _contract.constructor());
_contract.constructor()->accept(*this);
}
}
bool FunctionCallGraphBuilder::add(Node _caller, Node _callee) bool FunctionCallGraphBuilder::add(Node _caller, Node _callee)
{ {
return m_graph->edges[_caller].insert(_callee).second; return m_graph->edges[_caller].insert(_callee).second;
@ -215,5 +201,6 @@ void FunctionCallGraphBuilder::processFunction(CallableDeclaration const& _calla
// Create edge to creation dispatch // Create edge to creation dispatch
if (!_calledDirectly) if (!_calledDirectly)
add(m_currentDispatch, &_callable); add(m_currentDispatch, &_callable);
visitCallable(&_callable, _calledDirectly); visitCallable(&_callable, _calledDirectly);
} }

View File

@ -78,19 +78,20 @@ public:
static std::unique_ptr<ContractCallGraph> create(ContractDefinition const& _contract); static std::unique_ptr<ContractCallGraph> create(ContractDefinition const& _contract);
private: private:
FunctionCallGraphBuilder(ContractDefinition const& _contract);
bool visit(Identifier const& _identifier) override; bool visit(Identifier const& _identifier) override;
bool visit(NewExpression const& _newExpression) override; bool visit(NewExpression const& _newExpression) override;
void endVisit(MemberAccess const& _memberAccess) override; void endVisit(MemberAccess const& _memberAccess) override;
void endVisit(ModifierInvocation const& _modifierInvocation) override; void endVisit(ModifierInvocation const& _modifierInvocation) override;
void visitCallable(CallableDeclaration const* _callable, bool _directCall = true); void visitCallable(CallableDeclaration const* _callable, bool _directCall = true);
void visitConstructor(ContractDefinition const& _contract);
bool add(Node _caller, Node _callee); bool add(Node _caller, Node _callee);
void processFunction(CallableDeclaration const& _callable, bool _calledDirectly = true); void processFunction(CallableDeclaration const& _callable, bool _calledDirectly = true);
ContractDefinition const* m_contract = nullptr; ContractDefinition const* m_contract = nullptr;
std::optional<Node> m_currentNode; std::optional<Node> m_currentNode = SpecialNode::EntryCreation;
std::unique_ptr<ContractCallGraph> m_graph = nullptr; std::unique_ptr<ContractCallGraph> m_graph = nullptr;
Node m_currentDispatch = SpecialNode::InternalCreationDispatch; Node m_currentDispatch = SpecialNode::InternalCreationDispatch;
}; };

View File

@ -406,7 +406,7 @@ bool CompilerStack::analyze()
if (source->ast) if (source->ast)
for (ASTPointer<ASTNode> const& node: source->ast->nodes()) for (ASTPointer<ASTNode> const& node: source->ast->nodes())
if (ContractDefinition* contract = dynamic_cast<ContractDefinition*>(node.get())) if (ContractDefinition* contract = dynamic_cast<ContractDefinition*>(node.get()))
m_contractCallGraphs.emplace(contract, FunctionCallGraphBuilder::create(*contract)); m_contractCallGraphs.emplace(std::piecewise_construct, std::forward_as_tuple(contract), std::forward_as_tuple(FunctionCallGraphBuilder::create(*contract)));
} }
if (noErrors) if (noErrors)

View File

@ -476,7 +476,7 @@ private:
bool m_generateIR = false; bool m_generateIR = false;
bool m_generateEwasm = false; bool m_generateEwasm = false;
std::map<std::string, util::h160> m_libraries; std::map<std::string, util::h160> m_libraries;
std::map<ContractDefinition const*, std::shared_ptr<FunctionCallGraphBuilder::ContractCallGraph> const> m_contractCallGraphs; std::map<ContractDefinition const*, std::unique_ptr<FunctionCallGraphBuilder::ContractCallGraph> const> m_contractCallGraphs;
/// list of path prefix remappings, e.g. mylibrary: github.com/ethereum = /usr/local/ethereum /// list of path prefix remappings, e.g. mylibrary: github.com/ethereum = /usr/local/ethereum
/// "context:prefix=target" /// "context:prefix=target"
std::vector<Remapping> m_remappings; std::vector<Remapping> m_remappings;