fixup! Implement

This commit is contained in:
Mathias Baumann 2020-12-15 15:04:22 +01:00 committed by Kamil Śliwak
parent 1452076ee2
commit d6ab70c27d
6 changed files with 69 additions and 108 deletions

View File

@ -52,8 +52,8 @@ shared_ptr<FunctionCallGraphBuilder::ContractCallGraph> FunctionCallGraphBuilder
m_graph = make_shared<ContractCallGraph>(_contract);
// Create graph for constructor, state vars, etc
m_currentNode = SpecialNode::CreationRoot;
m_currentDispatch = SpecialNode::CreationDispatch;
m_currentNode = SpecialNode::EntryCreation;
m_currentDispatch = SpecialNode::InternalCreationDispatch;
for (ContractDefinition const* contract: _contract.annotation().linearizedBaseContracts)
visitConstructor(*contract);
@ -67,24 +67,19 @@ shared_ptr<FunctionCallGraphBuilder::ContractCallGraph> FunctionCallGraphBuilder
if (auto const* funcDef = dynamic_cast<FunctionDefinition const*>(&functionType->declaration()))
if (!m_graph->edges.count(funcDef))
visitCallable(funcDef);
// Add all external functions to the RuntimeDispatch
add(SpecialNode::Entry, &functionType->declaration());
}
// Add all CreationDispatch calls to the RuntimeDispatch as well
for (auto node: m_graph->edges[SpecialNode::CreationDispatch])
add(SpecialNode::InternalDispatch, node);
// Add all external functions to the RuntimeDispatch
for (auto& [hash, functionType]: _contract.interfaceFunctionList())
{
(void)hash;
add(SpecialNode::ExternalDispatch, &functionType->declaration());
}
// Add all InternalCreationDispatch calls to the RuntimeDispatch as well
add(SpecialNode::InternalDispatch, SpecialNode::InternalCreationDispatch);
if (_contract.fallbackFunction())
add(SpecialNode::ExternalDispatch, _contract.fallbackFunction());
add(SpecialNode::Entry, _contract.fallbackFunction());
if (_contract.receiveFunction())
add(SpecialNode::ExternalDispatch, _contract.receiveFunction());
add(SpecialNode::Entry, _contract.receiveFunction());
m_contract = nullptr;
solAssert(!m_currentNode.has_value(), "Current node not properly reset.");
@ -129,37 +124,31 @@ void FunctionCallGraphBuilder::endVisit(MemberAccess const& _memberAccess)
// Super functions
if (*_memberAccess.annotation().requiredLookup == VirtualLookup::Super)
{
if (ContractType const* type = dynamic_cast<ContractType const*>(_memberAccess.expression().annotation().type))
{
solAssert(type->isSuper(), "");
functionDef = &functionDef->resolveVirtual(*m_contract, type->contractDefinition().superContract(*m_contract));
}
}
else
solAssert(*_memberAccess.annotation().requiredLookup == VirtualLookup::Static, "");
processFunction(*functionDef, _memberAccess.annotation());
return;
}
void FunctionCallGraphBuilder::endVisit(FunctionCall const& _functionCall)
{
auto* functionType = dynamic_cast<FunctionType const*>(_functionCall.expression().annotation().type);
if (
functionType &&
functionType->kind() == FunctionType::Kind::Internal &&
!functionType->hasDeclaration()
)
add(m_currentDispatch, &_functionCall);
}
void FunctionCallGraphBuilder::visitCallable(CallableDeclaration const* _callable)
void FunctionCallGraphBuilder::visitCallable(CallableDeclaration const* _callable, bool _directCall)
{
solAssert(!m_graph->edges.count(_callable), "");
auto previousNode = m_currentNode;
std::optional<Node> previousNode = m_currentNode;
m_currentNode = _callable;
if (previousNode.has_value())
if (previousNode.has_value() && _directCall)
add(*previousNode, _callable);
if (!_directCall)
add(*m_currentNode, m_currentDispatch);
_callable->accept(*this);
@ -181,18 +170,9 @@ void FunctionCallGraphBuilder::visitConstructor(ContractDefinition const& _contr
}
}
bool FunctionCallGraphBuilder::add(Node _caller, ASTNode const* _callee)
bool FunctionCallGraphBuilder::add(Node _caller, Node _callee)
{
solAssert(_callee != nullptr, "");
auto result = m_graph->edges.find(_caller);
if (result == m_graph->edges.end())
{
m_graph->edges.emplace(_caller, std::set<ASTNode const*, ASTNode::CompareByID>{_callee});
return true;
}
return result->second.emplace(_callee).second;
return m_graph->edges[_caller].insert(_callee).second;
}
void FunctionCallGraphBuilder::processFunction(CallableDeclaration const& _callable, ExpressionAnnotation const& _annotation)
@ -203,5 +183,5 @@ void FunctionCallGraphBuilder::processFunction(CallableDeclaration const& _calla
// Create edge to creation dispatch
if (!_annotation.calledDirectly)
add(m_currentDispatch, &_callable);
visitCallable(&_callable);
visitCallable(&_callable, _annotation.calledDirectly);
}

