Generate internal dispatch only for functions that might actually get called via pointers

- This also adds support for internal library calls as a side-effect since they'll now be pulled into the internal dispatch automatically.
This commit is contained in:
Kamil Śliwak 2020-05-19 21:50:22 +02:00
parent b7aa6cb1f7
commit 1a2e441bc5
11 changed files with 283 additions and 63 deletions

View File

@ -29,6 +29,8 @@
#include <libsolutil/Whiskers.h> #include <libsolutil/Whiskers.h>
#include <libsolutil/StringUtils.h> #include <libsolutil/StringUtils.h>
#include <boost/range/adaptor/map.hpp>
using namespace std; using namespace std;
using namespace solidity; using namespace solidity;
using namespace solidity::util; using namespace solidity::util;
@ -121,49 +123,55 @@ string IRGenerationContext::newYulVariable()
return "_" + to_string(++m_varCounter); return "_" + to_string(++m_varCounter);
} }
string IRGenerationContext::generateInternalDispatchFunction(YulArity const& _arity) void IRGenerationContext::initializeInternalDispatch(InternalDispatchMap _internalDispatch)
{ {
string funName = IRNames::internalDispatch(_arity); solAssert(internalDispatchClean(), "");
return m_functions.createFunction(funName, [&]() {
Whiskers templ(R"(
function <functionName>(fun<?+in>, <in></+in>) <?+out>-> <out></+out> {
switch fun
<#cases>
case <funID>
{
<?+out> <out> :=</+out> <name>(<in>)
}
</cases>
default { invalid() }
}
)");
templ("functionName", funName);
templ("in", suffixedVariableNameList("in_", 0, _arity.in));
templ("out", suffixedVariableNameList("out_", 0, _arity.out));
vector<map<string, string>> cases;
for (FunctionDefinition const* function: collectFunctionsOfArity(_arity))
{
solAssert(function, "");
solAssert(
YulArity::fromType(*TypeProvider::function(*function, FunctionType::Kind::Internal)) == _arity,
"A single dispatch function can only handle functions of one arity"
);
solAssert(!function->isConstructor(), "");
// 0 is reserved for uninitialized function pointers
solAssert(function->id() != 0, "Unexpected function ID: 0");
cases.emplace_back(map<string, string>{
{"funID", to_string(function->id())},
{"name", IRNames::function(*function)}
});
for (set<FunctionDefinition const*> const& functions: _internalDispatch | boost::adaptors::map_values)
for (auto function: functions)
enqueueFunctionForCodeGeneration(*function); enqueueFunctionForCodeGeneration(*function);
m_internalDispatchMap = move(_internalDispatch);
} }
templ("cases", move(cases)); InternalDispatchMap IRGenerationContext::consumeInternalDispatchMap()
return templ.render(); {
}); m_directInternalFunctionCalls.clear();
InternalDispatchMap internalDispatch = move(m_internalDispatchMap);
m_internalDispatchMap.clear();
return internalDispatch;
}
void IRGenerationContext::internalFunctionCalledDirectly(Expression const& _expression)
{
solAssert(m_directInternalFunctionCalls.count(&_expression) == 0, "");
m_directInternalFunctionCalls.insert(&_expression);
}
void IRGenerationContext::internalFunctionAccessed(Expression const& _expression, FunctionDefinition const& _function)
{
solAssert(
IRHelpers::referencedFunctionDeclaration(_expression) &&
_function.resolveVirtual(mostDerivedContract()) ==
IRHelpers::referencedFunctionDeclaration(_expression)->resolveVirtual(mostDerivedContract()),
"Function definition does not match the expression"
);
if (m_directInternalFunctionCalls.count(&_expression) == 0)
{
FunctionType const* functionType = TypeProvider::function(_function, FunctionType::Kind::Internal);
solAssert(functionType, "");
m_internalDispatchMap[YulArity::fromType(*functionType)].insert(&_function);
enqueueFunctionForCodeGeneration(_function);
}
}
void IRGenerationContext::internalFunctionCalledThroughDispatch(YulArity const& _arity)
{
m_internalDispatchMap.try_emplace(_arity);
} }
YulUtilFunctions IRGenerationContext::utils() YulUtilFunctions IRGenerationContext::utils()
@ -180,21 +188,3 @@ std::string IRGenerationContext::revertReasonIfDebug(std::string const& _message
{ {
return YulUtilFunctions::revertReasonIfDebug(m_revertStrings, _message); return YulUtilFunctions::revertReasonIfDebug(m_revertStrings, _message);
} }
set<FunctionDefinition const*> IRGenerationContext::collectFunctionsOfArity(YulArity const& _arity)
{
// UNIMPLEMENTED: Internal library calls via pointers are not implemented yet.
// We're not returning any internal library functions here even though it's possible
// to call them via pointers. Right now such calls end will up triggering the `default` case in
// the switch in the generated dispatch function.
set<FunctionDefinition const*> functions;
for (auto const& contract: mostDerivedContract().annotation().linearizedBaseContracts)
for (FunctionDefinition const* function: contract->definedFunctions())
if (
!function->isConstructor() &&
YulArity::fromType(*TypeProvider::function(*function, FunctionType::Kind::Internal)) == _arity
)
functions.insert(function);
return functions;
}

View File

@ -43,6 +43,8 @@ namespace solidity::frontend
class YulUtilFunctions; class YulUtilFunctions;
class ABIFunctions; class ABIFunctions;
using InternalDispatchMap = std::map<YulArity, std::set<FunctionDefinition const*>>;
/** /**
* Class that contains contextual information during IR generation. * Class that contains contextual information during IR generation.
*/ */
@ -102,7 +104,26 @@ public:
std::string newYulVariable(); std::string newYulVariable();
std::string generateInternalDispatchFunction(YulArity const& _arity); void initializeInternalDispatch(InternalDispatchMap _internalDispatchMap);
InternalDispatchMap consumeInternalDispatchMap();
bool internalDispatchClean() const { return m_internalDispatchMap.empty() && m_directInternalFunctionCalls.empty(); }
/// Notifies the context that a function call that needs to go through internal dispatch was
/// encountered while visiting the AST. This ensures that the corresponding dispatch function
/// gets added to the dispatch map even if there are no entries in it (which may happen if
/// the code contains a call to an uninitialized function variable).
void internalFunctionCalledThroughDispatch(YulArity const& _arity);
/// Notifies the context that a direct function call (i.e. not through internal dispatch) was
/// encountered while visiting the AST. This lets the context know that the function should
/// not be added to the dispatch (unless there are also indirect calls to it elsewhere else).
void internalFunctionCalledDirectly(Expression const& _expression);
/// Notifies the context that a name representing an internal function has been found while
/// visiting the AST. If the name has not been reported as a direct call using
/// @a internalFunctionCalledDirectly(), it's assumed to represent function variable access
/// and the function gets added to internal dispatch.
void internalFunctionAccessed(Expression const& _expression, FunctionDefinition const& _function);
/// @returns a new copy of the utility function generator (but using the same function set). /// @returns a new copy of the utility function generator (but using the same function set).
YulUtilFunctions utils(); YulUtilFunctions utils();
@ -120,8 +141,6 @@ public:
std::set<ContractDefinition const*, ASTNode::CompareByID>& subObjectsCreated() { return m_subObjects; } std::set<ContractDefinition const*, ASTNode::CompareByID>& subObjectsCreated() { return m_subObjects; }
private: private:
std::set<FunctionDefinition const*> collectFunctionsOfArity(YulArity const& _arity);
langutil::EVMVersion m_evmVersion; langutil::EVMVersion m_evmVersion;
RevertStrings m_revertStrings; RevertStrings m_revertStrings;
OptimiserSettings m_optimiserSettings; OptimiserSettings m_optimiserSettings;
@ -147,6 +166,13 @@ private:
/// all platforms - which is a property guaranteed by MultiUseYulFunctionCollector. /// all platforms - which is a property guaranteed by MultiUseYulFunctionCollector.
std::set<FunctionDefinition const*> m_functionGenerationQueue; std::set<FunctionDefinition const*> m_functionGenerationQueue;
/// Collection of functions that need to be callable via internal dispatch.
/// Note that having a key with an empty set of functions is a valid situation. It means that
/// the code contains a call via a pointer even though a specific function is never assigned to it.
/// It will fail at runtime but the code must still compile.
InternalDispatchMap m_internalDispatchMap;
std::set<Expression const*> m_directInternalFunctionCalls;
std::set<ContractDefinition const*, ASTNode::CompareByID> m_subObjects; std::set<ContractDefinition const*, ASTNode::CompareByID> m_subObjects;
}; };

