fixup! Implement

This commit is contained in:
Mathias Baumann 2020-12-15 15:04:22 +01:00 committed by Kamil Śliwak
parent 1452076ee2
commit d6ab70c27d
6 changed files with 69 additions and 108 deletions

View File

@ -52,8 +52,8 @@ shared_ptr<FunctionCallGraphBuilder::ContractCallGraph> FunctionCallGraphBuilder
m_graph = make_shared<ContractCallGraph>(_contract); m_graph = make_shared<ContractCallGraph>(_contract);
// Create graph for constructor, state vars, etc // Create graph for constructor, state vars, etc
m_currentNode = SpecialNode::CreationRoot; m_currentNode = SpecialNode::EntryCreation;
m_currentDispatch = SpecialNode::CreationDispatch; m_currentDispatch = SpecialNode::InternalCreationDispatch;
for (ContractDefinition const* contract: _contract.annotation().linearizedBaseContracts) for (ContractDefinition const* contract: _contract.annotation().linearizedBaseContracts)
visitConstructor(*contract); visitConstructor(*contract);
@ -67,24 +67,19 @@ shared_ptr<FunctionCallGraphBuilder::ContractCallGraph> FunctionCallGraphBuilder
if (auto const* funcDef = dynamic_cast<FunctionDefinition const*>(&functionType->declaration())) if (auto const* funcDef = dynamic_cast<FunctionDefinition const*>(&functionType->declaration()))
if (!m_graph->edges.count(funcDef)) if (!m_graph->edges.count(funcDef))
visitCallable(funcDef); visitCallable(funcDef);
// Add all external functions to the RuntimeDispatch
add(SpecialNode::Entry, &functionType->declaration());
} }
// Add all CreationDispatch calls to the RuntimeDispatch as well // Add all InternalCreationDispatch calls to the RuntimeDispatch as well
for (auto node: m_graph->edges[SpecialNode::CreationDispatch]) add(SpecialNode::InternalDispatch, SpecialNode::InternalCreationDispatch);
add(SpecialNode::InternalDispatch, node);
// Add all external functions to the RuntimeDispatch
for (auto& [hash, functionType]: _contract.interfaceFunctionList())
{
(void)hash;
add(SpecialNode::ExternalDispatch, &functionType->declaration());
}
if (_contract.fallbackFunction()) if (_contract.fallbackFunction())
add(SpecialNode::ExternalDispatch, _contract.fallbackFunction()); add(SpecialNode::Entry, _contract.fallbackFunction());
if (_contract.receiveFunction()) if (_contract.receiveFunction())
add(SpecialNode::ExternalDispatch, _contract.receiveFunction()); add(SpecialNode::Entry, _contract.receiveFunction());
m_contract = nullptr; m_contract = nullptr;
solAssert(!m_currentNode.has_value(), "Current node not properly reset."); solAssert(!m_currentNode.has_value(), "Current node not properly reset.");
@ -129,37 +124,31 @@ void FunctionCallGraphBuilder::endVisit(MemberAccess const& _memberAccess)
// Super functions // Super functions
if (*_memberAccess.annotation().requiredLookup == VirtualLookup::Super) if (*_memberAccess.annotation().requiredLookup == VirtualLookup::Super)
{
if (ContractType const* type = dynamic_cast<ContractType const*>(_memberAccess.expression().annotation().type)) if (ContractType const* type = dynamic_cast<ContractType const*>(_memberAccess.expression().annotation().type))
{ {
solAssert(type->isSuper(), ""); solAssert(type->isSuper(), "");
functionDef = &functionDef->resolveVirtual(*m_contract, type->contractDefinition().superContract(*m_contract)); functionDef = &functionDef->resolveVirtual(*m_contract, type->contractDefinition().superContract(*m_contract));
} }
}
else
solAssert(*_memberAccess.annotation().requiredLookup == VirtualLookup::Static, "");
processFunction(*functionDef, _memberAccess.annotation()); processFunction(*functionDef, _memberAccess.annotation());
return; return;
} }
void FunctionCallGraphBuilder::endVisit(FunctionCall const& _functionCall) void FunctionCallGraphBuilder::visitCallable(CallableDeclaration const* _callable, bool _directCall)
{
auto* functionType = dynamic_cast<FunctionType const*>(_functionCall.expression().annotation().type);
if (
functionType &&
functionType->kind() == FunctionType::Kind::Internal &&
!functionType->hasDeclaration()
)
add(m_currentDispatch, &_functionCall);
}
void FunctionCallGraphBuilder::visitCallable(CallableDeclaration const* _callable)
{ {
solAssert(!m_graph->edges.count(_callable), ""); solAssert(!m_graph->edges.count(_callable), "");
auto previousNode = m_currentNode; std::optional<Node> previousNode = m_currentNode;
m_currentNode = _callable; m_currentNode = _callable;
if (previousNode.has_value()) if (previousNode.has_value() && _directCall)
add(*previousNode, _callable); add(*previousNode, _callable);
if (!_directCall)
add(*m_currentNode, m_currentDispatch);
_callable->accept(*this); _callable->accept(*this);
@ -181,18 +170,9 @@ void FunctionCallGraphBuilder::visitConstructor(ContractDefinition const& _contr
} }
} }
bool FunctionCallGraphBuilder::add(Node _caller, ASTNode const* _callee) bool FunctionCallGraphBuilder::add(Node _caller, Node _callee)
{ {
solAssert(_callee != nullptr, ""); return m_graph->edges[_caller].insert(_callee).second;
auto result = m_graph->edges.find(_caller);
if (result == m_graph->edges.end())
{
m_graph->edges.emplace(_caller, std::set<ASTNode const*, ASTNode::CompareByID>{_callee});
return true;
}
return result->second.emplace(_callee).second;
} }
void FunctionCallGraphBuilder::processFunction(CallableDeclaration const& _callable, ExpressionAnnotation const& _annotation) void FunctionCallGraphBuilder::processFunction(CallableDeclaration const& _callable, ExpressionAnnotation const& _annotation)
@ -203,5 +183,5 @@ void FunctionCallGraphBuilder::processFunction(CallableDeclaration const& _calla
// Create edge to creation dispatch // Create edge to creation dispatch
if (!_annotation.calledDirectly) if (!_annotation.calledDirectly)
add(m_currentDispatch, &_callable); add(m_currentDispatch, &_callable);
visitCallable(&_callable); visitCallable(&_callable, _annotation.calledDirectly);
} }