View File

@ -32,10 +32,10 @@ namespace solidity::frontend
* functions and modifiers
*
* Includes the following special nodes:
* - CreationRoot: All calls made at contract creation originate from this node
* - CreationDispatch: Represents the internal dispatch function at creation time
* - ExternalDispatch: Represents the runtime dispatch for all external functions
* - EntryCreation: All calls made at contract creation originate from this node. Constructors are modelled to be all called by this node instead of calling each other due to implicit constructors that don't exist at this stage.
* - InternalCreationDispatch: Represents the internal dispatch function at creation time
* - InternalDispatch: Represents the runtime dispatch for internal function pointers and complex expressions
* - Entry: Represents the runtime dispatch for all external functions
*
* Nodes are a variant of either the enum SpecialNode or an ASTNode pointer.
* ASTNodes are usually inherited from CallableDeclarations
@ -50,7 +50,7 @@ namespace solidity::frontend
class FunctionCallGraphBuilder: private ASTConstVisitor
{
public:
enum class SpecialNode { CreationRoot, CreationDispatch, InternalDispatch, ExternalDispatch };
enum class SpecialNode { EntryCreation, InternalCreationDispatch, InternalDispatch, Entry };
using Node = std::variant<ASTNode const*, SpecialNode>;
@ -69,7 +69,7 @@ public:
/// Contract for which this is the graph
ContractDefinition const& contract;
std::map<Node, std::set<ASTNode const*, ASTNode::CompareByID>, CompareByID> edges;
std::map<Node, std::set<Node, CompareByID>, CompareByID> edges;
/// Set of contracts created
std::set<ContractDefinition const*, ASTNode::CompareByID> createdContracts;
@ -81,18 +81,17 @@ private:
bool visit(Identifier const& _identifier) override;
bool visit(NewExpression const& _newExpression) override;
void endVisit(MemberAccess const& _memberAccess) override;
void endVisit(FunctionCall const& _functionCall) override;
void visitCallable(CallableDeclaration const* _callable);
void visitCallable(CallableDeclaration const* _callable, bool _directCall = true);
void visitConstructor(ContractDefinition const& _contract);
bool add(Node _caller, ASTNode const* _callee);
bool add(Node _caller, Node _callee);
void processFunction(CallableDeclaration const& _callable, ExpressionAnnotation const& _annotation);
ContractDefinition const* m_contract = nullptr;
std::optional<Node> m_currentNode;
std::shared_ptr<ContractCallGraph> m_graph = nullptr;
Node m_currentDispatch = SpecialNode::CreationDispatch;
Node m_currentDispatch = SpecialNode::InternalCreationDispatch;
};
}

View File

