fixup! Imp

This commit is contained in:
Mathias Baumann 2021-01-05 16:47:12 +01:00 committed by Kamil Śliwak
parent 1bfc766139
commit 3838a3a29f
3 changed files with 36 additions and 23 deletions

View File

@ -17,6 +17,7 @@
// SPDX-License-Identifier: GPL-3.0 // SPDX-License-Identifier: GPL-3.0
#include <libsolidity/analysis/FunctionCallGraph.h> #include <libsolidity/analysis/FunctionCallGraph.h>
#include <boost/range/adaptor/reversed.hpp>
using namespace std; using namespace std;
using namespace solidity::frontend; using namespace solidity::frontend;
@ -45,46 +46,59 @@ bool FunctionCallGraphBuilder::CompareByID::operator()(int64_t _lhs, Node const&
return _lhs < std::get<ASTNode const*>(_rhs)->id(); return _lhs < std::get<ASTNode const*>(_rhs)->id();
} }
shared_ptr<FunctionCallGraphBuilder::ContractCallGraph> FunctionCallGraphBuilder::create(ContractDefinition const& _contract) unique_ptr<FunctionCallGraphBuilder::ContractCallGraph> FunctionCallGraphBuilder::create(ContractDefinition const& _contract)
{ {
m_contract = &_contract; FunctionCallGraphBuilder builder;
m_graph = make_shared<ContractCallGraph>(_contract); builder.m_contract = &_contract;
builder.m_graph = make_unique<ContractCallGraph>(_contract);
// Create graph for constructor, state vars, etc // Create graph for constructor, state vars, etc
m_currentNode = SpecialNode::EntryCreation; builder.m_currentNode = SpecialNode::EntryCreation;
m_currentDispatch = SpecialNode::InternalCreationDispatch; builder.m_currentDispatch = SpecialNode::InternalCreationDispatch;
for (ContractDefinition const* contract: _contract.annotation().linearizedBaseContracts)
visitConstructor(*contract);
m_currentNode.reset(); for (ContractDefinition const* contract: _contract.annotation().linearizedBaseContracts | boost::adaptors::reversed)
m_currentDispatch = SpecialNode::InternalDispatch; {
for (auto const* stateVar: contract->stateVariables())
stateVar->accept(builder);
for (auto arg: contract->baseContracts())
arg->accept(builder);
if (contract->constructor())
{
builder.add(*builder.m_currentNode, contract->constructor());
contract->constructor()->accept(builder);
builder.m_currentNode = contract->constructor();
}
}
builder.m_currentNode.reset();
builder.m_currentDispatch = SpecialNode::InternalDispatch;
// Create graph for all publicly reachable functions // Create graph for all publicly reachable functions
for (auto& [hash, functionType]: _contract.interfaceFunctionList()) for (auto& [hash, functionType]: _contract.interfaceFunctionList())
{ {
(void)hash; (void)hash;
if (auto const* funcDef = dynamic_cast<FunctionDefinition const*>(&functionType->declaration())) if (auto const* funcDef = dynamic_cast<FunctionDefinition const*>(&functionType->declaration()))
if (!m_graph->edges.count(funcDef)) if (!builder.m_graph->edges.count(funcDef))
visitCallable(funcDef); builder.visitCallable(funcDef);
// Add all external functions to the RuntimeDispatch // Add all external functions to the RuntimeDispatch
add(SpecialNode::Entry, &functionType->declaration()); builder.add(SpecialNode::Entry, &functionType->declaration());
} }
// Add all InternalCreationDispatch calls to the RuntimeDispatch as well // Add all InternalCreationDispatch calls to the RuntimeDispatch as well
add(SpecialNode::InternalDispatch, SpecialNode::InternalCreationDispatch); builder.add(SpecialNode::InternalDispatch, SpecialNode::InternalCreationDispatch);
if (_contract.fallbackFunction()) if (_contract.fallbackFunction())
add(SpecialNode::Entry, _contract.fallbackFunction()); builder.add(SpecialNode::Entry, _contract.fallbackFunction());
if (_contract.receiveFunction()) if (_contract.receiveFunction())
add(SpecialNode::Entry, _contract.receiveFunction()); builder.add(SpecialNode::Entry, _contract.receiveFunction());
m_contract = nullptr; return std::move(builder.m_graph);
solAssert(!m_currentNode.has_value(), "Current node not properly reset.");
return m_graph;
} }
bool FunctionCallGraphBuilder::visit(Identifier const& _identifier) bool FunctionCallGraphBuilder::visit(Identifier const& _identifier)

View File

@ -75,7 +75,7 @@ public:
std::set<ContractDefinition const*, ASTNode::CompareByID> createdContracts; std::set<ContractDefinition const*, ASTNode::CompareByID> createdContracts;
}; };
std::shared_ptr<ContractCallGraph> create(ContractDefinition const& _contract); static std::unique_ptr<ContractCallGraph> create(ContractDefinition const& _contract);
private: private:
bool visit(Identifier const& _identifier) override; bool visit(Identifier const& _identifier) override;
@ -91,7 +91,7 @@ private:
ContractDefinition const* m_contract = nullptr; ContractDefinition const* m_contract = nullptr;
std::optional<Node> m_currentNode; std::optional<Node> m_currentNode;
std::shared_ptr<ContractCallGraph> m_graph = nullptr; std::unique_ptr<ContractCallGraph> m_graph = nullptr;
Node m_currentDispatch = SpecialNode::InternalCreationDispatch; Node m_currentDispatch = SpecialNode::InternalCreationDispatch;
}; };

View File

@ -402,12 +402,11 @@ bool CompilerStack::analyze()
if (noErrors) if (noErrors)
{ {
FunctionCallGraphBuilder builder;
for (Source const* source: m_sourceOrder) for (Source const* source: m_sourceOrder)
if (source->ast) if (source->ast)
for (ASTPointer<ASTNode> const& node: source->ast->nodes()) for (ASTPointer<ASTNode> const& node: source->ast->nodes())
if (ContractDefinition* contract = dynamic_cast<ContractDefinition*>(node.get())) if (ContractDefinition* contract = dynamic_cast<ContractDefinition*>(node.get()))
m_contractCallGraphs.emplace(contract, builder.create(*contract)); m_contractCallGraphs.emplace(contract, FunctionCallGraphBuilder::create(*contract));
} }
if (noErrors) if (noErrors)