View File

@ -32,10 +32,10 @@ namespace solidity::frontend
* functions and modifiers * functions and modifiers
* *
* Includes the following special nodes: * Includes the following special nodes:
* - CreationRoot: All calls made at contract creation originate from this node * - EntryCreation: All calls made at contract creation originate from this node. Constructors are modelled to be all called by this node instead of calling each other due to implicit constructors that don't exist at this stage.
* - CreationDispatch: Represents the internal dispatch function at creation time * - InternalCreationDispatch: Represents the internal dispatch function at creation time
* - ExternalDispatch: Represents the runtime dispatch for all external functions
* - InternalDispatch: Represents the runtime dispatch for internal function pointers and complex expressions * - InternalDispatch: Represents the runtime dispatch for internal function pointers and complex expressions
* - Entry: Represents the runtime dispatch for all external functions
* *
* Nodes are a variant of either the enum SpecialNode or an ASTNode pointer. * Nodes are a variant of either the enum SpecialNode or an ASTNode pointer.
* ASTNodes are usually inherited from CallableDeclarations * ASTNodes are usually inherited from CallableDeclarations
@ -50,7 +50,7 @@ namespace solidity::frontend
class FunctionCallGraphBuilder: private ASTConstVisitor class FunctionCallGraphBuilder: private ASTConstVisitor
{ {
public: public:
enum class SpecialNode { CreationRoot, CreationDispatch, InternalDispatch, ExternalDispatch }; enum class SpecialNode { EntryCreation, InternalCreationDispatch, InternalDispatch, Entry };
using Node = std::variant<ASTNode const*, SpecialNode>; using Node = std::variant<ASTNode const*, SpecialNode>;
@ -69,7 +69,7 @@ public:
/// Contract for which this is the graph /// Contract for which this is the graph
ContractDefinition const& contract; ContractDefinition const& contract;
std::map<Node, std::set<ASTNode const*, ASTNode::CompareByID>, CompareByID> edges; std::map<Node, std::set<Node, CompareByID>, CompareByID> edges;
/// Set of contracts created /// Set of contracts created
std::set<ContractDefinition const*, ASTNode::CompareByID> createdContracts; std::set<ContractDefinition const*, ASTNode::CompareByID> createdContracts;
@ -81,18 +81,17 @@ private:
bool visit(Identifier const& _identifier) override; bool visit(Identifier const& _identifier) override;
bool visit(NewExpression const& _newExpression) override; bool visit(NewExpression const& _newExpression) override;
void endVisit(MemberAccess const& _memberAccess) override; void endVisit(MemberAccess const& _memberAccess) override;
void endVisit(FunctionCall const& _functionCall) override;
void visitCallable(CallableDeclaration const* _callable); void visitCallable(CallableDeclaration const* _callable, bool _directCall = true);
void visitConstructor(ContractDefinition const& _contract); void visitConstructor(ContractDefinition const& _contract);
bool add(Node _caller, ASTNode const* _callee); bool add(Node _caller, Node _callee);
void processFunction(CallableDeclaration const& _callable, ExpressionAnnotation const& _annotation); void processFunction(CallableDeclaration const& _callable, ExpressionAnnotation const& _annotation);
ContractDefinition const* m_contract = nullptr; ContractDefinition const* m_contract = nullptr;
std::optional<Node> m_currentNode; std::optional<Node> m_currentNode;
std::shared_ptr<ContractCallGraph> m_graph = nullptr; std::shared_ptr<ContractCallGraph> m_graph = nullptr;
Node m_currentDispatch = SpecialNode::CreationDispatch; Node m_currentDispatch = SpecialNode::InternalCreationDispatch;
}; };
} }