@ -43,18 +43,6 @@ string MultiUseYulFunctionCollector::requestedFunctions()
return result;
}
set<string> MultiUseYulFunctionCollector::requestedFunctionsNames()
{
set<string> names;
for (auto const& [name, code]: m_requestedFunctions)
{
(void) code;
names.emplace(name);
}
return names;
}
string MultiUseYulFunctionCollector::createFunction(string const& _name, function<string ()> const& _creator)
{
if (!m_requestedFunctions.count(_name))

View File

@ -49,9 +49,6 @@ public:
/// empty return value.
std::string requestedFunctions();
/// Helper function to get the names of all requested functions
std::set<std::string> requestedFunctionsNames();
/// @returns true IFF a function with the specified name has already been collected.
bool contains(std::string const& _name) const { return m_requestedFunctions.count(_name) > 0; }
private:

View File

@ -36,12 +36,14 @@
#include <libsolutil/CommonData.h>
#include <libsolutil/Whiskers.h>
#include <libsolutil/StringUtils.h>
#include <libsolutil/Algorithms.h>
#include <liblangutil/SourceReferenceFormatter.h>
#include <boost/range/adaptor/map.hpp>
#include <sstream>
#include <variant>
using namespace std;
using namespace solidity;
@ -51,41 +53,37 @@ using namespace solidity::frontend;
namespace
{
void verifyCallGraph(set<ASTNode const*, ASTNode::CompareByID> const& _nodes, set<string>& _functionList)
void verifyCallGraph(set<ASTNode const*, ASTNode::CompareByID> const& _nodes, set<ASTNode const*>& _functionList)
{
for (auto const& node: _nodes)
if (auto const* functionDef = dynamic_cast<FunctionDefinition const*>(node))
solAssert(_functionList.erase(IRNames::function(*functionDef)) == 1, "Function not found in generated code");
solAssert(functionDef->isConstructor() || _functionList.erase(functionDef) == 1, "Function not found in generated code");
static string const funPrefix = "fun_";
for (string const& name: _functionList)
solAssert(name.substr(0, funPrefix.size()) != funPrefix, "Functions found in code gen that were not in the call graph");
for (ASTNode const* node: _functionList)
if (auto functionDefinition = dynamic_cast<FunctionDefinition const*>(node))
solAssert(functionDefinition->isConstructor(), "Functions found in code gen that were not in the call graph");
}
void collectCalls(FunctionCallGraphBuilder::ContractCallGraph const& _graph, ASTNode const* _root, set<ASTNode const*, ASTNode::CompareByID>& _functions)
void collectCalls(FunctionCallGraphBuilder::ContractCallGraph const& _graph, FunctionCallGraphBuilder::Node _root, set<ASTNode const*, ASTNode::CompareByID>& _functions)
{
if (_functions.count(_root) > 0)
return;
using Node = FunctionCallGraphBuilder::Node;
set<ASTNode const*, ASTNode::CompareByID> toVisit{_root};
set<Node const*> functions = BreadthFirstSearch<Node const*>{{&_root}}.run([&](Node const* _node, auto&& _addChild) {
auto callees = _graph.edges.find(*_node);
_functions.emplace(_root);
while (!toVisit.empty())
{
ASTNode const* function = *toVisit.begin();
toVisit.erase(toVisit.begin());
auto callees = _graph.edges.find(function);
if (callees == _graph.edges.end())
continue;
return;
for (auto& callee: callees->second)
if (_functions.emplace(callee).second)
toVisit.emplace(callee);
}
for (Node const& _child: callees->second)
_addChild(&_child);
}).visited;
for (Node const* node: functions)
if (auto* astNode = get_if<ASTNode const*>(node))
_functions.emplace(*astNode);
}
}
@ -122,22 +120,13 @@ void IRGenerator::verifyCallGraph(FunctionCallGraphBuilder::ContractCallGraph co
{
set<ASTNode const*, ASTNode::CompareByID> functions;
auto collectFromNode = [&](FunctionCallGraphBuilder::SpecialNode _node)
{
auto callees = _graph.edges.find(_node);
if (callees != _graph.edges.end())
for (auto callee: callees->second)
collectCalls(_graph, callee, functions);
};
collectFromNode(FunctionCallGraphBuilder::SpecialNode::CreationRoot);
collectFromNode(FunctionCallGraphBuilder::SpecialNode::CreationDispatch);
collectCalls(_graph, FunctionCallGraphBuilder::SpecialNode::EntryCreation, functions);
collectCalls(_graph, FunctionCallGraphBuilder::SpecialNode::InternalCreationDispatch, functions);
::verifyCallGraph(functions, m_creationFunctionList);
functions.clear();
collectFromNode(FunctionCallGraphBuilder::SpecialNode::ExternalDispatch);
collectFromNode(FunctionCallGraphBuilder::SpecialNode::InternalDispatch);
collectCalls(_graph, FunctionCallGraphBuilder::SpecialNode::Entry, functions);
collectCalls(_graph, FunctionCallGraphBuilder::SpecialNode::InternalDispatch, functions);
::verifyCallGraph(functions, m_deployedFunctionList);
}
@ -207,10 +196,9 @@ string IRGenerator::generate(
t("deploy", deployCode(_contract));
generateImplicitConstructors(_contract);
generateQueuedFunctions();
m_creationFunctionList = generateQueuedFunctions();
InternalDispatchMap internalDispatchMap = generateInternalDispatchFunctions();
m_creationFunctionList = m_context.functionCollector().requestedFunctionsNames();
t("functions", m_context.functionCollector().requestedFunctions());
t("subObjects", subObjectSources(m_context.subObjectsCreated()));
@ -229,9 +217,8 @@ string IRGenerator::generate(
t("DeployedObject", IRNames::deployedObject(_contract));
t("library_address", IRNames::libraryAddressImmutable());
t("dispatch", dispatchRoutine(_contract));
generateQueuedFunctions();
m_deployedFunctionList = generateQueuedFunctions();
generateInternalDispatchFunctions();
m_deployedFunctionList = m_context.functionCollector().requestedFunctionsNames();
t("deployedFunctions", m_context.functionCollector().requestedFunctions());
t("deployedSubObjects", subObjectSources(m_context.subObjectsCreated()));
@ -249,11 +236,20 @@ string IRGenerator::generate(Block const& _block)
return generator.code();
}
void IRGenerator::generateQueuedFunctions()
set<ASTNode const*> IRGenerator::generateQueuedFunctions()
{
set<ASTNode const*> functions;
while (!m_context.functionGenerationQueueEmpty())
{
FunctionDefinition const& functionDefinition = *m_context.dequeueFunctionForCodeGeneration();
functions.emplace(&functionDefinition);
// NOTE: generateFunction() may modify function generation queue
generateFunction(*m_context.dequeueFunctionForCodeGeneration());
generateFunction(functionDefinition);
}
return functions;
}
InternalDispatchMap IRGenerator::generateInternalDispatchFunctions()

View File

@ -68,7 +68,8 @@ private:
/// Generates code for all the functions from the function generation queue.
/// The resulting code is stored in the function collector in IRGenerationContext.
void generateQueuedFunctions();
/// @returns A set of ast nodes of the generated functions.
std::set<ASTNode const*> generateQueuedFunctions();
/// Generates all the internal dispatch functions necessary to handle any function that could
/// possibly be called via a pointer.
/// @return The content of the dispatch for reuse in runtime code. Reuse is necessary because
@ -118,8 +119,8 @@ private:
langutil::EVMVersion const m_evmVersion;
OptimiserSettings const m_optimiserSettings;
std::set<std::string> m_creationFunctionList;
std::set<std::string> m_deployedFunctionList;
std::set<ASTNode const*> m_creationFunctionList;
std::set<ASTNode const*> m_deployedFunctionList;
IRGenerationContext m_context;
YulUtilFunctions m_utils;