diff --git a/libyul/backends/wasm/WasmCodeTransform.cpp b/libyul/backends/wasm/WasmCodeTransform.cpp index 8de291402..d04d84c0e 100644 --- a/libyul/backends/wasm/WasmCodeTransform.cpp +++ b/libyul/backends/wasm/WasmCodeTransform.cpp @@ -20,6 +20,8 @@ #include +#include + #include #include @@ -40,7 +42,8 @@ wasm::Module WasmCodeTransform::run(Dialect const& _dialect, yul::Block const& _ { wasm::Module module; - WasmCodeTransform transform(_dialect, _ast); + TypeInfo typeInfo(_dialect, _ast); + WasmCodeTransform transform(_dialect, _ast, typeInfo); for (auto const& statement: _ast.statements) { @@ -70,14 +73,18 @@ wasm::Expression WasmCodeTransform::generateMultiAssignment( if (_variableNames.size() == 1) return { std::move(assignment) }; - allocateGlobals(_variableNames.size() - 1); + vector typesForGlobals; + for (size_t i = 1; i < _variableNames.size(); ++i) + typesForGlobals.push_back(translatedType(m_typeInfo.typeOfVariable(YulString(_variableNames[i])))); + vector allocatedIndices = allocateGlobals(typesForGlobals); + yulAssert(allocatedIndices.size() == _variableNames.size() - 1, ""); wasm::Block block; block.statements.emplace_back(move(assignment)); for (size_t i = 1; i < _variableNames.size(); ++i) block.statements.emplace_back(wasm::LocalAssignment{ move(_variableNames.at(i)), - make_unique(wasm::GlobalVariable{m_globalVariables.at(i - 1).variableName}) + make_unique(wasm::GlobalVariable{m_globalVariables.at(allocatedIndices[i - 1]).variableName}) }); return { std::move(block) }; } @@ -88,7 +95,7 @@ wasm::Expression WasmCodeTransform::operator()(VariableDeclaration const& _varDe for (auto const& var: _varDecl.variables) { variableNames.emplace_back(var.name.str()); - m_localVariables.emplace_back(wasm::VariableDeclaration{variableNames.back(), wasm::Type::i64}); + m_localVariables.emplace_back(wasm::VariableDeclaration{variableNames.back(), translatedType(var.type)}); } if (_varDecl.value) @@ -165,20 +172,21 @@ wasm::Expression WasmCodeTransform::operator()(Identifier const& _identifier) wasm::Expression WasmCodeTransform::operator()(Literal const& _literal) { - u256 value = valueOfLiteral(_literal); - yulAssert(value <= numeric_limits::max(), "Literal too large: " + value.str()); - return wasm::Literal{static_cast(value)}; + return makeLiteral(translatedType(_literal.type), valueOfLiteral(_literal)); } wasm::Expression WasmCodeTransform::operator()(If const& _if) { - // TODO converting i64 to i32 might not always be needed. + yul::Type conditionType = m_typeInfo.typeOf(*_if.condition); + YulString ne_instruction = YulString(conditionType.str() + ".ne"); + yulAssert(WasmDialect::instance().builtin(ne_instruction), ""); + // TODO converting i64 to i32 might not always be needed. vector args; args.emplace_back(visitReturnByValue(*_if.condition)); - args.emplace_back(wasm::Literal{static_cast(0)}); + args.emplace_back(makeLiteral(translatedType(conditionType), 0)); return wasm::If{ - make_unique(wasm::BuiltinCall{"i64.ne", std::move(args)}), + make_unique(wasm::BuiltinCall{ne_instruction.str(), std::move(args)}), visit(_if.body.statements), {} }; @@ -186,9 +194,13 @@ wasm::Expression WasmCodeTransform::operator()(If const& _if) wasm::Expression WasmCodeTransform::operator()(Switch const& _switch) { + yul::Type expressionType = m_typeInfo.typeOf(*_switch.expression); + YulString eq_instruction = YulString(expressionType.str() + ".eq"); + yulAssert(WasmDialect::instance().builtin(eq_instruction), ""); + wasm::Block block; string condition = m_nameDispenser.newName("condition"_yulstring).str(); - m_localVariables.emplace_back(wasm::VariableDeclaration{condition, wasm::Type::i64}); + m_localVariables.emplace_back(wasm::VariableDeclaration{condition, translatedType(expressionType)}); block.statements.emplace_back(wasm::LocalAssignment{condition, visit(*_switch.expression)}); vector* currentBlock = &block.statements; @@ -197,7 +209,7 @@ wasm::Expression WasmCodeTransform::operator()(Switch const& _switch) Case const& c = _switch.cases.at(i); if (c.value) { - wasm::BuiltinCall comparison{"i64.eq", make_vector( + wasm::BuiltinCall comparison{eq_instruction.str(), make_vector( wasm::LocalVariable{condition}, visitReturnByValue(*c.value) )}; @@ -236,11 +248,15 @@ wasm::Expression WasmCodeTransform::operator()(ForLoop const& _for) string continueLabel = newLabel(); m_breakContinueLabelNames.push({breakLabel, continueLabel}); + yul::Type conditionType = m_typeInfo.typeOf(*_for.condition); + YulString eqz_instruction = YulString(conditionType.str() + ".eqz"); + yulAssert(WasmDialect::instance().builtin(eqz_instruction), ""); + wasm::Loop loop; loop.labelName = newLabel(); loop.statements = visit(_for.pre.statements); loop.statements.emplace_back(wasm::BranchIf{wasm::Label{breakLabel}, make_unique( - wasm::BuiltinCall{"i64.eqz", make_vector( + wasm::BuiltinCall{eqz_instruction.str(), make_vector( visitReturnByValue(*_for.condition) )} )}); @@ -308,11 +324,11 @@ wasm::FunctionDefinition WasmCodeTransform::translateFunction(yul::FunctionDefin wasm::FunctionDefinition fun; fun.name = _fun.name.str(); for (auto const& param: _fun.parameters) - fun.parameters.push_back({param.name.str(), wasm::Type::i64}); + fun.parameters.push_back({param.name.str(), translatedType(param.type)}); for (auto const& retParam: _fun.returnVariables) - fun.locals.emplace_back(wasm::VariableDeclaration{retParam.name.str(), wasm::Type::i64}); + fun.locals.emplace_back(wasm::VariableDeclaration{retParam.name.str(), translatedType(retParam.type)}); if (!_fun.returnVariables.empty()) - fun.returnType = wasm::Type::i64; + fun.returnType = translatedType(_fun.returnVariables[0].type); yulAssert(m_localVariables.empty(), ""); yulAssert(m_functionBodyLabel.empty(), ""); @@ -330,10 +346,15 @@ wasm::FunctionDefinition WasmCodeTransform::translateFunction(yul::FunctionDefin { // First return variable is returned directly, the others are stored // in globals. - allocateGlobals(_fun.returnVariables.size() - 1); + vector typesForGlobals; + for (size_t i = 1; i < _fun.returnVariables.size(); ++i) + typesForGlobals.push_back(translatedType(_fun.returnVariables[i].type)); + vector allocatedIndices = allocateGlobals(typesForGlobals); + yulAssert(allocatedIndices.size() == _fun.returnVariables.size() - 1, ""); + for (size_t i = 1; i < _fun.returnVariables.size(); ++i) fun.body.emplace_back(wasm::GlobalAssignment{ - m_globalVariables.at(i - 1).variableName, + m_globalVariables.at(allocatedIndices[i - 1]).variableName, make_unique(wasm::LocalVariable{_fun.returnVariables.at(i).name.str()}) }); fun.body.emplace_back(wasm::LocalVariable{_fun.returnVariables.front().name.str()}); @@ -346,13 +367,45 @@ string WasmCodeTransform::newLabel() return m_nameDispenser.newName("label_"_yulstring).str(); } -void WasmCodeTransform::allocateGlobals(size_t _amount) +vector WasmCodeTransform::allocateGlobals(vector const& _typesForGlobals) { - while (m_globalVariables.size() < _amount) - m_globalVariables.emplace_back(wasm::GlobalVariableDeclaration{ - m_nameDispenser.newName("global_"_yulstring).str(), - wasm::Type::i64 - }); + map availableGlobals; + for (wasm::GlobalVariableDeclaration const& global: m_globalVariables) + ++availableGlobals[global.type]; + + map neededGlobals; + for (wasm::Type const& type: _typesForGlobals) + ++neededGlobals[type]; + + for (auto [type, neededGlobalCount]: neededGlobals) + while (availableGlobals[type] < neededGlobalCount) + { + m_globalVariables.emplace_back(wasm::GlobalVariableDeclaration{ + m_nameDispenser.newName("global_"_yulstring).str(), + type, + }); + + ++availableGlobals[type]; + } + + vector allocatedIndices; + map nextGlobal; + for (wasm::Type const& type: _typesForGlobals) + { + while (m_globalVariables[nextGlobal[type]].type != type) + ++nextGlobal[type]; + + allocatedIndices.push_back(nextGlobal[type]++); + } + + yulAssert(all_of( + allocatedIndices.begin(), + allocatedIndices.end(), + [this](size_t index){ return index < m_globalVariables.size(); } + ), ""); + yulAssert(allocatedIndices.size() == set(allocatedIndices.begin(), allocatedIndices.end()).size(), "Indices not unique"); + yulAssert(allocatedIndices.size() == _typesForGlobals.size(), ""); + return allocatedIndices; } wasm::Type WasmCodeTransform::translatedType(yul::Type _yulType) @@ -364,3 +417,19 @@ wasm::Type WasmCodeTransform::translatedType(yul::Type _yulType) else yulAssert(false, "This Yul type does not have a corresponding type in Wasm."); } + +wasm::Literal WasmCodeTransform::makeLiteral(wasm::Type _type, u256 _value) +{ + if (_type == wasm::Type::i32) + { + yulAssert(_value <= numeric_limits::max(), "Literal too large: " + _value.str()); + return wasm::Literal{static_cast(_value)}; + } + else if (_type == wasm::Type::i64) + { + yulAssert(_value <= numeric_limits::max(), "Literal too large: " + _value.str()); + return wasm::Literal{static_cast(_value)}; + } + else + yulAssert(false, "Invalid Wasm literal type"); +} diff --git a/libyul/backends/wasm/WasmCodeTransform.h b/libyul/backends/wasm/WasmCodeTransform.h index 68aeb86a3..326778b3d 100644 --- a/libyul/backends/wasm/WasmCodeTransform.h +++ b/libyul/backends/wasm/WasmCodeTransform.h @@ -24,6 +24,9 @@ #include #include #include +#include + +#include #include #include @@ -56,10 +59,12 @@ public: private: WasmCodeTransform( Dialect const& _dialect, - Block const& _ast + Block const& _ast, + TypeInfo& _typeInfo ): m_dialect(_dialect), - m_nameDispenser(_dialect, _ast) + m_nameDispenser(_dialect, _ast), + m_typeInfo(_typeInfo) {} std::unique_ptr visit(yul::Expression const& _expression); @@ -80,10 +85,12 @@ private: wasm::FunctionDefinition translateFunction(yul::FunctionDefinition const& _funDef); std::string newLabel(); - /// Makes sure that there are at least @a _amount global variables. - void allocateGlobals(size_t _amount); + /// Selects a subset of global variables matching specified sequence of variable types. + /// Defines more global variables of a given type if there's not enough. + std::vector allocateGlobals(std::vector const& _typesForGlobals); static wasm::Type translatedType(yul::Type _yulType); + static wasm::Literal makeLiteral(wasm::Type _type, u256 _value); Dialect const& m_dialect; NameDispenser m_nameDispenser; @@ -93,6 +100,7 @@ private: std::map m_functionsToImport; std::string m_functionBodyLabel; std::stack> m_breakContinueLabelNames; + TypeInfo& m_typeInfo; }; }