diff --git a/libyul/backends/wasm/BinaryTransform.cpp b/libyul/backends/wasm/BinaryTransform.cpp index 1cf825dd4..ce97225e0 100644 --- a/libyul/backends/wasm/BinaryTransform.cpp +++ b/libyul/backends/wasm/BinaryTransform.cpp @@ -24,6 +24,7 @@ #include #include +#include using namespace std; using namespace solidity; @@ -130,7 +131,7 @@ bytes toBytes(Opcode _o) return toBytes(uint8_t(_o)); } -static std::map const builtins = { +static map const builtins = { {"i32.load", 0x28}, {"i64.load", 0x29}, {"i32.load8_s", 0x2c}, @@ -240,46 +241,56 @@ bytes lebEncodeSigned(int64_t _n) bytes prefixSize(bytes _data) { size_t size = _data.size(); - return lebEncode(size) + std::move(_data); + return lebEncode(size) + move(_data); } bytes makeSection(Section _section, bytes _data) { - return toBytes(_section) + prefixSize(std::move(_data)); + return toBytes(_section) + prefixSize(move(_data)); } } bytes BinaryTransform::run(Module const& _module) { - BinaryTransform bt; + map> const types = typeToFunctionMap(_module.imports, _module.functions); - for (size_t i = 0; i < _module.globals.size(); ++i) - bt.m_globals[_module.globals[i].variableName] = i; + map const globalIDs = enumerateGlobals(_module); + map const functionIDs = enumerateFunctions(_module); + map const functionTypes = enumerateFunctionTypes(types); - size_t funID = 0; - for (FunctionImport const& fun: _module.imports) - bt.m_functions[fun.internalName] = funID++; - for (FunctionDefinition const& fun: _module.functions) - bt.m_functions[fun.name] = funID++; + yulAssert(globalIDs.size() == _module.globals.size(), ""); + yulAssert(functionIDs.size() == _module.imports.size() + _module.functions.size(), ""); + yulAssert(functionTypes.size() == functionIDs.size(), ""); + yulAssert(functionTypes.size() >= types.size(), ""); bytes ret{0, 'a', 's', 'm'}; // version ret += bytes{1, 0, 0, 0}; - ret += bt.typeSection(_module.imports, _module.functions); - ret += bt.importSection(_module.imports); - ret += bt.functionSection(_module.functions); - ret += bt.memorySection(); - ret += bt.globalSection(); - ret += bt.exportSection(); + ret += typeSection(types); + ret += importSection(_module.imports, functionTypes); + ret += functionSection(_module.functions, functionTypes); + ret += memorySection(); + ret += globalSection(_module.globals); + ret += exportSection(functionIDs); + + map> subModulePosAndSize; for (auto const& sub: _module.subModules) { // TODO should we prefix and / or shorten the name? bytes data = BinaryTransform::run(sub.second); size_t length = data.size(); - ret += bt.customSection(sub.first, std::move(data)); - bt.m_subModulePosAndSize[sub.first] = {ret.size() - length, length}; + ret += customSection(sub.first, move(data)); + subModulePosAndSize[sub.first] = {ret.size() - length, length}; } + + BinaryTransform bt( + move(globalIDs), + move(functionIDs), + move(functionTypes), + move(subModulePosAndSize) + ); + ret += bt.codeSection(_module.functions); return ret; } @@ -302,7 +313,7 @@ bytes BinaryTransform::operator()(LocalVariable const& _variable) bytes BinaryTransform::operator()(GlobalVariable const& _variable) { - return toBytes(Opcode::GlobalGet) + lebEncode(m_globals.at(_variable.name)); + return toBytes(Opcode::GlobalGet) + lebEncode(m_globalIDs.at(_variable.name)); } bytes BinaryTransform::operator()(BuiltinCall const& _call) @@ -311,12 +322,12 @@ bytes BinaryTransform::operator()(BuiltinCall const& _call) // they are references to object names that should not end up in the code. if (_call.functionName == "dataoffset") { - string name = std::get(_call.arguments.at(0)).value; + string name = get(_call.arguments.at(0)).value; return toBytes(Opcode::I64Const) + lebEncodeSigned(m_subModulePosAndSize.at(name).first); } else if (_call.functionName == "datasize") { - string name = std::get(_call.arguments.at(0)).value; + string name = get(_call.arguments.at(0)).value; return toBytes(Opcode::I64Const) + lebEncodeSigned(m_subModulePosAndSize.at(name).second); } @@ -331,7 +342,7 @@ bytes BinaryTransform::operator()(BuiltinCall const& _call) else { yulAssert(builtins.count(_call.functionName), "Builtin " + _call.functionName + " not found"); - bytes ret = std::move(args) + toBytes(builtins.at(_call.functionName)); + bytes ret = move(args) + toBytes(builtins.at(_call.functionName)); if ( _call.functionName.find(".load") != string::npos || _call.functionName.find(".store") != string::npos @@ -348,7 +359,7 @@ bytes BinaryTransform::operator()(BuiltinCall const& _call) bytes BinaryTransform::operator()(FunctionCall const& _call) { - return visit(_call.arguments) + toBytes(Opcode::Call) + lebEncode(m_functions.at(_call.functionName)); + return visit(_call.arguments) + toBytes(Opcode::Call) + lebEncode(m_functionIDs.at(_call.functionName)); } bytes BinaryTransform::operator()(LocalAssignment const& _assignment) @@ -364,7 +375,7 @@ bytes BinaryTransform::operator()(GlobalAssignment const& _assignment) return std::visit(*this, *_assignment.value) + toBytes(Opcode::GlobalSet) + - lebEncode(m_globals.at(_assignment.variableName)); + lebEncode(m_globalIDs.at(_assignment.variableName)); } bytes BinaryTransform::operator()(If const& _if) @@ -457,7 +468,7 @@ bytes BinaryTransform::operator()(FunctionDefinition const& _function) yulAssert(m_labels.empty(), "Stray labels."); - return prefixSize(std::move(ret)); + return prefixSize(move(ret)); } BinaryTransform::Type BinaryTransform::typeOf(FunctionImport const& _import) @@ -496,9 +507,9 @@ vector BinaryTransform::encodeTypes(vector const& _typeNames) return result; } -bytes BinaryTransform::typeSection( - vector const& _imports, - vector const& _functions +map> BinaryTransform::typeToFunctionMap( + vector const& _imports, + vector const& _functions ) { map> types; @@ -507,12 +518,50 @@ bytes BinaryTransform::typeSection( for (auto const& fun: _functions) types[typeOf(fun)].emplace_back(fun.name); - bytes result; - size_t index = 0; - for (auto const& [type, funNames]: types) + return types; +} + +map BinaryTransform::enumerateGlobals(Module const& _module) +{ + map globals; + for (size_t i = 0; i < _module.globals.size(); ++i) + globals[_module.globals[i].variableName] = i; + + return globals; +} + +map BinaryTransform::enumerateFunctions(Module const& _module) +{ + map functions; + size_t funID = 0; + for (FunctionImport const& fun: _module.imports) + functions[fun.internalName] = funID++; + for (FunctionDefinition const& fun: _module.functions) + functions[fun.name] = funID++; + + return functions; +} + +map BinaryTransform::enumerateFunctionTypes(map> const& _typeToFunctionMap) +{ + map functionTypes; + size_t typeID = 0; + for (vector const& funNames: _typeToFunctionMap | boost::adaptors::map_values) { for (string const& name: funNames) - m_functionTypes[name] = index; + functionTypes[name] = typeID; + ++typeID; + } + + return functionTypes; +} + +bytes BinaryTransform::typeSection(map> const& _typeToFunctionMap) +{ + bytes result; + size_t index = 0; + for (Type const& type: _typeToFunctionMap | boost::adaptors::map_keys) + { result += toBytes(ValueType::Function); result += lebEncode(type.first.size()) + type.first; result += lebEncode(type.second.size()) + type.second; @@ -520,11 +569,12 @@ bytes BinaryTransform::typeSection( index++; } - return makeSection(Section::TYPE, lebEncode(index) + std::move(result)); + return makeSection(Section::TYPE, lebEncode(index) + move(result)); } bytes BinaryTransform::importSection( - vector const& _imports + vector const& _imports, + map const& _functionTypes ) { bytes result = lebEncode(_imports.size()); @@ -535,17 +585,20 @@ bytes BinaryTransform::importSection( encodeName(import.module) + encodeName(import.externalName) + toBytes(importKind) + - lebEncode(m_functionTypes[import.internalName]); + lebEncode(_functionTypes.at(import.internalName)); } - return makeSection(Section::IMPORT, std::move(result)); + return makeSection(Section::IMPORT, move(result)); } -bytes BinaryTransform::functionSection(vector const& _functions) +bytes BinaryTransform::functionSection( + vector const& _functions, + map const& _functionTypes +) { bytes result = lebEncode(_functions.size()); for (auto const& fun: _functions) - result += lebEncode(m_functionTypes.at(fun.name)); - return makeSection(Section::FUNCTION, std::move(result)); + result += lebEncode(_functionTypes.at(fun.name)); + return makeSection(Section::FUNCTION, move(result)); } bytes BinaryTransform::memorySection() @@ -553,13 +606,13 @@ bytes BinaryTransform::memorySection() bytes result = lebEncode(1); result.push_back(static_cast(LimitsKind::Min)); result.push_back(1); // initial length - return makeSection(Section::MEMORY, std::move(result)); + return makeSection(Section::MEMORY, move(result)); } -bytes BinaryTransform::globalSection() +bytes BinaryTransform::globalSection(vector const& _globals) { - bytes result = lebEncode(m_globals.size()); - for (size_t i = 0; i < m_globals.size(); ++i) + bytes result = lebEncode(_globals.size()); + for (size_t i = 0; i < _globals.size(); ++i) result += toBytes(ValueType::I64) + lebEncode(static_cast(Mutability::Var)) + @@ -567,21 +620,21 @@ bytes BinaryTransform::globalSection() lebEncodeSigned(0) + toBytes(Opcode::End); - return makeSection(Section::GLOBAL, std::move(result)); + return makeSection(Section::GLOBAL, move(result)); } -bytes BinaryTransform::exportSection() +bytes BinaryTransform::exportSection(map const& _functionIDs) { bytes result = lebEncode(2); result += encodeName("memory") + toBytes(Export::Memory) + lebEncode(0); - result += encodeName("main") + toBytes(Export::Function) + lebEncode(m_functions.at("main")); - return makeSection(Section::EXPORT, std::move(result)); + result += encodeName("main") + toBytes(Export::Function) + lebEncode(_functionIDs.at("main")); + return makeSection(Section::EXPORT, move(result)); } bytes BinaryTransform::customSection(string const& _name, bytes _data) { - bytes result = encodeName(_name) + std::move(_data); - return makeSection(Section::CUSTOM, std::move(result)); + bytes result = encodeName(_name) + move(_data); + return makeSection(Section::CUSTOM, move(result)); } bytes BinaryTransform::codeSection(vector const& _functions) @@ -589,7 +642,7 @@ bytes BinaryTransform::codeSection(vector const& _func bytes result = lebEncode(_functions.size()); for (FunctionDefinition const& fun: _functions) result += (*this)(fun); - return makeSection(Section::CODE, std::move(result)); + return makeSection(Section::CODE, move(result)); } bytes BinaryTransform::visit(vector const& _expressions) @@ -620,7 +673,7 @@ bytes BinaryTransform::encodeLabelIdx(string const& _label) const yulAssert(false, "Label not found."); } -bytes BinaryTransform::encodeName(std::string const& _name) +bytes BinaryTransform::encodeName(string const& _name) { // UTF-8 is allowed here by the Wasm spec, but since all names here should stem from // Solidity or Yul identifiers or similar, non-ascii characters ending up here diff --git a/libyul/backends/wasm/BinaryTransform.h b/libyul/backends/wasm/BinaryTransform.h index c1c924dcc..f4444b2fd 100644 --- a/libyul/backends/wasm/BinaryTransform.h +++ b/libyul/backends/wasm/BinaryTransform.h @@ -55,23 +55,49 @@ public: bytes operator()(wasm::FunctionDefinition const& _function); private: + BinaryTransform( + std::map _globalIDs, + std::map _functionIDs, + std::map _functionTypes, + std::map> _subModulePosAndSize + ): + m_globalIDs(std::move(_globalIDs)), + m_functionIDs(std::move(_functionIDs)), + m_functionTypes(std::move(_functionTypes)), + m_subModulePosAndSize(std::move(_subModulePosAndSize)) + {} + using Type = std::pair, std::vector>; static Type typeOf(wasm::FunctionImport const& _import); static Type typeOf(wasm::FunctionDefinition const& _funDef); static uint8_t encodeType(std::string const& _typeName); static std::vector encodeTypes(std::vector const& _typeNames); - bytes typeSection( + + static std::map> typeToFunctionMap( std::vector const& _imports, std::vector const& _functions ); - bytes importSection(std::vector const& _imports); - bytes functionSection(std::vector const& _functions); - bytes memorySection(); - bytes globalSection(); - bytes exportSection(); - bytes customSection(std::string const& _name, bytes _data); + static std::map enumerateGlobals(Module const& _module); + static std::map enumerateFunctions(Module const& _module); + static std::map enumerateFunctionTypes( + std::map> const& _typeToFunctionMap + ); + + static bytes typeSection(std::map> const& _typeToFunctionMap); + static bytes importSection( + std::vector const& _imports, + std::map const& _functionTypes + ); + static bytes functionSection( + std::vector const& _functions, + std::map const& _functionTypes + ); + static bytes memorySection(); + static bytes globalSection(std::vector const& _globals); + static bytes exportSection(std::map const& _functionIDs); + static bytes customSection(std::string const& _name, bytes _data); bytes codeSection(std::vector const& _functions); bytes visit(std::vector const& _expressions); @@ -81,12 +107,13 @@ private: static bytes encodeName(std::string const& _name); + std::map const m_globalIDs; + std::map const m_functionIDs; + std::map const m_functionTypes; + std::map> const m_subModulePosAndSize; + std::map m_locals; - std::map m_globals; - std::map m_functions; - std::map m_functionTypes; std::vector m_labels; - std::map> m_subModulePosAndSize; }; }