[WASM] Inject type conversions on the fly if needed.

This commit is contained in:
chriseth 2019-10-31 17:31:35 +01:00
parent 8780f2d595
commit 8337de5189
3 changed files with 80 additions and 10 deletions

View File

@ -133,6 +133,8 @@ wasm::Expression EWasmCodeTransform::operator()(FunctionalInstruction const& _f)
wasm::Expression EWasmCodeTransform::operator()(FunctionCall const& _call) wasm::Expression EWasmCodeTransform::operator()(FunctionCall const& _call)
{ {
bool typeConversionNeeded = false;
if (BuiltinFunction const* builtin = m_dialect.builtin(_call.functionName.name)) if (BuiltinFunction const* builtin = m_dialect.builtin(_call.functionName.name))
{ {
if (_call.functionName.name.str().substr(0, 4) == "eth.") if (_call.functionName.name.str().substr(0, 4) == "eth.")
@ -152,6 +154,7 @@ wasm::Expression EWasmCodeTransform::operator()(FunctionCall const& _call)
imp.paramTypes.emplace_back(param.str()); imp.paramTypes.emplace_back(param.str());
m_functionsToImport[builtin->name] = std::move(imp); m_functionsToImport[builtin->name] = std::move(imp);
} }
typeConversionNeeded = true;
} }
else if (builtin->literalArguments) else if (builtin->literalArguments)
{ {
@ -161,14 +164,32 @@ wasm::Expression EWasmCodeTransform::operator()(FunctionCall const& _call)
return wasm::BuiltinCall{_call.functionName.name.str(), std::move(literals)}; return wasm::BuiltinCall{_call.functionName.name.str(), std::move(literals)};
} }
else else
return wasm::BuiltinCall{_call.functionName.name.str(), visit(_call.arguments)}; {
wasm::BuiltinCall call{
_call.functionName.name.str(),
injectTypeConversionIfNeeded(visit(_call.arguments), builtin->parameters)
};
if (!builtin->returns.empty() && !builtin->returns.front().empty() && builtin->returns.front() != "i64"_yulstring)
{
yulAssert(builtin->returns.front() == "i32"_yulstring, "Invalid type " + builtin->returns.front().str());
call = wasm::BuiltinCall{"i64.extend_i32_u", make_vector<wasm::Expression>(std::move(call))};
}
return {std::move(call)};
}
} }
// If this function returns multiple values, then the first one will // If this function returns multiple values, then the first one will
// be returned in the expression itself and the others in global variables. // be returned in the expression itself and the others in global variables.
// The values have to be used right away in an assignment or variable declaration, // The values have to be used right away in an assignment or variable declaration,
// so it is handled there. // so it is handled there.
return wasm::FunctionCall{_call.functionName.name.str(), visit(_call.arguments)};
wasm::FunctionCall funCall{_call.functionName.name.str(), visit(_call.arguments)};
if (typeConversionNeeded)
// Inject type conversion if needed on the fly. This is just a temporary measure
// and can be removed once we have proper types in Yul.
return injectTypeConversionIfNeeded(std::move(funCall));
else
return {std::move(funCall)};
} }
wasm::Expression EWasmCodeTransform::operator()(Identifier const& _identifier) wasm::Expression EWasmCodeTransform::operator()(Identifier const& _identifier)
@ -191,7 +212,16 @@ wasm::Expression EWasmCodeTransform::operator()(yul::Instruction const&)
wasm::Expression EWasmCodeTransform::operator()(If const& _if) wasm::Expression EWasmCodeTransform::operator()(If const& _if)
{ {
return wasm::If{visit(*_if.condition), visit(_if.body.statements), {}}; // TODO converting i64 to i32 might not always be needed.
vector<wasm::Expression> args;
args.emplace_back(visitReturnByValue(*_if.condition));
args.emplace_back(wasm::Literal{0});
return wasm::If{
make_unique<wasm::Expression>(wasm::BuiltinCall{"i64.ne", std::move(args)}),
visit(_if.body.statements),
{}
};
} }
wasm::Expression EWasmCodeTransform::operator()(Switch const& _switch) wasm::Expression EWasmCodeTransform::operator()(Switch const& _switch)
@ -336,6 +366,40 @@ wasm::FunctionDefinition EWasmCodeTransform::translateFunction(yul::FunctionDefi
return fun; return fun;
} }
wasm::Expression EWasmCodeTransform::injectTypeConversionIfNeeded(wasm::FunctionCall _call) const
{
wasm::FunctionImport const& import = m_functionsToImport.at(YulString{_call.functionName});
for (size_t i = 0; i < _call.arguments.size(); ++i)
if (import.paramTypes.at(i) == "i32")
_call.arguments[i] = wasm::BuiltinCall{"i32.wrap_i64", make_vector<wasm::Expression>(std::move(_call.arguments[i]))};
else
yulAssert(import.paramTypes.at(i) == "i64", "Unknown type " + import.paramTypes.at(i));
if (import.returnType && *import.returnType != "i64")
{
yulAssert(*import.returnType == "i32", "Invalid type " + *import.returnType);
return wasm::BuiltinCall{"i64.extend_i32_u", make_vector<wasm::Expression>(std::move(_call))};
}
return {std::move(_call)};
}
vector<wasm::Expression> EWasmCodeTransform::injectTypeConversionIfNeeded(
vector<wasm::Expression> _arguments,
vector<Type> const& _parameterTypes
) const
{
for (size_t i = 0; i < _arguments.size(); ++i)
if (_parameterTypes.at(i) == "i32"_yulstring)
_arguments[i] = wasm::BuiltinCall{"i32.wrap_i64", make_vector<wasm::Expression>(std::move(_arguments[i]))};
else
yulAssert(
_parameterTypes.at(i).empty() || _parameterTypes.at(i) == "i64"_yulstring,
"Unknown type " + _parameterTypes.at(i).str()
);
return _arguments;
}
string EWasmCodeTransform::newLabel() string EWasmCodeTransform::newLabel()
{ {
return m_nameDispenser.newName("label_"_yulstring).str(); return m_nameDispenser.newName("label_"_yulstring).str();

View File

@ -82,6 +82,12 @@ private:
wasm::FunctionDefinition translateFunction(yul::FunctionDefinition const& _funDef); wasm::FunctionDefinition translateFunction(yul::FunctionDefinition const& _funDef);
wasm::Expression injectTypeConversionIfNeeded(wasm::FunctionCall _call) const;
std::vector<wasm::Expression> injectTypeConversionIfNeeded(
std::vector<wasm::Expression> _arguments,
std::vector<yul::Type> const& _parameterTypes
) const;
std::string newLabel(); std::string newLabel();
/// Makes sure that there are at least @a _amount global variables. /// Makes sure that there are at least @a _amount global variables.
void allocateGlobals(size_t _amount); void allocateGlobals(size_t _amount);

View File

@ -20,7 +20,7 @@
(i64.store (i64.add (local.get $_2) (i64.const 16)) (local.get $y)) (i64.store (i64.add (local.get $_2) (i64.const 16)) (local.get $y))
(local.set $hi_1 (i64.shl (i64.or (i64.shl (i64.or (i64.and (i64.shl (i64.const 128) (i64.const 8)) (local.get $_3)) (i64.and (i64.shr_u (i64.const 128) (i64.const 8)) (i64.const 255))) (i64.const 16)) (call $endian_swap_16 (i64.shr_u (i64.const 128) (i64.const 16)))) (i64.const 32))) (local.set $hi_1 (i64.shl (i64.or (i64.shl (i64.or (i64.and (i64.shl (i64.const 128) (i64.const 8)) (local.get $_3)) (i64.and (i64.shr_u (i64.const 128) (i64.const 8)) (i64.const 255))) (i64.const 16)) (call $endian_swap_16 (i64.shr_u (i64.const 128) (i64.const 16)))) (i64.const 32)))
(i64.store (i64.add (local.get $_2) (i64.const 24)) (i64.or (local.get $hi_1) (call $endian_swap_32 (i64.shr_u (i64.const 128) (i64.const 32))))) (i64.store (i64.add (local.get $_2) (i64.const 24)) (i64.or (local.get $hi_1) (call $endian_swap_32 (i64.shr_u (i64.const 128) (i64.const 32)))))
(call $eth.revert (i64.add (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (local.get $_1)) (i64.const 64)) (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (local.get $_1))) (call $eth.revert (i32.wrap_i64 (i64.add (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (local.get $_1)) (i64.const 64))) (i32.wrap_i64 (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (local.get $_1))))
) )
(func $u256_to_i32 (func $u256_to_i32
@ -30,9 +30,9 @@
(param $x4 i64) (param $x4 i64)
(result i64) (result i64)
(local $v i64) (local $v i64)
(if (i64.ne (local.get $v) (i64.or (i64.or (local.get $x1) (local.get $x2)) (local.get $x3))) (then (if (i64.ne (i64.ne (local.get $v) (i64.or (i64.or (local.get $x1) (local.get $x2)) (local.get $x3))) (i64.const 0)) (then
(unreachable))) (unreachable)))
(if (i64.ne (local.get $v) (i64.shr_u (local.get $x4) (i64.const 32))) (then (if (i64.ne (i64.ne (local.get $v) (i64.shr_u (local.get $x4) (i64.const 32))) (i64.const 0)) (then
(unreachable))) (unreachable)))
(local.set $v (local.get $x4)) (local.set $v (local.get $x4))
(local.get $v) (local.get $v)
@ -80,8 +80,8 @@
(local.set $hi_1 (i64.shl (call $endian_swap_32 (i64.const 128)) (i64.const 32))) (local.set $hi_1 (i64.shl (call $endian_swap_32 (i64.const 128)) (i64.const 32)))
(i64.store (i64.add (local.get $_2) (i64.const 24)) (i64.or (local.get $hi_1) (call $endian_swap_32 (i64.shr_u (i64.const 128) (i64.const 32))))) (i64.store (i64.add (local.get $_2) (i64.const 24)) (i64.or (local.get $hi_1) (call $endian_swap_32 (i64.shr_u (i64.const 128) (i64.const 32)))))
(local.set $_3 (datasize \"C_2_deployed\")) (local.set $_3 (datasize \"C_2_deployed\"))
(call $eth.codeCopy (i64.add (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (local.get $_1)) (i64.const 64)) (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (dataoffset \"C_2_deployed\")) (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (local.get $_3))) (call $eth.codeCopy (i32.wrap_i64 (i64.add (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (local.get $_1)) (i64.const 64))) (i32.wrap_i64 (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (dataoffset \"C_2_deployed\"))) (i32.wrap_i64 (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (local.get $_3))))
(call $eth.finish (i64.add (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (local.get $_1)) (i64.const 64)) (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (local.get $_3))) (call $eth.finish (i32.wrap_i64 (i64.add (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (local.get $_1)) (i64.const 64))) (i32.wrap_i64 (call $u256_to_i32 (local.get $_1) (local.get $_1) (local.get $_1) (local.get $_3))))
) )
(func $u256_to_i32 (func $u256_to_i32
@ -91,9 +91,9 @@
(param $x4 i64) (param $x4 i64)
(result i64) (result i64)
(local $v i64) (local $v i64)
(if (i64.ne (local.get $v) (i64.or (i64.or (local.get $x1) (local.get $x2)) (local.get $x3))) (then (if (i64.ne (i64.ne (local.get $v) (i64.or (i64.or (local.get $x1) (local.get $x2)) (local.get $x3))) (i64.const 0)) (then
(unreachable))) (unreachable)))
(if (i64.ne (local.get $v) (i64.shr_u (local.get $x4) (i64.const 32))) (then (if (i64.ne (i64.ne (local.get $v) (i64.shr_u (local.get $x4) (i64.const 32))) (i64.const 0)) (then
(unreachable))) (unreachable)))
(local.set $v (local.get $x4)) (local.set $v (local.get $x4))
(local.get $v) (local.get $v)