diff --git a/libsolidity/codegen/ir/IRGenerationContext.cpp b/libsolidity/codegen/ir/IRGenerationContext.cpp index 493ed7b53..b85e2f7d3 100644 --- a/libsolidity/codegen/ir/IRGenerationContext.cpp +++ b/libsolidity/codegen/ir/IRGenerationContext.cpp @@ -29,6 +29,8 @@ #include #include +#include + using namespace std; using namespace solidity; using namespace solidity::util; @@ -121,49 +123,55 @@ string IRGenerationContext::newYulVariable() return "_" + to_string(++m_varCounter); } -string IRGenerationContext::generateInternalDispatchFunction(YulArity const& _arity) +void IRGenerationContext::initializeInternalDispatch(InternalDispatchMap _internalDispatch) { - string funName = IRNames::internalDispatch(_arity); - return m_functions.createFunction(funName, [&]() { - Whiskers templ(R"( - function (fun, ) -> { - switch fun - <#cases> - case - { - := () - } - - default { invalid() } - } - )"); - templ("functionName", funName); - templ("in", suffixedVariableNameList("in_", 0, _arity.in)); - templ("out", suffixedVariableNameList("out_", 0, _arity.out)); - - vector> cases; - for (FunctionDefinition const* function: collectFunctionsOfArity(_arity)) - { - solAssert(function, ""); - solAssert( - YulArity::fromType(*TypeProvider::function(*function, FunctionType::Kind::Internal)) == _arity, - "A single dispatch function can only handle functions of one arity" - ); - solAssert(!function->isConstructor(), ""); - // 0 is reserved for uninitialized function pointers - solAssert(function->id() != 0, "Unexpected function ID: 0"); - - cases.emplace_back(map{ - {"funID", to_string(function->id())}, - {"name", IRNames::function(*function)} - }); + solAssert(internalDispatchClean(), ""); + for (set const& functions: _internalDispatch | boost::adaptors::map_values) + for (auto function: functions) enqueueFunctionForCodeGeneration(*function); - } - templ("cases", move(cases)); - return templ.render(); - }); + m_internalDispatchMap = move(_internalDispatch); +} + +InternalDispatchMap IRGenerationContext::consumeInternalDispatchMap() +{ + m_directInternalFunctionCalls.clear(); + + InternalDispatchMap internalDispatch = move(m_internalDispatchMap); + m_internalDispatchMap.clear(); + return internalDispatch; +} + +void IRGenerationContext::internalFunctionCalledDirectly(Expression const& _expression) +{ + solAssert(m_directInternalFunctionCalls.count(&_expression) == 0, ""); + + m_directInternalFunctionCalls.insert(&_expression); +} + +void IRGenerationContext::internalFunctionAccessed(Expression const& _expression, FunctionDefinition const& _function) +{ + solAssert( + IRHelpers::referencedFunctionDeclaration(_expression) && + _function.resolveVirtual(mostDerivedContract()) == + IRHelpers::referencedFunctionDeclaration(_expression)->resolveVirtual(mostDerivedContract()), + "Function definition does not match the expression" + ); + + if (m_directInternalFunctionCalls.count(&_expression) == 0) + { + FunctionType const* functionType = TypeProvider::function(_function, FunctionType::Kind::Internal); + solAssert(functionType, ""); + + m_internalDispatchMap[YulArity::fromType(*functionType)].insert(&_function); + enqueueFunctionForCodeGeneration(_function); + } +} + +void IRGenerationContext::internalFunctionCalledThroughDispatch(YulArity const& _arity) +{ + m_internalDispatchMap.try_emplace(_arity); } YulUtilFunctions IRGenerationContext::utils() @@ -180,21 +188,3 @@ std::string IRGenerationContext::revertReasonIfDebug(std::string const& _message { return YulUtilFunctions::revertReasonIfDebug(m_revertStrings, _message); } - -set IRGenerationContext::collectFunctionsOfArity(YulArity const& _arity) -{ - // UNIMPLEMENTED: Internal library calls via pointers are not implemented yet. - // We're not returning any internal library functions here even though it's possible - // to call them via pointers. Right now such calls end will up triggering the `default` case in - // the switch in the generated dispatch function. - set functions; - for (auto const& contract: mostDerivedContract().annotation().linearizedBaseContracts) - for (FunctionDefinition const* function: contract->definedFunctions()) - if ( - !function->isConstructor() && - YulArity::fromType(*TypeProvider::function(*function, FunctionType::Kind::Internal)) == _arity - ) - functions.insert(function); - - return functions; -} diff --git a/libsolidity/codegen/ir/IRGenerationContext.h b/libsolidity/codegen/ir/IRGenerationContext.h index 30aa5e220..d6d8022df 100644 --- a/libsolidity/codegen/ir/IRGenerationContext.h +++ b/libsolidity/codegen/ir/IRGenerationContext.h @@ -43,6 +43,8 @@ namespace solidity::frontend class YulUtilFunctions; class ABIFunctions; +using InternalDispatchMap = std::map>; + /** * Class that contains contextual information during IR generation. */ @@ -102,7 +104,26 @@ public: std::string newYulVariable(); - std::string generateInternalDispatchFunction(YulArity const& _arity); + void initializeInternalDispatch(InternalDispatchMap _internalDispatchMap); + InternalDispatchMap consumeInternalDispatchMap(); + bool internalDispatchClean() const { return m_internalDispatchMap.empty() && m_directInternalFunctionCalls.empty(); } + + /// Notifies the context that a function call that needs to go through internal dispatch was + /// encountered while visiting the AST. This ensures that the corresponding dispatch function + /// gets added to the dispatch map even if there are no entries in it (which may happen if + /// the code contains a call to an uninitialized function variable). + void internalFunctionCalledThroughDispatch(YulArity const& _arity); + + /// Notifies the context that a direct function call (i.e. not through internal dispatch) was + /// encountered while visiting the AST. This lets the context know that the function should + /// not be added to the dispatch (unless there are also indirect calls to it elsewhere else). + void internalFunctionCalledDirectly(Expression const& _expression); + + /// Notifies the context that a name representing an internal function has been found while + /// visiting the AST. If the name has not been reported as a direct call using + /// @a internalFunctionCalledDirectly(), it's assumed to represent function variable access + /// and the function gets added to internal dispatch. + void internalFunctionAccessed(Expression const& _expression, FunctionDefinition const& _function); /// @returns a new copy of the utility function generator (but using the same function set). YulUtilFunctions utils(); @@ -120,8 +141,6 @@ public: std::set& subObjectsCreated() { return m_subObjects; } private: - std::set collectFunctionsOfArity(YulArity const& _arity); - langutil::EVMVersion m_evmVersion; RevertStrings m_revertStrings; OptimiserSettings m_optimiserSettings; @@ -147,6 +166,13 @@ private: /// all platforms - which is a property guaranteed by MultiUseYulFunctionCollector. std::set m_functionGenerationQueue; + /// Collection of functions that need to be callable via internal dispatch. + /// Note that having a key with an empty set of functions is a valid situation. It means that + /// the code contains a call via a pointer even though a specific function is never assigned to it. + /// It will fail at runtime but the code must still compile. + InternalDispatchMap m_internalDispatchMap; + std::set m_directInternalFunctionCalls; + std::set m_subObjects; }; diff --git a/libsolidity/codegen/ir/IRGenerator.cpp b/libsolidity/codegen/ir/IRGenerator.cpp index 69aad0af6..1bb6c7089 100644 --- a/libsolidity/codegen/ir/IRGenerator.cpp +++ b/libsolidity/codegen/ir/IRGenerator.cpp @@ -38,6 +38,8 @@ #include +#include + #include using namespace std; @@ -137,14 +139,22 @@ string IRGenerator::generate( t("deploy", deployCode(_contract)); generateImplicitConstructors(_contract); generateQueuedFunctions(); + InternalDispatchMap internalDispatchMap = generateInternalDispatchFunctions(); t("functions", m_context.functionCollector().requestedFunctions()); t("subObjects", subObjectSources(m_context.subObjectsCreated())); resetContext(_contract); + + // NOTE: Function pointers can be passed from creation code via storage variables. We need to + // get all the functions they could point to into the dispatch functions even if they're never + // referenced by name in the runtime code. + m_context.initializeInternalDispatch(move(internalDispatchMap)); + // Do not register immutables to avoid assignment. t("RuntimeObject", IRNames::runtimeObject(_contract)); t("dispatch", dispatchRoutine(_contract)); generateQueuedFunctions(); + generateInternalDispatchFunctions(); t("runtimeFunctions", m_context.functionCollector().requestedFunctions()); t("runtimeSubObjects", subObjectSources(m_context.subObjectsCreated())); return t.render(); @@ -164,6 +174,68 @@ void IRGenerator::generateQueuedFunctions() generateFunction(*m_context.dequeueFunctionForCodeGeneration()); } +InternalDispatchMap IRGenerator::generateInternalDispatchFunctions() +{ + solAssert( + m_context.functionGenerationQueueEmpty(), + "At this point all the enqueued functions should have been generated. " + "Otherwise the dispatch may be incomplete." + ); + + InternalDispatchMap internalDispatchMap = m_context.consumeInternalDispatchMap(); + for (YulArity const& arity: internalDispatchMap | boost::adaptors::map_keys) + { + string funName = IRNames::internalDispatch(arity); + m_context.functionCollector().createFunction(funName, [&]() { + Whiskers templ(R"( + function (fun, ) -> { + switch fun + <#cases> + case + { + := () + } + + default { invalid() } + } + )"); + templ("functionName", funName); + templ("in", suffixedVariableNameList("in_", 0, arity.in)); + templ("out", suffixedVariableNameList("out_", 0, arity.out)); + + vector> cases; + for (FunctionDefinition const* function: internalDispatchMap.at(arity)) + { + solAssert(function, ""); + solAssert( + YulArity::fromType(*TypeProvider::function(*function, FunctionType::Kind::Internal)) == arity, + "A single dispatch function can only handle functions of one arity" + ); + solAssert(!function->isConstructor(), ""); + // 0 is reserved for uninitialized function pointers + solAssert(function->id() != 0, "Unexpected function ID: 0"); + solAssert(m_context.functionCollector().contains(IRNames::function(*function)), ""); + + cases.emplace_back(map{ + {"funID", to_string(function->id())}, + {"name", IRNames::function(*function)} + }); + } + + templ("cases", move(cases)); + return templ.render(); + }); + } + + solAssert(m_context.internalDispatchClean(), ""); + solAssert( + m_context.functionGenerationQueueEmpty(), + "Internal dispatch generation must not add new functions to generation queue because they won't be proeessed." + ); + + return internalDispatchMap; +} + string IRGenerator::generateFunction(FunctionDefinition const& _function) { string functionName = IRNames::function(_function); @@ -556,6 +628,10 @@ void IRGenerator::resetContext(ContractDefinition const& _contract) m_context.functionCollector().requestedFunctions().empty(), "Reset context while it still had functions." ); + solAssert( + m_context.internalDispatchClean(), + "Reset internal dispatch map without consuming it." + ); m_context = IRGenerationContext(m_evmVersion, m_context.revertStrings(), m_optimiserSettings); m_context.setMostDerivedContract(_contract); diff --git a/libsolidity/codegen/ir/IRGenerator.h b/libsolidity/codegen/ir/IRGenerator.h index bff0c0739..6bf94e575 100644 --- a/libsolidity/codegen/ir/IRGenerator.h +++ b/libsolidity/codegen/ir/IRGenerator.h @@ -65,6 +65,11 @@ private: /// Generates code for all the functions from the function generation queue. /// The resulting code is stored in the function collector in IRGenerationContext. void generateQueuedFunctions(); + /// Generates all the internal dispatch functions necessary to handle any function that could + /// possibly be called via a pointer. + /// @return The content of the dispatch for reuse in runtime code. Reuse is necessary because + /// pointers to functions can be passed from the creation code in storage variables. + InternalDispatchMap generateInternalDispatchFunctions(); /// Generates code for and returns the name of the function. std::string generateFunction(FunctionDefinition const& _function); /// Generates a getter for the given declaration and returns its name diff --git a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp index 5472ec7ed..98c0b2169 100644 --- a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp +++ b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp @@ -583,6 +583,20 @@ bool IRGeneratorForStatements::visit(BinaryOperation const& _binOp) return false; } +bool IRGeneratorForStatements::visit(FunctionCall const& _functionCall) +{ + FunctionTypePointer functionType = dynamic_cast(&type(_functionCall.expression())); + if ( + functionType && + functionType->kind() == FunctionType::Kind::Internal && + !functionType->bound() && + IRHelpers::referencedFunctionDeclaration(_functionCall.expression()) + ) + m_context.internalFunctionCalledDirectly(_functionCall.expression()); + + return true; +} + void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall) { solUnimplementedAssert( @@ -688,9 +702,10 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall) else { YulArity arity = YulArity::fromType(*functionType); + m_context.internalFunctionCalledThroughDispatch(arity); + define(_functionCall) << - // NOTE: generateInternalDispatchFunction() takes care of adding the function to function generation queue - m_context.generateInternalDispatchFunction(arity) << + IRNames::internalDispatch(arity) << "(" << IRVariable(_functionCall.expression()).part("functionIdentifier").name() << joinHumanReadablePrefixed(args) << @@ -1492,7 +1507,10 @@ void IRGeneratorForStatements::endVisit(MemberAccess const& _memberAccess) break; case FunctionType::Kind::Internal: if (auto const* function = dynamic_cast(_memberAccess.annotation().referencedDeclaration)) + { define(_memberAccess) << to_string(function->id()) << "\n"; + m_context.internalFunctionAccessed(_memberAccess, *function); + } else solAssert(false, "Function not found in member access"); break; @@ -1756,7 +1774,14 @@ void IRGeneratorForStatements::endVisit(Identifier const& _identifier) return; } else if (FunctionDefinition const* functionDef = dynamic_cast(declaration)) - define(_identifier) << to_string(functionDef->resolveVirtual(m_context.mostDerivedContract()).id()) << "\n"; + { + FunctionDefinition const& resolvedFunctionDef = functionDef->resolveVirtual(m_context.mostDerivedContract()); + define(_identifier) << to_string(resolvedFunctionDef.id()) << "\n"; + + solAssert(resolvedFunctionDef.functionType(true), ""); + solAssert(resolvedFunctionDef.functionType(true)->kind() == FunctionType::Kind::Internal, ""); + m_context.internalFunctionAccessed(_identifier, resolvedFunctionDef); + } else if (VariableDeclaration const* varDecl = dynamic_cast(declaration)) handleVariableReference(*varDecl, _identifier); else if (dynamic_cast(declaration)) diff --git a/libsolidity/codegen/ir/IRGeneratorForStatements.h b/libsolidity/codegen/ir/IRGeneratorForStatements.h index 42d355ddd..bdf21f783 100644 --- a/libsolidity/codegen/ir/IRGeneratorForStatements.h +++ b/libsolidity/codegen/ir/IRGeneratorForStatements.h @@ -70,6 +70,7 @@ public: void endVisit(Return const& _return) override; void endVisit(UnaryOperation const& _unaryOperation) override; bool visit(BinaryOperation const& _binOp) override; + bool visit(FunctionCall const& _funCall) override; void endVisit(FunctionCall const& _funCall) override; void endVisit(FunctionCallOptions const& _funCallOptions) override; void endVisit(MemberAccess const& _memberAccess) override; diff --git a/test/libsolidity/semanticTests/constructor/functions_called_by_constructor_through_dispatch.sol b/test/libsolidity/semanticTests/constructor/functions_called_by_constructor_through_dispatch.sol new file mode 100644 index 000000000..56f92cd50 --- /dev/null +++ b/test/libsolidity/semanticTests/constructor/functions_called_by_constructor_through_dispatch.sol @@ -0,0 +1,31 @@ +contract Test { + bytes6 name; + + constructor() public { + function (bytes6 _name) internal setter = setName; + setter("abcdef"); + + applyShift(leftByteShift, 3); + } + + function getName() public returns (bytes6 ret) { + return name; + } + + function setName(bytes6 _name) private { + name = _name; + } + + function leftByteShift(bytes6 _value, uint _shift) public returns (bytes6) { + return _value << _shift * 8; + } + + function applyShift(function (bytes6 _value, uint _shift) internal returns (bytes6) _shiftOperator, uint _bytes) internal { + name = _shiftOperator(name, _bytes); + } +} + +// ==== +// compileViaYul: also +// ---- +// getName() -> "def\x00\x00\x00" diff --git a/test/libsolidity/semanticTests/functionTypes/function_type_library_internal.sol b/test/libsolidity/semanticTests/functionTypes/function_type_library_internal.sol index f096a4979..c4554740d 100644 --- a/test/libsolidity/semanticTests/functionTypes/function_type_library_internal.sol +++ b/test/libsolidity/semanticTests/functionTypes/function_type_library_internal.sol @@ -22,5 +22,7 @@ contract C { } } +// ==== +// compileViaYul: also // ---- // f(uint256[]): 0x20, 0x3, 0x1, 0x7, 0x3 -> 11 diff --git a/test/libsolidity/semanticTests/intheritance/inherited_function_through_dispatch.sol b/test/libsolidity/semanticTests/intheritance/inherited_function_through_dispatch.sol new file mode 100644 index 000000000..a7aa1fcb8 --- /dev/null +++ b/test/libsolidity/semanticTests/intheritance/inherited_function_through_dispatch.sol @@ -0,0 +1,21 @@ +contract A { + function f() internal virtual returns (uint256) { + return 1; + } +} + + +contract B is A { + function f() internal override returns (uint256) { + return 2; + } + + function g() public returns (uint256) { + function() internal returns (uint256) ptr = A.f; + return ptr(); + } +} +// ==== +// compileViaYul: also +// ---- +// g() -> 1 diff --git a/test/libsolidity/semanticTests/libraries/internal_library_function_pointer.sol b/test/libsolidity/semanticTests/libraries/internal_library_function_pointer.sol new file mode 100644 index 000000000..bc2f2da0a --- /dev/null +++ b/test/libsolidity/semanticTests/libraries/internal_library_function_pointer.sol @@ -0,0 +1,17 @@ +library L { + function f() internal returns (uint) { + return 66; + } +} + +contract C { + function g() public returns (uint) { + function() internal returns(uint) ptr; + ptr = L.f; + return ptr(); + } +} +// ==== +// compileViaYul: also +// ---- +// g() -> 66 diff --git a/test/libsolidity/semanticTests/virtualFunctions/internal_virtual_function_calls_through_dispatch.sol b/test/libsolidity/semanticTests/virtualFunctions/internal_virtual_function_calls_through_dispatch.sol new file mode 100644 index 000000000..f3bc0845f --- /dev/null +++ b/test/libsolidity/semanticTests/virtualFunctions/internal_virtual_function_calls_through_dispatch.sol @@ -0,0 +1,26 @@ +contract Base { + function f() internal returns (uint256 i) { + function() internal returns (uint256) ptr = g; + return ptr(); + } + + function g() internal virtual returns (uint256 i) { + return 1; + } +} + + +contract Derived is Base { + function g() internal override returns (uint256 i) { + return 2; + } + + function h() public returns (uint256 i) { + return f(); + } +} + +// ==== +// compileViaYul: also +// ---- +// h() -> 2