View File

@ -38,6 +38,8 @@
#include <liblangutil/SourceReferenceFormatter.h> #include <liblangutil/SourceReferenceFormatter.h>
#include <boost/range/adaptor/map.hpp>
#include <sstream> #include <sstream>
using namespace std; using namespace std;
@ -137,14 +139,22 @@ string IRGenerator::generate(
t("deploy", deployCode(_contract)); t("deploy", deployCode(_contract));
generateImplicitConstructors(_contract); generateImplicitConstructors(_contract);
generateQueuedFunctions(); generateQueuedFunctions();
InternalDispatchMap internalDispatchMap = generateInternalDispatchFunctions();
t("functions", m_context.functionCollector().requestedFunctions()); t("functions", m_context.functionCollector().requestedFunctions());
t("subObjects", subObjectSources(m_context.subObjectsCreated())); t("subObjects", subObjectSources(m_context.subObjectsCreated()));
resetContext(_contract); resetContext(_contract);
// NOTE: Function pointers can be passed from creation code via storage variables. We need to
// get all the functions they could point to into the dispatch functions even if they're never
// referenced by name in the runtime code.
m_context.initializeInternalDispatch(move(internalDispatchMap));
// Do not register immutables to avoid assignment. // Do not register immutables to avoid assignment.
t("RuntimeObject", IRNames::runtimeObject(_contract)); t("RuntimeObject", IRNames::runtimeObject(_contract));
t("dispatch", dispatchRoutine(_contract)); t("dispatch", dispatchRoutine(_contract));
generateQueuedFunctions(); generateQueuedFunctions();
generateInternalDispatchFunctions();
t("runtimeFunctions", m_context.functionCollector().requestedFunctions()); t("runtimeFunctions", m_context.functionCollector().requestedFunctions());
t("runtimeSubObjects", subObjectSources(m_context.subObjectsCreated())); t("runtimeSubObjects", subObjectSources(m_context.subObjectsCreated()));
return t.render(); return t.render();
@ -164,6 +174,68 @@ void IRGenerator::generateQueuedFunctions()
generateFunction(*m_context.dequeueFunctionForCodeGeneration()); generateFunction(*m_context.dequeueFunctionForCodeGeneration());
} }
InternalDispatchMap IRGenerator::generateInternalDispatchFunctions()
{
solAssert(
m_context.functionGenerationQueueEmpty(),
"At this point all the enqueued functions should have been generated. "
"Otherwise the dispatch may be incomplete."
);
InternalDispatchMap internalDispatchMap = m_context.consumeInternalDispatchMap();
for (YulArity const& arity: internalDispatchMap | boost::adaptors::map_keys)
{
string funName = IRNames::internalDispatch(arity);
m_context.functionCollector().createFunction(funName, [&]() {
Whiskers templ(R"(
function <functionName>(fun<?+in>, <in></+in>) <?+out>-> <out></+out> {
switch fun
<#cases>
case <funID>
{
<?+out> <out> :=</+out> <name>(<in>)
}
</cases>
default { invalid() }
}
)");
templ("functionName", funName);
templ("in", suffixedVariableNameList("in_", 0, arity.in));
templ("out", suffixedVariableNameList("out_", 0, arity.out));
vector<map<string, string>> cases;
for (FunctionDefinition const* function: internalDispatchMap.at(arity))
{
solAssert(function, "");
solAssert(
YulArity::fromType(*TypeProvider::function(*function, FunctionType::Kind::Internal)) == arity,
"A single dispatch function can only handle functions of one arity"
);
solAssert(!function->isConstructor(), "");
// 0 is reserved for uninitialized function pointers
solAssert(function->id() != 0, "Unexpected function ID: 0");
solAssert(m_context.functionCollector().contains(IRNames::function(*function)), "");
cases.emplace_back(map<string, string>{
{"funID", to_string(function->id())},
{"name", IRNames::function(*function)}
});
}
templ("cases", move(cases));
return templ.render();
});
}
solAssert(m_context.internalDispatchClean(), "");
solAssert(
m_context.functionGenerationQueueEmpty(),
"Internal dispatch generation must not add new functions to generation queue because they won't be proeessed."
);
return internalDispatchMap;
}
string IRGenerator::generateFunction(FunctionDefinition const& _function) string IRGenerator::generateFunction(FunctionDefinition const& _function)
{ {
string functionName = IRNames::function(_function); string functionName = IRNames::function(_function);
@ -556,6 +628,10 @@ void IRGenerator::resetContext(ContractDefinition const& _contract)
m_context.functionCollector().requestedFunctions().empty(), m_context.functionCollector().requestedFunctions().empty(),
"Reset context while it still had functions." "Reset context while it still had functions."
); );
solAssert(
m_context.internalDispatchClean(),
"Reset internal dispatch map without consuming it."
);
m_context = IRGenerationContext(m_evmVersion, m_context.revertStrings(), m_optimiserSettings); m_context = IRGenerationContext(m_evmVersion, m_context.revertStrings(), m_optimiserSettings);
m_context.setMostDerivedContract(_contract); m_context.setMostDerivedContract(_contract);

