diff --git a/test/libyul/yulInterpreterTests/function_scopes.yul b/test/libyul/yulInterpreterTests/function_scopes.yul new file mode 100644 index 000000000..0150fa990 --- /dev/null +++ b/test/libyul/yulInterpreterTests/function_scopes.yul @@ -0,0 +1,20 @@ +{ + f(1) + function f(i) { + if i { g(1) } + function g(j) { + if j { h() } + f(0) + function h() { + g(0) + } + } + sstore(i, add(i, 7)) + } +} +// ---- +// Trace: +// Memory dump: +// Storage dump: +// 0000000000000000000000000000000000000000000000000000000000000000: 0000000000000000000000000000000000000000000000000000000000000007 +// 0000000000000000000000000000000000000000000000000000000000000001: 0000000000000000000000000000000000000000000000000000000000000008 diff --git a/test/libyul/yulInterpreterTests/recursion.yul b/test/libyul/yulInterpreterTests/recursion.yul new file mode 100644 index 000000000..0d80faf30 --- /dev/null +++ b/test/libyul/yulInterpreterTests/recursion.yul @@ -0,0 +1,14 @@ +{ + function fib(i) -> y { + y := 1 + if gt(i, 2) { + y := add(fib(sub(i, 1)), fib(sub(i, 2))) + } + } + sstore(0, fib(8)) +} +// ---- +// Trace: +// Memory dump: +// Storage dump: +// 0000000000000000000000000000000000000000000000000000000000000000: 0000000000000000000000000000000000000000000000000000000000000015 diff --git a/test/tools/yulInterpreter/Interpreter.cpp b/test/tools/yulInterpreter/Interpreter.cpp index 73c17a6c3..9089e6df5 100644 --- a/test/tools/yulInterpreter/Interpreter.cpp +++ b/test/tools/yulInterpreter/Interpreter.cpp @@ -90,7 +90,8 @@ void Interpreter::operator()(VariableDeclaration const& _declaration) YulString varName = _declaration.variables.at(i).name; solAssert(!m_variables.count(varName), ""); m_variables[varName] = values.at(i); - m_scopes.back().insert(varName); + solAssert(!m_scopes.back().count(varName), ""); + m_scopes.back().emplace(varName, nullptr); } } @@ -164,8 +165,8 @@ void Interpreter::operator()(Block const& _block) if (statement.type() == typeid(FunctionDefinition)) { FunctionDefinition const& funDef = boost::get(statement); - m_functions[funDef.name] = &funDef; - m_scopes.back().insert(funDef.name); + solAssert(!m_scopes.back().count(funDef.name), ""); + m_scopes.back().emplace(funDef.name, &funDef); } for (auto const& statement: _block.statements) @@ -180,25 +181,23 @@ void Interpreter::operator()(Block const& _block) u256 Interpreter::evaluate(Expression const& _expression) { - ExpressionEvaluator ev(m_state, m_dialect, m_variables, m_functions, m_scopes); + ExpressionEvaluator ev(m_state, m_dialect, m_variables, m_scopes); ev.visit(_expression); return ev.value(); } vector Interpreter::evaluateMulti(Expression const& _expression) { - ExpressionEvaluator ev(m_state, m_dialect, m_variables, m_functions, m_scopes); + ExpressionEvaluator ev(m_state, m_dialect, m_variables, m_scopes); ev.visit(_expression); return ev.values(); } void Interpreter::closeScope() { - for (auto const& var: m_scopes.back()) - { - size_t erased = m_variables.erase(var) + m_functions.erase(var); - solAssert(erased == 1, ""); - } + for (auto const& [var, funDeclaration]: m_scopes.back()) + if (!funDeclaration) + solAssert(m_variables.erase(var) == 1, ""); m_scopes.pop_back(); } @@ -237,22 +236,21 @@ void ExpressionEvaluator::operator()(FunctionCall const& _funCall) return; } - solAssert(m_functions.count(_funCall.functionName.name), ""); - FunctionDefinition const& fun = *m_functions.at(_funCall.functionName.name); - solAssert(m_values.size() == fun.parameters.size(), ""); - map variables; - for (size_t i = 0; i < fun.parameters.size(); ++i) - variables[fun.parameters.at(i).name] = m_values.at(i); - for (size_t i = 0; i < fun.returnVariables.size(); ++i) - variables[fun.returnVariables.at(i).name] = 0; + auto [functionScopes, fun] = findFunctionAndScope(_funCall.functionName.name); - // TODO function name lookup could be a little more efficient, - // we have to copy the list here. - Interpreter interpreter(m_state, m_dialect, variables, visibleFunctionsFor(fun.name)); - interpreter(fun.body); + solAssert(fun, "Function not found."); + solAssert(m_values.size() == fun->parameters.size(), ""); + map variables; + for (size_t i = 0; i < fun->parameters.size(); ++i) + variables[fun->parameters.at(i).name] = m_values.at(i); + for (size_t i = 0; i < fun->returnVariables.size(); ++i) + variables[fun->returnVariables.at(i).name] = 0; + + Interpreter interpreter(m_state, m_dialect, variables, functionScopes); + interpreter(fun->body); m_values.clear(); - for (auto const& retVar: fun.returnVariables) + for (auto const& retVar: fun->returnVariables) m_values.emplace_back(interpreter.valueOfVariable(retVar.name)); } @@ -281,19 +279,26 @@ void ExpressionEvaluator::evaluateArgs(vector const& _expr) std::reverse(m_values.begin(), m_values.end()); } -std::map ExpressionEvaluator::visibleFunctionsFor(YulString const& _name) +pair< + vector>, + FunctionDefinition const* +> ExpressionEvaluator::findFunctionAndScope(YulString _functionName) const { - std::map functions; - + FunctionDefinition const* fun = nullptr; + std::vector> newScopes; for (auto const& scope: m_scopes) { - for (auto const& symbol: scope) - if (m_functions.count(symbol) > 0) - functions[symbol] = m_functions.at(symbol); - - if (scope.count(_name)) + // Copy over all functions. + newScopes.push_back({}); + for (auto const& [name, funDef]: scope) + if (funDef) + newScopes.back().emplace(name, funDef); + // Stop at the called function. + if (scope.count(_functionName)) + { + fun = scope.at(_functionName); break; + } } - - return functions; + return {move(newScopes), fun}; } diff --git a/test/tools/yulInterpreter/Interpreter.h b/test/tools/yulInterpreter/Interpreter.h index d7c6b1f76..92e00b42f 100644 --- a/test/tools/yulInterpreter/Interpreter.h +++ b/test/tools/yulInterpreter/Interpreter.h @@ -106,12 +106,12 @@ public: InterpreterState& _state, Dialect const& _dialect, std::map _variables = {}, - std::map _functions = {} + std::vector> _scopes = {} ): m_dialect(_dialect), m_state(_state), m_variables(std::move(_variables)), - m_functions(std::move(_functions)) + m_scopes(std::move(_scopes)) {} void operator()(ExpressionStatement const& _statement) override; @@ -136,17 +136,17 @@ private: std::vector evaluateMulti(Expression const& _expression); void openScope() { m_scopes.push_back({}); } - /// Unregisters variables. + /// Unregisters variables and functions. void closeScope(); Dialect const& m_dialect; InterpreterState& m_state; /// Values of variables. std::map m_variables; - /// Meanings of functions. - std::map m_functions; - /// Scopes of variables and functions, used to clear them at end of blocks. - std::vector> m_scopes; + /// Scopes of variables and functions. Used for lookup, clearing at end of blocks + /// and passing over the visible functions across function calls. + /// The pointer is nullptr if and only if the key is a variable. + std::vector> m_scopes; }; /** @@ -159,13 +159,11 @@ public: InterpreterState& _state, Dialect const& _dialect, std::map const& _variables, - std::map const& _functions, - std::vector> const& _scopes + std::vector> const& _scopes ): m_state(_state), m_dialect(_dialect), m_variables(_variables), - m_functions(_functions), m_scopes(_scopes) {} @@ -186,16 +184,19 @@ private: /// stores it in m_value. void evaluateArgs(std::vector const& _expr); - /// Extracts functions from the earlier scopes that are visible for the given function - std::map visibleFunctionsFor(YulString const& _name); + /// Finds the function called @a _functionName in the current scope stack and returns + /// the function's scope stack (with variables removed) and definition. + std::pair< + std::vector>, + FunctionDefinition const* + > findFunctionAndScope(YulString _functionName) const; InterpreterState& m_state; Dialect const& m_dialect; /// Values of variables. std::map const& m_variables; - /// Meanings of functions. - std::map const& m_functions; - std::vector> const& m_scopes; + /// Stack of scopes in the current context. + std::vector> const& m_scopes; /// Current value of the expression std::vector m_values; };