Merge pull request #9151 from ethereum/wasm-binary-transform-refactor-index-registration

Refactor the index assignment logic in wasm::BinaryTransform
This commit is contained in:
chriseth 2020-06-10 10:39:33 +02:00 committed by GitHub
commit d2e9b4e946
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 143 additions and 63 deletions

View File

@ -24,6 +24,7 @@
#include <libsolutil/CommonData.h> #include <libsolutil/CommonData.h>
#include <boost/range/adaptor/reversed.hpp> #include <boost/range/adaptor/reversed.hpp>
#include <boost/range/adaptor/map.hpp>
using namespace std; using namespace std;
using namespace solidity; using namespace solidity;
@ -130,7 +131,7 @@ bytes toBytes(Opcode _o)
return toBytes(uint8_t(_o)); return toBytes(uint8_t(_o));
} }
static std::map<string, uint8_t> const builtins = { static map<string, uint8_t> const builtins = {
{"i32.load", 0x28}, {"i32.load", 0x28},
{"i64.load", 0x29}, {"i64.load", 0x29},
{"i32.load8_s", 0x2c}, {"i32.load8_s", 0x2c},
@ -240,46 +241,56 @@ bytes lebEncodeSigned(int64_t _n)
bytes prefixSize(bytes _data) bytes prefixSize(bytes _data)
{ {
size_t size = _data.size(); size_t size = _data.size();
return lebEncode(size) + std::move(_data); return lebEncode(size) + move(_data);
} }
bytes makeSection(Section _section, bytes _data) bytes makeSection(Section _section, bytes _data)
{ {
return toBytes(_section) + prefixSize(std::move(_data)); return toBytes(_section) + prefixSize(move(_data));
} }
} }
bytes BinaryTransform::run(Module const& _module) bytes BinaryTransform::run(Module const& _module)
{ {
BinaryTransform bt; map<Type, vector<string>> const types = typeToFunctionMap(_module.imports, _module.functions);
for (size_t i = 0; i < _module.globals.size(); ++i) map<string, size_t> const globalIDs = enumerateGlobals(_module);
bt.m_globals[_module.globals[i].variableName] = i; map<string, size_t> const functionIDs = enumerateFunctions(_module);
map<string, size_t> const functionTypes = enumerateFunctionTypes(types);
size_t funID = 0; yulAssert(globalIDs.size() == _module.globals.size(), "");
for (FunctionImport const& fun: _module.imports) yulAssert(functionIDs.size() == _module.imports.size() + _module.functions.size(), "");
bt.m_functions[fun.internalName] = funID++; yulAssert(functionTypes.size() == functionIDs.size(), "");
for (FunctionDefinition const& fun: _module.functions) yulAssert(functionTypes.size() >= types.size(), "");
bt.m_functions[fun.name] = funID++;
bytes ret{0, 'a', 's', 'm'}; bytes ret{0, 'a', 's', 'm'};
// version // version
ret += bytes{1, 0, 0, 0}; ret += bytes{1, 0, 0, 0};
ret += bt.typeSection(_module.imports, _module.functions); ret += typeSection(types);
ret += bt.importSection(_module.imports); ret += importSection(_module.imports, functionTypes);
ret += bt.functionSection(_module.functions); ret += functionSection(_module.functions, functionTypes);
ret += bt.memorySection(); ret += memorySection();
ret += bt.globalSection(); ret += globalSection(_module.globals);
ret += bt.exportSection(); ret += exportSection(functionIDs);
map<string, pair<size_t, size_t>> subModulePosAndSize;
for (auto const& sub: _module.subModules) for (auto const& sub: _module.subModules)
{ {
// TODO should we prefix and / or shorten the name? // TODO should we prefix and / or shorten the name?
bytes data = BinaryTransform::run(sub.second); bytes data = BinaryTransform::run(sub.second);
size_t length = data.size(); size_t length = data.size();
ret += bt.customSection(sub.first, std::move(data)); ret += customSection(sub.first, move(data));
bt.m_subModulePosAndSize[sub.first] = {ret.size() - length, length}; subModulePosAndSize[sub.first] = {ret.size() - length, length};
} }
BinaryTransform bt(
move(globalIDs),
move(functionIDs),
move(functionTypes),
move(subModulePosAndSize)
);
ret += bt.codeSection(_module.functions); ret += bt.codeSection(_module.functions);
return ret; return ret;
} }
@ -302,7 +313,7 @@ bytes BinaryTransform::operator()(LocalVariable const& _variable)
bytes BinaryTransform::operator()(GlobalVariable const& _variable) bytes BinaryTransform::operator()(GlobalVariable const& _variable)
{ {
return toBytes(Opcode::GlobalGet) + lebEncode(m_globals.at(_variable.name)); return toBytes(Opcode::GlobalGet) + lebEncode(m_globalIDs.at(_variable.name));
} }
bytes BinaryTransform::operator()(BuiltinCall const& _call) bytes BinaryTransform::operator()(BuiltinCall const& _call)
@ -311,12 +322,12 @@ bytes BinaryTransform::operator()(BuiltinCall const& _call)
// they are references to object names that should not end up in the code. // they are references to object names that should not end up in the code.
if (_call.functionName == "dataoffset") if (_call.functionName == "dataoffset")
{ {
string name = std::get<StringLiteral>(_call.arguments.at(0)).value; string name = get<StringLiteral>(_call.arguments.at(0)).value;
return toBytes(Opcode::I64Const) + lebEncodeSigned(m_subModulePosAndSize.at(name).first); return toBytes(Opcode::I64Const) + lebEncodeSigned(m_subModulePosAndSize.at(name).first);
} }
else if (_call.functionName == "datasize") else if (_call.functionName == "datasize")
{ {
string name = std::get<StringLiteral>(_call.arguments.at(0)).value; string name = get<StringLiteral>(_call.arguments.at(0)).value;
return toBytes(Opcode::I64Const) + lebEncodeSigned(m_subModulePosAndSize.at(name).second); return toBytes(Opcode::I64Const) + lebEncodeSigned(m_subModulePosAndSize.at(name).second);
} }
@ -331,7 +342,7 @@ bytes BinaryTransform::operator()(BuiltinCall const& _call)
else else
{ {
yulAssert(builtins.count(_call.functionName), "Builtin " + _call.functionName + " not found"); yulAssert(builtins.count(_call.functionName), "Builtin " + _call.functionName + " not found");
bytes ret = std::move(args) + toBytes(builtins.at(_call.functionName)); bytes ret = move(args) + toBytes(builtins.at(_call.functionName));
if ( if (
_call.functionName.find(".load") != string::npos || _call.functionName.find(".load") != string::npos ||
_call.functionName.find(".store") != string::npos _call.functionName.find(".store") != string::npos
@ -348,7 +359,7 @@ bytes BinaryTransform::operator()(BuiltinCall const& _call)
bytes BinaryTransform::operator()(FunctionCall const& _call) bytes BinaryTransform::operator()(FunctionCall const& _call)
{ {
return visit(_call.arguments) + toBytes(Opcode::Call) + lebEncode(m_functions.at(_call.functionName)); return visit(_call.arguments) + toBytes(Opcode::Call) + lebEncode(m_functionIDs.at(_call.functionName));
} }
bytes BinaryTransform::operator()(LocalAssignment const& _assignment) bytes BinaryTransform::operator()(LocalAssignment const& _assignment)
@ -364,7 +375,7 @@ bytes BinaryTransform::operator()(GlobalAssignment const& _assignment)
return return
std::visit(*this, *_assignment.value) + std::visit(*this, *_assignment.value) +
toBytes(Opcode::GlobalSet) + toBytes(Opcode::GlobalSet) +
lebEncode(m_globals.at(_assignment.variableName)); lebEncode(m_globalIDs.at(_assignment.variableName));
} }
bytes BinaryTransform::operator()(If const& _if) bytes BinaryTransform::operator()(If const& _if)
@ -457,7 +468,7 @@ bytes BinaryTransform::operator()(FunctionDefinition const& _function)
yulAssert(m_labels.empty(), "Stray labels."); yulAssert(m_labels.empty(), "Stray labels.");
return prefixSize(std::move(ret)); return prefixSize(move(ret));
} }
BinaryTransform::Type BinaryTransform::typeOf(FunctionImport const& _import) BinaryTransform::Type BinaryTransform::typeOf(FunctionImport const& _import)
@ -496,9 +507,9 @@ vector<uint8_t> BinaryTransform::encodeTypes(vector<string> const& _typeNames)
return result; return result;
} }
bytes BinaryTransform::typeSection( map<BinaryTransform::Type, vector<string>> BinaryTransform::typeToFunctionMap(
vector<FunctionImport> const& _imports, vector<wasm::FunctionImport> const& _imports,
vector<FunctionDefinition> const& _functions vector<wasm::FunctionDefinition> const& _functions
) )
{ {
map<Type, vector<string>> types; map<Type, vector<string>> types;
@ -507,12 +518,50 @@ bytes BinaryTransform::typeSection(
for (auto const& fun: _functions) for (auto const& fun: _functions)
types[typeOf(fun)].emplace_back(fun.name); types[typeOf(fun)].emplace_back(fun.name);
bytes result; return types;
size_t index = 0; }
for (auto const& [type, funNames]: types)
map<string, size_t> BinaryTransform::enumerateGlobals(Module const& _module)
{
map<string, size_t> globals;
for (size_t i = 0; i < _module.globals.size(); ++i)
globals[_module.globals[i].variableName] = i;
return globals;
}
map<string, size_t> BinaryTransform::enumerateFunctions(Module const& _module)
{
map<string, size_t> functions;
size_t funID = 0;
for (FunctionImport const& fun: _module.imports)
functions[fun.internalName] = funID++;
for (FunctionDefinition const& fun: _module.functions)
functions[fun.name] = funID++;
return functions;
}
map<string, size_t> BinaryTransform::enumerateFunctionTypes(map<Type, vector<string>> const& _typeToFunctionMap)
{
map<string, size_t> functionTypes;
size_t typeID = 0;
for (vector<string> const& funNames: _typeToFunctionMap | boost::adaptors::map_values)
{ {
for (string const& name: funNames) for (string const& name: funNames)
m_functionTypes[name] = index; functionTypes[name] = typeID;
++typeID;
}
return functionTypes;
}
bytes BinaryTransform::typeSection(map<BinaryTransform::Type, vector<string>> const& _typeToFunctionMap)
{
bytes result;
size_t index = 0;
for (Type const& type: _typeToFunctionMap | boost::adaptors::map_keys)
{
result += toBytes(ValueType::Function); result += toBytes(ValueType::Function);
result += lebEncode(type.first.size()) + type.first; result += lebEncode(type.first.size()) + type.first;
result += lebEncode(type.second.size()) + type.second; result += lebEncode(type.second.size()) + type.second;
@ -520,11 +569,12 @@ bytes BinaryTransform::typeSection(
index++; index++;
} }
return makeSection(Section::TYPE, lebEncode(index) + std::move(result)); return makeSection(Section::TYPE, lebEncode(index) + move(result));
} }
bytes BinaryTransform::importSection( bytes BinaryTransform::importSection(
vector<FunctionImport> const& _imports vector<FunctionImport> const& _imports,
map<string, size_t> const& _functionTypes
) )
{ {
bytes result = lebEncode(_imports.size()); bytes result = lebEncode(_imports.size());
@ -535,17 +585,20 @@ bytes BinaryTransform::importSection(
encodeName(import.module) + encodeName(import.module) +
encodeName(import.externalName) + encodeName(import.externalName) +
toBytes(importKind) + toBytes(importKind) +
lebEncode(m_functionTypes[import.internalName]); lebEncode(_functionTypes.at(import.internalName));
} }
return makeSection(Section::IMPORT, std::move(result)); return makeSection(Section::IMPORT, move(result));
} }
bytes BinaryTransform::functionSection(vector<FunctionDefinition> const& _functions) bytes BinaryTransform::functionSection(
vector<FunctionDefinition> const& _functions,
map<string, size_t> const& _functionTypes
)
{ {
bytes result = lebEncode(_functions.size()); bytes result = lebEncode(_functions.size());
for (auto const& fun: _functions) for (auto const& fun: _functions)
result += lebEncode(m_functionTypes.at(fun.name)); result += lebEncode(_functionTypes.at(fun.name));
return makeSection(Section::FUNCTION, std::move(result)); return makeSection(Section::FUNCTION, move(result));
} }
bytes BinaryTransform::memorySection() bytes BinaryTransform::memorySection()
@ -553,13 +606,13 @@ bytes BinaryTransform::memorySection()
bytes result = lebEncode(1); bytes result = lebEncode(1);
result.push_back(static_cast<uint8_t>(LimitsKind::Min)); result.push_back(static_cast<uint8_t>(LimitsKind::Min));
result.push_back(1); // initial length result.push_back(1); // initial length
return makeSection(Section::MEMORY, std::move(result)); return makeSection(Section::MEMORY, move(result));
} }
bytes BinaryTransform::globalSection() bytes BinaryTransform::globalSection(vector<wasm::GlobalVariableDeclaration> const& _globals)
{ {
bytes result = lebEncode(m_globals.size()); bytes result = lebEncode(_globals.size());
for (size_t i = 0; i < m_globals.size(); ++i) for (size_t i = 0; i < _globals.size(); ++i)
result += result +=
toBytes(ValueType::I64) + toBytes(ValueType::I64) +
lebEncode(static_cast<uint8_t>(Mutability::Var)) + lebEncode(static_cast<uint8_t>(Mutability::Var)) +
@ -567,21 +620,21 @@ bytes BinaryTransform::globalSection()
lebEncodeSigned(0) + lebEncodeSigned(0) +
toBytes(Opcode::End); toBytes(Opcode::End);
return makeSection(Section::GLOBAL, std::move(result)); return makeSection(Section::GLOBAL, move(result));
} }
bytes BinaryTransform::exportSection() bytes BinaryTransform::exportSection(map<string, size_t> const& _functionIDs)
{ {
bytes result = lebEncode(2); bytes result = lebEncode(2);
result += encodeName("memory") + toBytes(Export::Memory) + lebEncode(0); result += encodeName("memory") + toBytes(Export::Memory) + lebEncode(0);
result += encodeName("main") + toBytes(Export::Function) + lebEncode(m_functions.at("main")); result += encodeName("main") + toBytes(Export::Function) + lebEncode(_functionIDs.at("main"));
return makeSection(Section::EXPORT, std::move(result)); return makeSection(Section::EXPORT, move(result));
} }
bytes BinaryTransform::customSection(string const& _name, bytes _data) bytes BinaryTransform::customSection(string const& _name, bytes _data)
{ {
bytes result = encodeName(_name) + std::move(_data); bytes result = encodeName(_name) + move(_data);
return makeSection(Section::CUSTOM, std::move(result)); return makeSection(Section::CUSTOM, move(result));
} }
bytes BinaryTransform::codeSection(vector<wasm::FunctionDefinition> const& _functions) bytes BinaryTransform::codeSection(vector<wasm::FunctionDefinition> const& _functions)
@ -589,7 +642,7 @@ bytes BinaryTransform::codeSection(vector<wasm::FunctionDefinition> const& _func
bytes result = lebEncode(_functions.size()); bytes result = lebEncode(_functions.size());
for (FunctionDefinition const& fun: _functions) for (FunctionDefinition const& fun: _functions)
result += (*this)(fun); result += (*this)(fun);
return makeSection(Section::CODE, std::move(result)); return makeSection(Section::CODE, move(result));
} }
bytes BinaryTransform::visit(vector<Expression> const& _expressions) bytes BinaryTransform::visit(vector<Expression> const& _expressions)
@ -620,7 +673,7 @@ bytes BinaryTransform::encodeLabelIdx(string const& _label) const
yulAssert(false, "Label not found."); yulAssert(false, "Label not found.");
} }
bytes BinaryTransform::encodeName(std::string const& _name) bytes BinaryTransform::encodeName(string const& _name)
{ {
// UTF-8 is allowed here by the Wasm spec, but since all names here should stem from // UTF-8 is allowed here by the Wasm spec, but since all names here should stem from
// Solidity or Yul identifiers or similar, non-ascii characters ending up here // Solidity or Yul identifiers or similar, non-ascii characters ending up here

View File

@ -55,23 +55,49 @@ public:
bytes operator()(wasm::FunctionDefinition const& _function); bytes operator()(wasm::FunctionDefinition const& _function);
private: private:
BinaryTransform(
std::map<std::string, size_t> _globalIDs,
std::map<std::string, size_t> _functionIDs,
std::map<std::string, size_t> _functionTypes,
std::map<std::string, std::pair<size_t, size_t>> _subModulePosAndSize
):
m_globalIDs(std::move(_globalIDs)),
m_functionIDs(std::move(_functionIDs)),
m_functionTypes(std::move(_functionTypes)),
m_subModulePosAndSize(std::move(_subModulePosAndSize))
{}
using Type = std::pair<std::vector<std::uint8_t>, std::vector<std::uint8_t>>; using Type = std::pair<std::vector<std::uint8_t>, std::vector<std::uint8_t>>;
static Type typeOf(wasm::FunctionImport const& _import); static Type typeOf(wasm::FunctionImport const& _import);
static Type typeOf(wasm::FunctionDefinition const& _funDef); static Type typeOf(wasm::FunctionDefinition const& _funDef);
static uint8_t encodeType(std::string const& _typeName); static uint8_t encodeType(std::string const& _typeName);
static std::vector<uint8_t> encodeTypes(std::vector<std::string> const& _typeNames); static std::vector<uint8_t> encodeTypes(std::vector<std::string> const& _typeNames);
bytes typeSection(
static std::map<Type, std::vector<std::string>> typeToFunctionMap(
std::vector<wasm::FunctionImport> const& _imports, std::vector<wasm::FunctionImport> const& _imports,
std::vector<wasm::FunctionDefinition> const& _functions std::vector<wasm::FunctionDefinition> const& _functions
); );
bytes importSection(std::vector<wasm::FunctionImport> const& _imports); static std::map<std::string, size_t> enumerateGlobals(Module const& _module);
bytes functionSection(std::vector<wasm::FunctionDefinition> const& _functions); static std::map<std::string, size_t> enumerateFunctions(Module const& _module);
bytes memorySection(); static std::map<std::string, size_t> enumerateFunctionTypes(
bytes globalSection(); std::map<Type, std::vector<std::string>> const& _typeToFunctionMap
bytes exportSection(); );
bytes customSection(std::string const& _name, bytes _data);
static bytes typeSection(std::map<Type, std::vector<std::string>> const& _typeToFunctionMap);
static bytes importSection(
std::vector<wasm::FunctionImport> const& _imports,
std::map<std::string, size_t> const& _functionTypes
);
static bytes functionSection(
std::vector<wasm::FunctionDefinition> const& _functions,
std::map<std::string, size_t> const& _functionTypes
);
static bytes memorySection();
static bytes globalSection(std::vector<wasm::GlobalVariableDeclaration> const& _globals);
static bytes exportSection(std::map<std::string, size_t> const& _functionIDs);
static bytes customSection(std::string const& _name, bytes _data);
bytes codeSection(std::vector<wasm::FunctionDefinition> const& _functions); bytes codeSection(std::vector<wasm::FunctionDefinition> const& _functions);
bytes visit(std::vector<wasm::Expression> const& _expressions); bytes visit(std::vector<wasm::Expression> const& _expressions);
@ -81,12 +107,13 @@ private:
static bytes encodeName(std::string const& _name); static bytes encodeName(std::string const& _name);
std::map<std::string, size_t> const m_globalIDs;
std::map<std::string, size_t> const m_functionIDs;
std::map<std::string, size_t> const m_functionTypes;
std::map<std::string, std::pair<size_t, size_t>> const m_subModulePosAndSize;
std::map<std::string, size_t> m_locals; std::map<std::string, size_t> m_locals;
std::map<std::string, size_t> m_globals;
std::map<std::string, size_t> m_functions;
std::map<std::string, size_t> m_functionTypes;
std::vector<std::string> m_labels; std::vector<std::string> m_labels;
std::map<std::string, std::pair<size_t, size_t>> m_subModulePosAndSize;
}; };
} }