diff --git a/libsolidity/analysis/FunctionCallGraph.cpp b/libsolidity/analysis/FunctionCallGraph.cpp index 9fe685b5f..017c26cdb 100644 --- a/libsolidity/analysis/FunctionCallGraph.cpp +++ b/libsolidity/analysis/FunctionCallGraph.cpp @@ -22,6 +22,55 @@ using namespace std; using namespace solidity::frontend; +FunctionCallGraphBuilder::FunctionCallGraphBuilder(ContractDefinition const& _contract): + m_contract(&_contract), + m_graph(std::make_unique(_contract)) +{ + // Create graph for constructor, state vars, etc + m_currentNode = SpecialNode::EntryCreation; + m_currentDispatch = SpecialNode::InternalCreationDispatch; + + for (ContractDefinition const* contract: _contract.annotation().linearizedBaseContracts | boost::adaptors::reversed) + { + for (auto const* stateVar: contract->stateVariables()) + stateVar->accept(*this); + + for (auto arg: contract->baseContracts()) + arg->accept(*this); + + if (contract->constructor()) + { + add(*m_currentNode, contract->constructor()); + contract->constructor()->accept(*this); + m_currentNode = contract->constructor(); + } + } + + m_currentNode.reset(); + m_currentDispatch = SpecialNode::InternalDispatch; + + // Create graph for all publicly reachable functions + for (auto& [hash, functionType]: _contract.interfaceFunctionList()) + { + (void)hash; + if (auto const* funcDef = dynamic_cast(&functionType->declaration())) + if (!m_graph->edges.count(funcDef)) + visitCallable(funcDef); + + // Add all external functions to the RuntimeDispatch + add(SpecialNode::Entry, &functionType->declaration()); + } + + // Add all InternalCreationDispatch calls to the RuntimeDispatch as well + add(SpecialNode::InternalDispatch, SpecialNode::InternalCreationDispatch); + + if (_contract.fallbackFunction()) + add(SpecialNode::Entry, _contract.fallbackFunction()); + + if (_contract.receiveFunction()) + add(SpecialNode::Entry, _contract.receiveFunction()); +} + bool FunctionCallGraphBuilder::CompareByID::operator()(Node const& _lhs, Node const& _rhs) const { if (_lhs.index() != _rhs.index()) @@ -48,57 +97,7 @@ bool FunctionCallGraphBuilder::CompareByID::operator()(int64_t _lhs, Node const& unique_ptr FunctionCallGraphBuilder::create(ContractDefinition const& _contract) { - FunctionCallGraphBuilder builder; - - builder.m_contract = &_contract; - - builder.m_graph = make_unique(_contract); - - // Create graph for constructor, state vars, etc - builder.m_currentNode = SpecialNode::EntryCreation; - builder.m_currentDispatch = SpecialNode::InternalCreationDispatch; - - for (ContractDefinition const* contract: _contract.annotation().linearizedBaseContracts | boost::adaptors::reversed) - { - for (auto const* stateVar: contract->stateVariables()) - stateVar->accept(builder); - - for (auto arg: contract->baseContracts()) - arg->accept(builder); - - if (contract->constructor()) - { - builder.add(*builder.m_currentNode, contract->constructor()); - contract->constructor()->accept(builder); - builder.m_currentNode = contract->constructor(); - } - } - - builder.m_currentNode.reset(); - builder.m_currentDispatch = SpecialNode::InternalDispatch; - - // Create graph for all publicly reachable functions - for (auto& [hash, functionType]: _contract.interfaceFunctionList()) - { - (void)hash; - if (auto const* funcDef = dynamic_cast(&functionType->declaration())) - if (!builder.m_graph->edges.count(funcDef)) - builder.visitCallable(funcDef); - - // Add all external functions to the RuntimeDispatch - builder.add(SpecialNode::Entry, &functionType->declaration()); - } - - // Add all InternalCreationDispatch calls to the RuntimeDispatch as well - builder.add(SpecialNode::InternalDispatch, SpecialNode::InternalCreationDispatch); - - if (_contract.fallbackFunction()) - builder.add(SpecialNode::Entry, _contract.fallbackFunction()); - - if (_contract.receiveFunction()) - builder.add(SpecialNode::Entry, _contract.receiveFunction()); - - return std::move(builder.m_graph); + return FunctionCallGraphBuilder(_contract).m_graph; } bool FunctionCallGraphBuilder::visit(Identifier const& _identifier) @@ -159,14 +158,16 @@ void FunctionCallGraphBuilder::endVisit(MemberAccess const& _memberAccess) void FunctionCallGraphBuilder::endVisit(ModifierInvocation const& _modifierInvocation) { + VirtualLookup const& requiredLookup = *_modifierInvocation.name().annotation().requiredLookup; + if (auto const* modifier = dynamic_cast(_modifierInvocation.name().annotation().referencedDeclaration)) { - if (*_modifierInvocation.name().annotation().requiredLookup == VirtualLookup::Virtual) + if (requiredLookup == VirtualLookup::Virtual) modifier = &modifier->resolveVirtual(*m_contract); else - solAssert(*_modifierInvocation.name().annotation().requiredLookup == VirtualLookup::Static, ""); + solAssert(requiredLookup == VirtualLookup::Static, ""); - processFunction(*modifier); + processFunction(*modifier, requiredLookup == VirtualLookup::Static); } } @@ -187,21 +188,6 @@ void FunctionCallGraphBuilder::visitCallable(CallableDeclaration const* _callabl m_currentNode = previousNode; } -void FunctionCallGraphBuilder::visitConstructor(ContractDefinition const& _contract) -{ - for (auto const* stateVar: _contract.stateVariables()) - stateVar->accept(*this); - - for (auto arg: _contract.baseContracts()) - arg->accept(*this); - - if (_contract.constructor()) - { - add(*m_currentNode, _contract.constructor()); - _contract.constructor()->accept(*this); - } -} - bool FunctionCallGraphBuilder::add(Node _caller, Node _callee) { return m_graph->edges[_caller].insert(_callee).second; @@ -215,5 +201,6 @@ void FunctionCallGraphBuilder::processFunction(CallableDeclaration const& _calla // Create edge to creation dispatch if (!_calledDirectly) add(m_currentDispatch, &_callable); + visitCallable(&_callable, _calledDirectly); } diff --git a/libsolidity/analysis/FunctionCallGraph.h b/libsolidity/analysis/FunctionCallGraph.h index ee00e7edb..8a092c7e5 100644 --- a/libsolidity/analysis/FunctionCallGraph.h +++ b/libsolidity/analysis/FunctionCallGraph.h @@ -78,19 +78,20 @@ public: static std::unique_ptr create(ContractDefinition const& _contract); private: + FunctionCallGraphBuilder(ContractDefinition const& _contract); + bool visit(Identifier const& _identifier) override; bool visit(NewExpression const& _newExpression) override; void endVisit(MemberAccess const& _memberAccess) override; void endVisit(ModifierInvocation const& _modifierInvocation) override; void visitCallable(CallableDeclaration const* _callable, bool _directCall = true); - void visitConstructor(ContractDefinition const& _contract); bool add(Node _caller, Node _callee); void processFunction(CallableDeclaration const& _callable, bool _calledDirectly = true); ContractDefinition const* m_contract = nullptr; - std::optional m_currentNode; + std::optional m_currentNode = SpecialNode::EntryCreation; std::unique_ptr m_graph = nullptr; Node m_currentDispatch = SpecialNode::InternalCreationDispatch; }; diff --git a/libsolidity/interface/CompilerStack.cpp b/libsolidity/interface/CompilerStack.cpp index 9648dcd1d..31db52399 100644 --- a/libsolidity/interface/CompilerStack.cpp +++ b/libsolidity/interface/CompilerStack.cpp @@ -406,7 +406,7 @@ bool CompilerStack::analyze() if (source->ast) for (ASTPointer const& node: source->ast->nodes()) if (ContractDefinition* contract = dynamic_cast(node.get())) - m_contractCallGraphs.emplace(contract, FunctionCallGraphBuilder::create(*contract)); + m_contractCallGraphs.emplace(std::piecewise_construct, std::forward_as_tuple(contract), std::forward_as_tuple(FunctionCallGraphBuilder::create(*contract))); } if (noErrors) diff --git a/libsolidity/interface/CompilerStack.h b/libsolidity/interface/CompilerStack.h index ba0c794b0..53f3b95dd 100644 --- a/libsolidity/interface/CompilerStack.h +++ b/libsolidity/interface/CompilerStack.h @@ -476,7 +476,7 @@ private: bool m_generateIR = false; bool m_generateEwasm = false; std::map m_libraries; - std::map const> m_contractCallGraphs; + std::map const> m_contractCallGraphs; /// list of path prefix remappings, e.g. mylibrary: github.com/ethereum = /usr/local/ethereum /// "context:prefix=target" std::vector m_remappings;