Merge pull request #9151 from ethereum/wasm-binary-transform-refactor-index-registration

Refactor the index assignment logic in wasm::BinaryTransform
This commit is contained in:
chriseth 2020-06-10 10:39:33 +02:00 committed by GitHub
commit d2e9b4e946
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 143 additions and 63 deletions

View File

@ -24,6 +24,7 @@
#include <libsolutil/CommonData.h>
#include <boost/range/adaptor/reversed.hpp>
#include <boost/range/adaptor/map.hpp>
using namespace std;
using namespace solidity;
@ -130,7 +131,7 @@ bytes toBytes(Opcode _o)
return toBytes(uint8_t(_o));
}
static std::map<string, uint8_t> const builtins = {
static map<string, uint8_t> const builtins = {
{"i32.load", 0x28},
{"i64.load", 0x29},
{"i32.load8_s", 0x2c},
@ -240,46 +241,56 @@ bytes lebEncodeSigned(int64_t _n)
bytes prefixSize(bytes _data)
{
size_t size = _data.size();
return lebEncode(size) + std::move(_data);
return lebEncode(size) + move(_data);
}
bytes makeSection(Section _section, bytes _data)
{
return toBytes(_section) + prefixSize(std::move(_data));
return toBytes(_section) + prefixSize(move(_data));
}
}
bytes BinaryTransform::run(Module const& _module)
{
BinaryTransform bt;
map<Type, vector<string>> const types = typeToFunctionMap(_module.imports, _module.functions);
for (size_t i = 0; i < _module.globals.size(); ++i)
bt.m_globals[_module.globals[i].variableName] = i;
map<string, size_t> const globalIDs = enumerateGlobals(_module);
map<string, size_t> const functionIDs = enumerateFunctions(_module);
map<string, size_t> const functionTypes = enumerateFunctionTypes(types);
size_t funID = 0;
for (FunctionImport const& fun: _module.imports)
bt.m_functions[fun.internalName] = funID++;
for (FunctionDefinition const& fun: _module.functions)
bt.m_functions[fun.name] = funID++;
yulAssert(globalIDs.size() == _module.globals.size(), "");
yulAssert(functionIDs.size() == _module.imports.size() + _module.functions.size(), "");
yulAssert(functionTypes.size() == functionIDs.size(), "");
yulAssert(functionTypes.size() >= types.size(), "");
bytes ret{0, 'a', 's', 'm'};
// version
ret += bytes{1, 0, 0, 0};
ret += bt.typeSection(_module.imports, _module.functions);
ret += bt.importSection(_module.imports);
ret += bt.functionSection(_module.functions);
ret += bt.memorySection();
ret += bt.globalSection();
ret += bt.exportSection();
ret += typeSection(types);
ret += importSection(_module.imports, functionTypes);
ret += functionSection(_module.functions, functionTypes);
ret += memorySection();
ret += globalSection(_module.globals);
ret += exportSection(functionIDs);
map<string, pair<size_t, size_t>> subModulePosAndSize;
for (auto const& sub: _module.subModules)
{
// TODO should we prefix and / or shorten the name?
bytes data = BinaryTransform::run(sub.second);
size_t length = data.size();
ret += bt.customSection(sub.first, std::move(data));
bt.m_subModulePosAndSize[sub.first] = {ret.size() - length, length};
ret += customSection(sub.first, move(data));
subModulePosAndSize[sub.first] = {ret.size() - length, length};
}
BinaryTransform bt(
move(globalIDs),
move(functionIDs),
move(functionTypes),
move(subModulePosAndSize)
);
ret += bt.codeSection(_module.functions);
return ret;
}
@ -302,7 +313,7 @@ bytes BinaryTransform::operator()(LocalVariable const& _variable)
bytes BinaryTransform::operator()(GlobalVariable const& _variable)
{
return toBytes(Opcode::GlobalGet) + lebEncode(m_globals.at(_variable.name));
return toBytes(Opcode::GlobalGet) + lebEncode(m_globalIDs.at(_variable.name));
}
bytes BinaryTransform::operator()(BuiltinCall const& _call)
@ -311,12 +322,12 @@ bytes BinaryTransform::operator()(BuiltinCall const& _call)
// they are references to object names that should not end up in the code.
if (_call.functionName == "dataoffset")
{
string name = std::get<StringLiteral>(_call.arguments.at(0)).value;
string name = get<StringLiteral>(_call.arguments.at(0)).value;
return toBytes(Opcode::I64Const) + lebEncodeSigned(m_subModulePosAndSize.at(name).first);
}
else if (_call.functionName == "datasize")
{
string name = std::get<StringLiteral>(_call.arguments.at(0)).value;
string name = get<StringLiteral>(_call.arguments.at(0)).value;
return toBytes(Opcode::I64Const) + lebEncodeSigned(m_subModulePosAndSize.at(name).second);
}
@ -331,7 +342,7 @@ bytes BinaryTransform::operator()(BuiltinCall const& _call)
else
{
yulAssert(builtins.count(_call.functionName), "Builtin " + _call.functionName + " not found");
bytes ret = std::move(args) + toBytes(builtins.at(_call.functionName));
bytes ret = move(args) + toBytes(builtins.at(_call.functionName));
if (
_call.functionName.find(".load") != string::npos ||
_call.functionName.find(".store") != string::npos
@ -348,7 +359,7 @@ bytes BinaryTransform::operator()(BuiltinCall const& _call)
bytes BinaryTransform::operator()(FunctionCall const& _call)
{
return visit(_call.arguments) + toBytes(Opcode::Call) + lebEncode(m_functions.at(_call.functionName));
return visit(_call.arguments) + toBytes(Opcode::Call) + lebEncode(m_functionIDs.at(_call.functionName));
}
bytes BinaryTransform::operator()(LocalAssignment const& _assignment)
@ -364,7 +375,7 @@ bytes BinaryTransform::operator()(GlobalAssignment const& _assignment)
return
std::visit(*this, *_assignment.value) +
toBytes(Opcode::GlobalSet) +
lebEncode(m_globals.at(_assignment.variableName));
lebEncode(m_globalIDs.at(_assignment.variableName));
}
bytes BinaryTransform::operator()(If const& _if)
@ -457,7 +468,7 @@ bytes BinaryTransform::operator()(FunctionDefinition const& _function)
yulAssert(m_labels.empty(), "Stray labels.");
return prefixSize(std::move(ret));
return prefixSize(move(ret));
}
BinaryTransform::Type BinaryTransform::typeOf(FunctionImport const& _import)
@ -496,9 +507,9 @@ vector<uint8_t> BinaryTransform::encodeTypes(vector<string> const& _typeNames)
return result;
}
bytes BinaryTransform::typeSection(
vector<FunctionImport> const& _imports,
vector<FunctionDefinition> const& _functions
map<BinaryTransform::Type, vector<string>> BinaryTransform::typeToFunctionMap(
vector<wasm::FunctionImport> const& _imports,
vector<wasm::FunctionDefinition> const& _functions
)
{
map<Type, vector<string>> types;
@ -507,12 +518,50 @@ bytes BinaryTransform::typeSection(
for (auto const& fun: _functions)
types[typeOf(fun)].emplace_back(fun.name);
bytes result;
size_t index = 0;
for (auto const& [type, funNames]: types)
return types;
}
map<string, size_t> BinaryTransform::enumerateGlobals(Module const& _module)
{
map<string, size_t> globals;
for (size_t i = 0; i < _module.globals.size(); ++i)
globals[_module.globals[i].variableName] = i;
return globals;
}
map<string, size_t> BinaryTransform::enumerateFunctions(Module const& _module)
{
map<string, size_t> functions;
size_t funID = 0;
for (FunctionImport const& fun: _module.imports)
functions[fun.internalName] = funID++;
for (FunctionDefinition const& fun: _module.functions)
functions[fun.name] = funID++;
return functions;
}
map<string, size_t> BinaryTransform::enumerateFunctionTypes(map<Type, vector<string>> const& _typeToFunctionMap)
{
map<string, size_t> functionTypes;
size_t typeID = 0;
for (vector<string> const& funNames: _typeToFunctionMap | boost::adaptors::map_values)
{
for (string const& name: funNames)
m_functionTypes[name] = index;
functionTypes[name] = typeID;
++typeID;
}
return functionTypes;
}
bytes BinaryTransform::typeSection(map<BinaryTransform::Type, vector<string>> const& _typeToFunctionMap)
{
bytes result;
size_t index = 0;
for (Type const& type: _typeToFunctionMap | boost::adaptors::map_keys)
{
result += toBytes(ValueType::Function);
result += lebEncode(type.first.size()) + type.first;
result += lebEncode(type.second.size()) + type.second;
@ -520,11 +569,12 @@ bytes BinaryTransform::typeSection(
index++;
}
return makeSection(Section::TYPE, lebEncode(index) + std::move(result));
return makeSection(Section::TYPE, lebEncode(index) + move(result));
}
bytes BinaryTransform::importSection(
vector<FunctionImport> const& _imports
vector<FunctionImport> const& _imports,
map<string, size_t> const& _functionTypes
)
{
bytes result = lebEncode(_imports.size());
@ -535,17 +585,20 @@ bytes BinaryTransform::importSection(
encodeName(import.module) +
encodeName(import.externalName) +
toBytes(importKind) +
lebEncode(m_functionTypes[import.internalName]);
lebEncode(_functionTypes.at(import.internalName));
}
return makeSection(Section::IMPORT, std::move(result));
return makeSection(Section::IMPORT, move(result));
}
bytes BinaryTransform::functionSection(vector<FunctionDefinition> const& _functions)
bytes BinaryTransform::functionSection(
vector<FunctionDefinition> const& _functions,
map<string, size_t> const& _functionTypes
)
{
bytes result = lebEncode(_functions.size());
for (auto const& fun: _functions)
result += lebEncode(m_functionTypes.at(fun.name));
return makeSection(Section::FUNCTION, std::move(result));
result += lebEncode(_functionTypes.at(fun.name));
return makeSection(Section::FUNCTION, move(result));
}
bytes BinaryTransform::memorySection()
@ -553,13 +606,13 @@ bytes BinaryTransform::memorySection()
bytes result = lebEncode(1);
result.push_back(static_cast<uint8_t>(LimitsKind::Min));
result.push_back(1); // initial length
return makeSection(Section::MEMORY, std::move(result));
return makeSection(Section::MEMORY, move(result));
}
bytes BinaryTransform::globalSection()
bytes BinaryTransform::globalSection(vector<wasm::GlobalVariableDeclaration> const& _globals)
{
bytes result = lebEncode(m_globals.size());
for (size_t i = 0; i < m_globals.size(); ++i)
bytes result = lebEncode(_globals.size());
for (size_t i = 0; i < _globals.size(); ++i)
result +=
toBytes(ValueType::I64) +
lebEncode(static_cast<uint8_t>(Mutability::Var)) +
@ -567,21 +620,21 @@ bytes BinaryTransform::globalSection()
lebEncodeSigned(0) +
toBytes(Opcode::End);
return makeSection(Section::GLOBAL, std::move(result));
return makeSection(Section::GLOBAL, move(result));
}
bytes BinaryTransform::exportSection()
bytes BinaryTransform::exportSection(map<string, size_t> const& _functionIDs)
{
bytes result = lebEncode(2);
result += encodeName("memory") + toBytes(Export::Memory) + lebEncode(0);
result += encodeName("main") + toBytes(Export::Function) + lebEncode(m_functions.at("main"));
return makeSection(Section::EXPORT, std::move(result));
result += encodeName("main") + toBytes(Export::Function) + lebEncode(_functionIDs.at("main"));
return makeSection(Section::EXPORT, move(result));
}
bytes BinaryTransform::customSection(string const& _name, bytes _data)
{
bytes result = encodeName(_name) + std::move(_data);
return makeSection(Section::CUSTOM, std::move(result));
bytes result = encodeName(_name) + move(_data);
return makeSection(Section::CUSTOM, move(result));
}
bytes BinaryTransform::codeSection(vector<wasm::FunctionDefinition> const& _functions)
@ -589,7 +642,7 @@ bytes BinaryTransform::codeSection(vector<wasm::FunctionDefinition> const& _func
bytes result = lebEncode(_functions.size());
for (FunctionDefinition const& fun: _functions)
result += (*this)(fun);
return makeSection(Section::CODE, std::move(result));
return makeSection(Section::CODE, move(result));
}
bytes BinaryTransform::visit(vector<Expression> const& _expressions)
@ -620,7 +673,7 @@ bytes BinaryTransform::encodeLabelIdx(string const& _label) const
yulAssert(false, "Label not found.");
}
bytes BinaryTransform::encodeName(std::string const& _name)
bytes BinaryTransform::encodeName(string const& _name)
{
// UTF-8 is allowed here by the Wasm spec, but since all names here should stem from
// Solidity or Yul identifiers or similar, non-ascii characters ending up here

View File

@ -55,23 +55,49 @@ public:
bytes operator()(wasm::FunctionDefinition const& _function);
private:
BinaryTransform(
std::map<std::string, size_t> _globalIDs,
std::map<std::string, size_t> _functionIDs,
std::map<std::string, size_t> _functionTypes,
std::map<std::string, std::pair<size_t, size_t>> _subModulePosAndSize
):
m_globalIDs(std::move(_globalIDs)),
m_functionIDs(std::move(_functionIDs)),
m_functionTypes(std::move(_functionTypes)),
m_subModulePosAndSize(std::move(_subModulePosAndSize))
{}
using Type = std::pair<std::vector<std::uint8_t>, std::vector<std::uint8_t>>;
static Type typeOf(wasm::FunctionImport const& _import);
static Type typeOf(wasm::FunctionDefinition const& _funDef);
static uint8_t encodeType(std::string const& _typeName);
static std::vector<uint8_t> encodeTypes(std::vector<std::string> const& _typeNames);
bytes typeSection(
static std::map<Type, std::vector<std::string>> typeToFunctionMap(
std::vector<wasm::FunctionImport> const& _imports,
std::vector<wasm::FunctionDefinition> const& _functions
);
bytes importSection(std::vector<wasm::FunctionImport> const& _imports);
bytes functionSection(std::vector<wasm::FunctionDefinition> const& _functions);
bytes memorySection();
bytes globalSection();
bytes exportSection();
bytes customSection(std::string const& _name, bytes _data);
static std::map<std::string, size_t> enumerateGlobals(Module const& _module);
static std::map<std::string, size_t> enumerateFunctions(Module const& _module);
static std::map<std::string, size_t> enumerateFunctionTypes(
std::map<Type, std::vector<std::string>> const& _typeToFunctionMap
);
static bytes typeSection(std::map<Type, std::vector<std::string>> const& _typeToFunctionMap);
static bytes importSection(
std::vector<wasm::FunctionImport> const& _imports,
std::map<std::string, size_t> const& _functionTypes
);
static bytes functionSection(
std::vector<wasm::FunctionDefinition> const& _functions,
std::map<std::string, size_t> const& _functionTypes
);
static bytes memorySection();
static bytes globalSection(std::vector<wasm::GlobalVariableDeclaration> const& _globals);
static bytes exportSection(std::map<std::string, size_t> const& _functionIDs);
static bytes customSection(std::string const& _name, bytes _data);
bytes codeSection(std::vector<wasm::FunctionDefinition> const& _functions);
bytes visit(std::vector<wasm::Expression> const& _expressions);
@ -81,12 +107,13 @@ private:
static bytes encodeName(std::string const& _name);
std::map<std::string, size_t> const m_globalIDs;
std::map<std::string, size_t> const m_functionIDs;
std::map<std::string, size_t> const m_functionTypes;
std::map<std::string, std::pair<size_t, size_t>> const m_subModulePosAndSize;
std::map<std::string, size_t> m_locals;
std::map<std::string, size_t> m_globals;
std::map<std::string, size_t> m_functions;
std::map<std::string, size_t> m_functionTypes;
std::vector<std::string> m_labels;
std::map<std::string, std::pair<size_t, size_t>> m_subModulePosAndSize;
};
}