View File

@ -65,6 +65,11 @@ private:
/// Generates code for all the functions from the function generation queue. /// Generates code for all the functions from the function generation queue.
/// The resulting code is stored in the function collector in IRGenerationContext. /// The resulting code is stored in the function collector in IRGenerationContext.
void generateQueuedFunctions(); void generateQueuedFunctions();
/// Generates all the internal dispatch functions necessary to handle any function that could
/// possibly be called via a pointer.
/// @return The content of the dispatch for reuse in runtime code. Reuse is necessary because
/// pointers to functions can be passed from the creation code in storage variables.
InternalDispatchMap generateInternalDispatchFunctions();
/// Generates code for and returns the name of the function. /// Generates code for and returns the name of the function.
std::string generateFunction(FunctionDefinition const& _function); std::string generateFunction(FunctionDefinition const& _function);
/// Generates a getter for the given declaration and returns its name /// Generates a getter for the given declaration and returns its name

View File

@ -583,6 +583,20 @@ bool IRGeneratorForStatements::visit(BinaryOperation const& _binOp)
return false; return false;
} }
bool IRGeneratorForStatements::visit(FunctionCall const& _functionCall)
{
FunctionTypePointer functionType = dynamic_cast<FunctionType const*>(&type(_functionCall.expression()));
if (
functionType &&
functionType->kind() == FunctionType::Kind::Internal &&
!functionType->bound() &&
IRHelpers::referencedFunctionDeclaration(_functionCall.expression())
)
m_context.internalFunctionCalledDirectly(_functionCall.expression());
return true;
}
void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall) void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
{ {
solUnimplementedAssert( solUnimplementedAssert(
@ -688,9 +702,10 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
else else
{ {
YulArity arity = YulArity::fromType(*functionType); YulArity arity = YulArity::fromType(*functionType);
m_context.internalFunctionCalledThroughDispatch(arity);
define(_functionCall) << define(_functionCall) <<
// NOTE: generateInternalDispatchFunction() takes care of adding the function to function generation queue IRNames::internalDispatch(arity) <<
m_context.generateInternalDispatchFunction(arity) <<
"(" << "(" <<
IRVariable(_functionCall.expression()).part("functionIdentifier").name() << IRVariable(_functionCall.expression()).part("functionIdentifier").name() <<
joinHumanReadablePrefixed(args) << joinHumanReadablePrefixed(args) <<
@ -1492,7 +1507,10 @@ void IRGeneratorForStatements::endVisit(MemberAccess const& _memberAccess)
break; break;
case FunctionType::Kind::Internal: case FunctionType::Kind::Internal:
if (auto const* function = dynamic_cast<FunctionDefinition const*>(_memberAccess.annotation().referencedDeclaration)) if (auto const* function = dynamic_cast<FunctionDefinition const*>(_memberAccess.annotation().referencedDeclaration))
{
define(_memberAccess) << to_string(function->id()) << "\n"; define(_memberAccess) << to_string(function->id()) << "\n";
m_context.internalFunctionAccessed(_memberAccess, *function);
}
else else
solAssert(false, "Function not found in member access"); solAssert(false, "Function not found in member access");
break; break;
@ -1756,7 +1774,14 @@ void IRGeneratorForStatements::endVisit(Identifier const& _identifier)
return; return;
} }
else if (FunctionDefinition const* functionDef = dynamic_cast<FunctionDefinition const*>(declaration)) else if (FunctionDefinition const* functionDef = dynamic_cast<FunctionDefinition const*>(declaration))
define(_identifier) << to_string(functionDef->resolveVirtual(m_context.mostDerivedContract()).id()) << "\n"; {
FunctionDefinition const& resolvedFunctionDef = functionDef->resolveVirtual(m_context.mostDerivedContract());
define(_identifier) << to_string(resolvedFunctionDef.id()) << "\n";
solAssert(resolvedFunctionDef.functionType(true), "");
solAssert(resolvedFunctionDef.functionType(true)->kind() == FunctionType::Kind::Internal, "");
m_context.internalFunctionAccessed(_identifier, resolvedFunctionDef);
}
else if (VariableDeclaration const* varDecl = dynamic_cast<VariableDeclaration const*>(declaration)) else if (VariableDeclaration const* varDecl = dynamic_cast<VariableDeclaration const*>(declaration))
handleVariableReference(*varDecl, _identifier); handleVariableReference(*varDecl, _identifier);
else if (dynamic_cast<ContractDefinition const*>(declaration)) else if (dynamic_cast<ContractDefinition const*>(declaration))

View File

@ -70,6 +70,7 @@ public:
void endVisit(Return const& _return) override; void endVisit(Return const& _return) override;
void endVisit(UnaryOperation const& _unaryOperation) override; void endVisit(UnaryOperation const& _unaryOperation) override;
bool visit(BinaryOperation const& _binOp) override; bool visit(BinaryOperation const& _binOp) override;
bool visit(FunctionCall const& _funCall) override;
void endVisit(FunctionCall const& _funCall) override; void endVisit(FunctionCall const& _funCall) override;
void endVisit(FunctionCallOptions const& _funCallOptions) override; void endVisit(FunctionCallOptions const& _funCallOptions) override;
void endVisit(MemberAccess const& _memberAccess) override; void endVisit(MemberAccess const& _memberAccess) override;

View File

@ -0,0 +1,31 @@
contract Test {
bytes6 name;
constructor() public {
function (bytes6 _name) internal setter = setName;
setter("abcdef");
applyShift(leftByteShift, 3);
}
function getName() public returns (bytes6 ret) {
return name;
}
function setName(bytes6 _name) private {
name = _name;
}
function leftByteShift(bytes6 _value, uint _shift) public returns (bytes6) {
return _value << _shift * 8;
}
function applyShift(function (bytes6 _value, uint _shift) internal returns (bytes6) _shiftOperator, uint _bytes) internal {
name = _shiftOperator(name, _bytes);
}
}
// ====
// compileViaYul: also
// ----
// getName() -> "def\x00\x00\x00"

View File

@ -22,5 +22,7 @@ contract C {
} }
} }
// ====
// compileViaYul: also
// ---- // ----
// f(uint256[]): 0x20, 0x3, 0x1, 0x7, 0x3 -> 11 // f(uint256[]): 0x20, 0x3, 0x1, 0x7, 0x3 -> 11

View File

@ -0,0 +1,21 @@
contract A {
function f() internal virtual returns (uint256) {
return 1;
}
}
contract B is A {
function f() internal override returns (uint256) {
return 2;
}
function g() public returns (uint256) {
function() internal returns (uint256) ptr = A.f;
return ptr();
}
}
// ====
// compileViaYul: also
// ----
// g() -> 1

View File

@ -0,0 +1,17 @@
library L {
function f() internal returns (uint) {
return 66;
}
}
contract C {
function g() public returns (uint) {
function() internal returns(uint) ptr;
ptr = L.f;
return ptr();
}
}
// ====
// compileViaYul: also
// ----
// g() -> 66

View File

@ -0,0 +1,26 @@
contract Base {
function f() internal returns (uint256 i) {
function() internal returns (uint256) ptr = g;
return ptr();
}
function g() internal virtual returns (uint256 i) {
return 1;
}
}
contract Derived is Base {
function g() internal override returns (uint256 i) {
return 2;
}
function h() public returns (uint256 i) {
return f();
}
}
// ====
// compileViaYul: also
// ----
// h() -> 2