View File

@ -43,18 +43,6 @@ string MultiUseYulFunctionCollector::requestedFunctions()
return result; return result;
} }
set<string> MultiUseYulFunctionCollector::requestedFunctionsNames()
{
set<string> names;
for (auto const& [name, code]: m_requestedFunctions)
{
(void) code;
names.emplace(name);
}
return names;
}
string MultiUseYulFunctionCollector::createFunction(string const& _name, function<string ()> const& _creator) string MultiUseYulFunctionCollector::createFunction(string const& _name, function<string ()> const& _creator)
{ {
if (!m_requestedFunctions.count(_name)) if (!m_requestedFunctions.count(_name))

View File

@ -49,9 +49,6 @@ public:
/// empty return value. /// empty return value.
std::string requestedFunctions(); std::string requestedFunctions();
/// Helper function to get the names of all requested functions
std::set<std::string> requestedFunctionsNames();
/// @returns true IFF a function with the specified name has already been collected. /// @returns true IFF a function with the specified name has already been collected.
bool contains(std::string const& _name) const { return m_requestedFunctions.count(_name) > 0; } bool contains(std::string const& _name) const { return m_requestedFunctions.count(_name) > 0; }
private: private:

View File

@ -36,12 +36,14 @@
#include <libsolutil/CommonData.h> #include <libsolutil/CommonData.h>
#include <libsolutil/Whiskers.h> #include <libsolutil/Whiskers.h>
#include <libsolutil/StringUtils.h> #include <libsolutil/StringUtils.h>
#include <libsolutil/Algorithms.h>
#include <liblangutil/SourceReferenceFormatter.h> #include <liblangutil/SourceReferenceFormatter.h>
#include <boost/range/adaptor/map.hpp> #include <boost/range/adaptor/map.hpp>
#include <sstream> #include <sstream>
#include <variant>
using namespace std; using namespace std;
using namespace solidity; using namespace solidity;
@ -51,41 +53,37 @@ using namespace solidity::frontend;
namespace namespace
{ {
void verifyCallGraph(set<ASTNode const*, ASTNode::CompareByID> const& _nodes, set<string>& _functionList) void verifyCallGraph(set<ASTNode const*, ASTNode::CompareByID> const& _nodes, set<ASTNode const*>& _functionList)
{ {
for (auto const& node: _nodes) for (auto const& node: _nodes)
if (auto const* functionDef = dynamic_cast<FunctionDefinition const*>(node)) if (auto const* functionDef = dynamic_cast<FunctionDefinition const*>(node))
solAssert(_functionList.erase(IRNames::function(*functionDef)) == 1, "Function not found in generated code"); solAssert(functionDef->isConstructor() || _functionList.erase(functionDef) == 1, "Function not found in generated code");
static string const funPrefix = "fun_"; static string const funPrefix = "fun_";
for (string const& name: _functionList) for (ASTNode const* node: _functionList)
solAssert(name.substr(0, funPrefix.size()) != funPrefix, "Functions found in code gen that were not in the call graph"); if (auto functionDefinition = dynamic_cast<FunctionDefinition const*>(node))
solAssert(functionDefinition->isConstructor(), "Functions found in code gen that were not in the call graph");
} }
void collectCalls(FunctionCallGraphBuilder::ContractCallGraph const& _graph, ASTNode const* _root, set<ASTNode const*, ASTNode::CompareByID>& _functions) void collectCalls(FunctionCallGraphBuilder::ContractCallGraph const& _graph, FunctionCallGraphBuilder::Node _root, set<ASTNode const*, ASTNode::CompareByID>& _functions)
{ {
if (_functions.count(_root) > 0) using Node = FunctionCallGraphBuilder::Node;
return;
set<ASTNode const*, ASTNode::CompareByID> toVisit{_root}; set<Node const*> functions = BreadthFirstSearch<Node const*>{{&_root}}.run([&](Node const* _node, auto&& _addChild) {
auto callees = _graph.edges.find(*_node);
_functions.emplace(_root);
while (!toVisit.empty())
{
ASTNode const* function = *toVisit.begin();
toVisit.erase(toVisit.begin());
auto callees = _graph.edges.find(function);
if (callees == _graph.edges.end()) if (callees == _graph.edges.end())
continue; return;
for (auto& callee: callees->second) for (Node const& _child: callees->second)
if (_functions.emplace(callee).second) _addChild(&_child);
toVisit.emplace(callee); }).visited;
}
for (Node const* node: functions)
if (auto* astNode = get_if<ASTNode const*>(node))
_functions.emplace(*astNode);
} }
} }
@ -122,22 +120,13 @@ void IRGenerator::verifyCallGraph(FunctionCallGraphBuilder::ContractCallGraph co
{ {
set<ASTNode const*, ASTNode::CompareByID> functions; set<ASTNode const*, ASTNode::CompareByID> functions;
auto collectFromNode = [&](FunctionCallGraphBuilder::SpecialNode _node) collectCalls(_graph, FunctionCallGraphBuilder::SpecialNode::EntryCreation, functions);
{ collectCalls(_graph, FunctionCallGraphBuilder::SpecialNode::InternalCreationDispatch, functions);
auto callees = _graph.edges.find(_node);
if (callees != _graph.edges.end())
for (auto callee: callees->second)
collectCalls(_graph, callee, functions);
};
collectFromNode(FunctionCallGraphBuilder::SpecialNode::CreationRoot);
collectFromNode(FunctionCallGraphBuilder::SpecialNode::CreationDispatch);
::verifyCallGraph(functions, m_creationFunctionList); ::verifyCallGraph(functions, m_creationFunctionList);
functions.clear(); functions.clear();
collectFromNode(FunctionCallGraphBuilder::SpecialNode::ExternalDispatch); collectCalls(_graph, FunctionCallGraphBuilder::SpecialNode::Entry, functions);
collectFromNode(FunctionCallGraphBuilder::SpecialNode::InternalDispatch); collectCalls(_graph, FunctionCallGraphBuilder::SpecialNode::InternalDispatch, functions);
::verifyCallGraph(functions, m_deployedFunctionList); ::verifyCallGraph(functions, m_deployedFunctionList);
} }
@ -207,10 +196,9 @@ string IRGenerator::generate(
t("deploy", deployCode(_contract)); t("deploy", deployCode(_contract));
generateImplicitConstructors(_contract); generateImplicitConstructors(_contract);
generateQueuedFunctions(); m_creationFunctionList = generateQueuedFunctions();
InternalDispatchMap internalDispatchMap = generateInternalDispatchFunctions(); InternalDispatchMap internalDispatchMap = generateInternalDispatchFunctions();
m_creationFunctionList = m_context.functionCollector().requestedFunctionsNames();
t("functions", m_context.functionCollector().requestedFunctions()); t("functions", m_context.functionCollector().requestedFunctions());
t("subObjects", subObjectSources(m_context.subObjectsCreated())); t("subObjects", subObjectSources(m_context.subObjectsCreated()));
@ -229,9 +217,8 @@ string IRGenerator::generate(
t("DeployedObject", IRNames::deployedObject(_contract)); t("DeployedObject", IRNames::deployedObject(_contract));
t("library_address", IRNames::libraryAddressImmutable()); t("library_address", IRNames::libraryAddressImmutable());
t("dispatch", dispatchRoutine(_contract)); t("dispatch", dispatchRoutine(_contract));
generateQueuedFunctions(); m_deployedFunctionList = generateQueuedFunctions();
generateInternalDispatchFunctions(); generateInternalDispatchFunctions();
m_deployedFunctionList = m_context.functionCollector().requestedFunctionsNames();
t("deployedFunctions", m_context.functionCollector().requestedFunctions()); t("deployedFunctions", m_context.functionCollector().requestedFunctions());
t("deployedSubObjects", subObjectSources(m_context.subObjectsCreated())); t("deployedSubObjects", subObjectSources(m_context.subObjectsCreated()));
@ -249,11 +236,20 @@ string IRGenerator::generate(Block const& _block)
return generator.code(); return generator.code();
} }
void IRGenerator::generateQueuedFunctions() set<ASTNode const*> IRGenerator::generateQueuedFunctions()
{ {
set<ASTNode const*> functions;
while (!m_context.functionGenerationQueueEmpty()) while (!m_context.functionGenerationQueueEmpty())
{
FunctionDefinition const& functionDefinition = *m_context.dequeueFunctionForCodeGeneration();
functions.emplace(&functionDefinition);
// NOTE: generateFunction() may modify function generation queue // NOTE: generateFunction() may modify function generation queue
generateFunction(*m_context.dequeueFunctionForCodeGeneration()); generateFunction(functionDefinition);
}
return functions;
} }
InternalDispatchMap IRGenerator::generateInternalDispatchFunctions() InternalDispatchMap IRGenerator::generateInternalDispatchFunctions()

View File

@ -68,7 +68,8 @@ private:
/// Generates code for all the functions from the function generation queue. /// Generates code for all the functions from the function generation queue.
/// The resulting code is stored in the function collector in IRGenerationContext. /// The resulting code is stored in the function collector in IRGenerationContext.
void generateQueuedFunctions(); /// @returns A set of ast nodes of the generated functions.
std::set<ASTNode const*> generateQueuedFunctions();
/// Generates all the internal dispatch functions necessary to handle any function that could /// Generates all the internal dispatch functions necessary to handle any function that could
/// possibly be called via a pointer. /// possibly be called via a pointer.
/// @return The content of the dispatch for reuse in runtime code. Reuse is necessary because /// @return The content of the dispatch for reuse in runtime code. Reuse is necessary because
@ -118,8 +119,8 @@ private:
langutil::EVMVersion const m_evmVersion; langutil::EVMVersion const m_evmVersion;
OptimiserSettings const m_optimiserSettings; OptimiserSettings const m_optimiserSettings;
std::set<std::string> m_creationFunctionList; std::set<ASTNode const*> m_creationFunctionList;
std::set<std::string> m_deployedFunctionList; std::set<ASTNode const*> m_deployedFunctionList;
IRGenerationContext m_context; IRGenerationContext m_context;
YulUtilFunctions m_utils; YulUtilFunctions m_utils;