diff --git a/libsolidity/analysis/FunctionCallGraph.cpp b/libsolidity/analysis/FunctionCallGraph.cpp index 117d18efa..9fe685b5f 100644 --- a/libsolidity/analysis/FunctionCallGraph.cpp +++ b/libsolidity/analysis/FunctionCallGraph.cpp @@ -17,6 +17,7 @@ // SPDX-License-Identifier: GPL-3.0 #include +#include using namespace std; using namespace solidity::frontend; @@ -45,46 +46,59 @@ bool FunctionCallGraphBuilder::CompareByID::operator()(int64_t _lhs, Node const& return _lhs < std::get(_rhs)->id(); } -shared_ptr FunctionCallGraphBuilder::create(ContractDefinition const& _contract) +unique_ptr FunctionCallGraphBuilder::create(ContractDefinition const& _contract) { - m_contract = &_contract; + FunctionCallGraphBuilder builder; - m_graph = make_shared(_contract); + builder.m_contract = &_contract; + + builder.m_graph = make_unique(_contract); // Create graph for constructor, state vars, etc - m_currentNode = SpecialNode::EntryCreation; - m_currentDispatch = SpecialNode::InternalCreationDispatch; - for (ContractDefinition const* contract: _contract.annotation().linearizedBaseContracts) - visitConstructor(*contract); + builder.m_currentNode = SpecialNode::EntryCreation; + builder.m_currentDispatch = SpecialNode::InternalCreationDispatch; - m_currentNode.reset(); - m_currentDispatch = SpecialNode::InternalDispatch; + for (ContractDefinition const* contract: _contract.annotation().linearizedBaseContracts | boost::adaptors::reversed) + { + for (auto const* stateVar: contract->stateVariables()) + stateVar->accept(builder); + + for (auto arg: contract->baseContracts()) + arg->accept(builder); + + if (contract->constructor()) + { + builder.add(*builder.m_currentNode, contract->constructor()); + contract->constructor()->accept(builder); + builder.m_currentNode = contract->constructor(); + } + } + + builder.m_currentNode.reset(); + builder.m_currentDispatch = SpecialNode::InternalDispatch; // Create graph for all publicly reachable functions for (auto& [hash, functionType]: _contract.interfaceFunctionList()) { (void)hash; if (auto const* funcDef = dynamic_cast(&functionType->declaration())) - if (!m_graph->edges.count(funcDef)) - visitCallable(funcDef); + if (!builder.m_graph->edges.count(funcDef)) + builder.visitCallable(funcDef); // Add all external functions to the RuntimeDispatch - add(SpecialNode::Entry, &functionType->declaration()); + builder.add(SpecialNode::Entry, &functionType->declaration()); } // Add all InternalCreationDispatch calls to the RuntimeDispatch as well - add(SpecialNode::InternalDispatch, SpecialNode::InternalCreationDispatch); + builder.add(SpecialNode::InternalDispatch, SpecialNode::InternalCreationDispatch); if (_contract.fallbackFunction()) - add(SpecialNode::Entry, _contract.fallbackFunction()); + builder.add(SpecialNode::Entry, _contract.fallbackFunction()); if (_contract.receiveFunction()) - add(SpecialNode::Entry, _contract.receiveFunction()); + builder.add(SpecialNode::Entry, _contract.receiveFunction()); - m_contract = nullptr; - solAssert(!m_currentNode.has_value(), "Current node not properly reset."); - - return m_graph; + return std::move(builder.m_graph); } bool FunctionCallGraphBuilder::visit(Identifier const& _identifier) diff --git a/libsolidity/analysis/FunctionCallGraph.h b/libsolidity/analysis/FunctionCallGraph.h index 6133fc19e..ee00e7edb 100644 --- a/libsolidity/analysis/FunctionCallGraph.h +++ b/libsolidity/analysis/FunctionCallGraph.h @@ -75,7 +75,7 @@ public: std::set createdContracts; }; - std::shared_ptr create(ContractDefinition const& _contract); + static std::unique_ptr create(ContractDefinition const& _contract); private: bool visit(Identifier const& _identifier) override; @@ -91,7 +91,7 @@ private: ContractDefinition const* m_contract = nullptr; std::optional m_currentNode; - std::shared_ptr m_graph = nullptr; + std::unique_ptr m_graph = nullptr; Node m_currentDispatch = SpecialNode::InternalCreationDispatch; }; diff --git a/libsolidity/interface/CompilerStack.cpp b/libsolidity/interface/CompilerStack.cpp index bfcc0fefd..9648dcd1d 100644 --- a/libsolidity/interface/CompilerStack.cpp +++ b/libsolidity/interface/CompilerStack.cpp @@ -402,12 +402,11 @@ bool CompilerStack::analyze() if (noErrors) { - FunctionCallGraphBuilder builder; for (Source const* source: m_sourceOrder) if (source->ast) for (ASTPointer const& node: source->ast->nodes()) if (ContractDefinition* contract = dynamic_cast(node.get())) - m_contractCallGraphs.emplace(contract, builder.create(*contract)); + m_contractCallGraphs.emplace(contract, FunctionCallGraphBuilder::create(*contract)); } if (noErrors)