Codegen for catch.

This commit is contained in:
chriseth 2021-02-03 08:06:35 +01:00
parent 6fd0cf547e
commit 7aa7e708f9
6 changed files with 183 additions and 90 deletions

View File

@ -190,13 +190,15 @@ string ABIFunctions::tupleEncoderPacked(
return templ.render(); return templ.render();
}); });
} }
string ABIFunctions::tupleDecoder(TypePointers const& _types, bool _fromMemory) string ABIFunctions::tupleDecoder(TypePointers const& _types, OnError _onError, bool _fromMemory)
{ {
string functionName = string("abi_decode_tuple_"); string functionName = string("abi_decode_tuple_");
for (auto const& t: _types) for (auto const& t: _types)
functionName += t->identifier(); functionName += t->identifier();
if (_fromMemory) if (_fromMemory)
functionName += "_fromMemory"; functionName += "_fromMemory";
if (_onError == OnError::ReturnFalse)
functionName += "_try";
return createFunction(functionName, [&]() { return createFunction(functionName, [&]() {
TypePointers decodingTypes; TypePointers decodingTypes;
@ -204,17 +206,25 @@ string ABIFunctions::tupleDecoder(TypePointers const& _types, bool _fromMemory)
decodingTypes.emplace_back(t->decodingType()); decodingTypes.emplace_back(t->decodingType());
Whiskers templ(R"( Whiskers templ(R"(
function <functionName>(headStart, dataEnd) <arrow> <valueReturnParams> { function <functionName>(headStart, dataEnd) <+returnParams> -> <returnParams> </+returnParams> {
if slt(sub(dataEnd, headStart), <minimumSize>) { <revertString> } if slt(sub(dataEnd, headStart), <minimumSize>) { <revertString> }
<?try> success := 1 </try>
<decodeElements> <decodeElements>
} }
)"); )");
templ("functionName", functionName); templ("functionName", functionName);
templ("revertString", revertReasonIfDebug("ABI decoding: tuple data too short")); templ("try", _onError == OnError::ReturnFalse);
if (_onError == OnError::ReturnFalse)
// TODO rename
templ("revertString", "success := 0 leave");
else
templ("revertString", revertReasonIfDebug("ABI decoding: tuple data too short"));
templ("minimumSize", to_string(headSize(decodingTypes))); templ("minimumSize", to_string(headSize(decodingTypes)));
string decodeElements; string decodeElements;
vector<string> valueReturnParams; vector<string> valueReturnParams;
if (_onError == OnError::ReturnFalse)
valueReturnParams.emplace_back("success");
size_t headPos = 0; size_t headPos = 0;
size_t stackPos = 0; size_t stackPos = 0;
for (size_t i = 0; i < _types.size(); ++i) for (size_t i = 0; i < _types.size(); ++i)
@ -225,6 +235,8 @@ string ABIFunctions::tupleDecoder(TypePointers const& _types, bool _fromMemory)
solAssert(sizeOnStack == decodingTypes[i]->sizeOnStack(), ""); solAssert(sizeOnStack == decodingTypes[i]->sizeOnStack(), "");
solAssert(sizeOnStack > 0, ""); solAssert(sizeOnStack > 0, "");
vector<string> valueNamesLocal; vector<string> valueNamesLocal;
if (_onError == OnError::ReturnFalse)
valueNamesLocal.push_back("success");
for (size_t j = 0; j < sizeOnStack; j++) for (size_t j = 0; j < sizeOnStack; j++)
{ {
valueNamesLocal.emplace_back("value" + to_string(stackPos)); valueNamesLocal.emplace_back("value" + to_string(stackPos));
@ -240,15 +252,22 @@ string ABIFunctions::tupleDecoder(TypePointers const& _types, bool _fromMemory)
let offset := <pos> let offset := <pos>
</dynamic> </dynamic>
<values> := <abiDecode>(add(headStart, offset), dataEnd) <values> := <abiDecode>(add(headStart, offset), dataEnd)
<?try> if iszero(success) { leave } </try>
} }
)"); )");
elementTempl("dynamic", decodingTypes[i]->isDynamicallyEncoded()); elementTempl("dynamic", decodingTypes[i]->isDynamicallyEncoded());
// TODO add test // TODO add test
elementTempl("revertString", revertReasonIfDebug("ABI decoding: invalid tuple offset")); if (_onError == OnError::ReturnFalse)
// TODO rename
elementTempl("revertString", "success := 0 leave");
else
elementTempl("revertString", revertReasonIfDebug("ABI decoding: invalid tuple offset"));
elementTempl("load", _fromMemory ? "mload" : "calldataload"); elementTempl("load", _fromMemory ? "mload" : "calldataload");
// TODO assign success
elementTempl("values", boost::algorithm::join(valueNamesLocal, ", ")); elementTempl("values", boost::algorithm::join(valueNamesLocal, ", "));
elementTempl("pos", to_string(headPos)); elementTempl("pos", to_string(headPos));
elementTempl("abiDecode", abiDecodingFunction(*_types[i], _fromMemory, true)); elementTempl("abiDecode", abiDecodingFunction(*_types[i], _fromMemory, true));
elementTempl("try", _onError == OnError::ReturnFalse);
decodeElements += elementTempl.render(); decodeElements += elementTempl.render();
headPos += decodingTypes[i]->calldataHeadSize(); headPos += decodingTypes[i]->calldataHeadSize();
} }
@ -1062,7 +1081,7 @@ string ABIFunctions::abiEncodingFunctionFunctionType(
}); });
} }
string ABIFunctions::abiDecodingFunction(Type const& _type, bool _fromMemory, bool _forUseOnStack) string ABIFunctions::abiDecodingFunction(Type const& _type, OnError _onError, bool _fromMemory, bool _forUseOnStack)
{ {
// The decoding function has to perform bounds checks unless it decodes a value type. // The decoding function has to perform bounds checks unless it decodes a value type.
// Conversely, bounds checks have to be performed before the decoding function // Conversely, bounds checks have to be performed before the decoding function
@ -1076,28 +1095,28 @@ string ABIFunctions::abiDecodingFunction(Type const& _type, bool _fromMemory, bo
if (arrayType->dataStoredIn(DataLocation::CallData)) if (arrayType->dataStoredIn(DataLocation::CallData))
{ {
solAssert(!_fromMemory, ""); solAssert(!_fromMemory, "");
return abiDecodingFunctionCalldataArray(*arrayType); return abiDecodingFunctionCalldataArray(*arrayType, _onError);
} }
else else
return abiDecodingFunctionArray(*arrayType, _fromMemory); return abiDecodingFunctionArray(*arrayType, _onError, _fromMemory);
} }
else if (auto const* structType = dynamic_cast<StructType const*>(decodingType)) else if (auto const* structType = dynamic_cast<StructType const*>(decodingType))
{ {
if (structType->dataStoredIn(DataLocation::CallData)) if (structType->dataStoredIn(DataLocation::CallData))
{ {
solAssert(!_fromMemory, ""); solAssert(!_fromMemory, "");
return abiDecodingFunctionCalldataStruct(*structType); return abiDecodingFunctionCalldataStruct(*structType, _onError);
} }
else else
return abiDecodingFunctionStruct(*structType, _fromMemory); return abiDecodingFunctionStruct(*structType, _onError, _fromMemory);
} }
else if (auto const* functionType = dynamic_cast<FunctionType const*>(decodingType)) else if (auto const* functionType = dynamic_cast<FunctionType const*>(decodingType))
return abiDecodingFunctionFunctionType(*functionType, _fromMemory, _forUseOnStack); return abiDecodingFunctionFunctionType(*functionType, _onError, _fromMemory, _forUseOnStack);
else else
return abiDecodingFunctionValueType(_type, _fromMemory); return abiDecodingFunctionValueType(_type, _onError, _fromMemory);
} }
string ABIFunctions::abiDecodingFunctionValueType(Type const& _type, bool _fromMemory) string ABIFunctions::abiDecodingFunctionValueType(Type const& _type, OnError _onError, bool _fromMemory)
{ {
TypePointer decodingType = _type.decodingType(); TypePointer decodingType = _type.decodingType();
solAssert(decodingType, ""); solAssert(decodingType, "");
@ -1109,12 +1128,13 @@ string ABIFunctions::abiDecodingFunctionValueType(Type const& _type, bool _fromM
string functionName = string functionName =
"abi_decode_" + "abi_decode_" +
_type.identifier() + _type.identifier() +
(_fromMemory ? "_fromMemory" : ""); (_fromMemory ? "_fromMemory" : "") +
(_onError == OnError::ReturnFalse ? "_try" : "");
return createFunction(functionName, [&]() { return createFunction(functionName, [&]() {
Whiskers templ(R"( Whiskers templ(R"(
function <functionName>(offset, end) -> value { function <functionName>(offset, end) -> <?try> success, </try> value {
value := <load>(offset) value := <load>(offset)
<validator>(value) <?try> success := </try> <validator>(value)
} }
)"); )");
templ("functionName", functionName); templ("functionName", functionName);
@ -1122,58 +1142,65 @@ string ABIFunctions::abiDecodingFunctionValueType(Type const& _type, bool _fromM
// Validation should use the type and not decodingType, because e.g. // Validation should use the type and not decodingType, because e.g.
// the decoding type of an enum is a plain int. // the decoding type of an enum is a plain int.
templ("validator", m_utils.validatorFunction(_type, true)); templ("validator", m_utils.validatorFunction(_type, true));
// TODO extend the validator to retrun success
return templ.render(); return templ.render();
}); });
} }
string ABIFunctions::abiDecodingFunctionArray(ArrayType const& _type, bool _fromMemory) string ABIFunctions::abiDecodingFunctionArray(ArrayType const& _type, OnError _onError, bool _fromMemory)
{ {
solAssert(_type.dataStoredIn(DataLocation::Memory), ""); solAssert(_type.dataStoredIn(DataLocation::Memory), "");
string functionName = string functionName =
"abi_decode_" + "abi_decode_" +
_type.identifier() + _type.identifier() +
(_fromMemory ? "_fromMemory" : ""); (_fromMemory ? "_fromMemory" : "") +
(_onError == OnError::ReturnFalse ? "_try" : "");
return createFunction(functionName, [&]() { return createFunction(functionName, [&]() {
string load = _fromMemory ? "mload" : "calldataload"; string load = _fromMemory ? "mload" : "calldataload";
Whiskers templ( Whiskers templ(
R"( R"(
// <readableTypeName> // <readableTypeName>
function <functionName>(offset, end) -> array { function <functionName>(offset, end) -> <?try> success, </try> array {
if iszero(slt(add(offset, 0x1f), end)) { <revertString> } if iszero(slt(add(offset, 0x1f), end)) { <revertString> }
let length := <retrieveLength> let length := <retrieveLength>
array := <abiDecodeAvailableLen>(<offset>, length, end) <?try> success, </try> array := <abiDecodeAvailableLen>(<offset>, length, end)
} }
)" )"
); );
// TODO add test // TODO add test
templ("revertString", revertReasonIfDebug("ABI decoding: invalid calldata array offset")); if (_onError == OnError::ReturnFalse)
// TODO rename
templ("revertString", "success := 0 leave");
else
templ("revertString", revertReasonIfDebug("ABI decoding: invalid calldata array offset"));
templ("functionName", functionName); templ("functionName", functionName);
templ("readableTypeName", _type.toString(true)); templ("readableTypeName", _type.toString(true));
templ("retrieveLength", _type.isDynamicallySized() ? (load + "(offset)") : toCompactHexWithPrefix(_type.length())); templ("retrieveLength", _type.isDynamicallySized() ? (load + "(offset)") : toCompactHexWithPrefix(_type.length()));
templ("offset", _type.isDynamicallySized() ? "add(offset, 0x20)" : "offset"); templ("offset", _type.isDynamicallySized() ? "add(offset, 0x20)" : "offset");
templ("abiDecodeAvailableLen", abiDecodingFunctionArrayAvailableLength(_type, _fromMemory)); templ("abiDecodeAvailableLen", abiDecodingFunctionArrayAvailableLength(_type, _onError, _fromMemory));
return templ.render(); return templ.render();
}); });
} }
string ABIFunctions::abiDecodingFunctionArrayAvailableLength(ArrayType const& _type, bool _fromMemory) string ABIFunctions::abiDecodingFunctionArrayAvailableLength(ArrayType const& _type, OnError _onError, bool _fromMemory)
{ {
solAssert(_type.dataStoredIn(DataLocation::Memory), ""); solAssert(_type.dataStoredIn(DataLocation::Memory), "");
if (_type.isByteArray()) if (_type.isByteArray())
return abiDecodingFunctionByteArrayAvailableLength(_type, _fromMemory); return abiDecodingFunctionByteArrayAvailableLength(_type, _onError, _fromMemory);
string functionName = string functionName =
"abi_decode_available_length_" + "abi_decode_available_length_" +
_type.identifier() + _type.identifier() +
(_fromMemory ? "_fromMemory" : ""); (_fromMemory ? "_fromMemory" : "") +
(_onError == OnError::ReturnFalse ? "_try" : "");
return createFunction(functionName, [&]() { return createFunction(functionName, [&]() {
Whiskers templ(R"( Whiskers templ(R"(
// <readableTypeName> // <readableTypeName>
function <functionName>(offset, length, end) -> array { function <functionName>(offset, length, end) -> <?try> success, </try> array {
array := <allocate>(<allocationSize>(length)) array := <allocate>(<allocationSize>(length))
let dst := array let dst := array
<storeLength> <storeLength>
@ -1209,17 +1236,20 @@ string ABIFunctions::abiDecodingFunctionArrayAvailableLength(ArrayType const& _t
templ("staticBoundsCheck", "if gt(add(src, mul(length, " + templ("staticBoundsCheck", "if gt(add(src, mul(length, " +
calldataStride + calldataStride +
")), end) { " + ")), end) { " +
revertReasonIfDebug("ABI decoding: invalid calldata array stride") + (
_onError == OnError::ReturnFalse ? "success := 0 leave" :
revertReasonIfDebug("ABI decoding: invalid calldata array stride")
) +
" }" " }"
); );
templ("retrieveElementPos", "src"); templ("retrieveElementPos", "src");
} }
templ("decodingFun", abiDecodingFunction(*_type.baseType(), _fromMemory, false)); templ("decodingFun", abiDecodingFunction(*_type.baseType(), _onError, _fromMemory, false));
return templ.render(); return templ.render();
}); });
} }
string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type) string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type, OnError _onError)
{ {
solAssert(_type.dataStoredIn(DataLocation::CallData), ""); solAssert(_type.dataStoredIn(DataLocation::CallData), "");
if (!_type.isDynamicallySized()) if (!_type.isDynamicallySized())
@ -1229,14 +1259,16 @@ string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type)
string functionName = string functionName =
"abi_decode_" + "abi_decode_" +
_type.identifier(); _type.identifier() +
(_onError == OnError::ReturnFalse ? "_try" : "");
return createFunction(functionName, [&]() { return createFunction(functionName, [&]() {
Whiskers w; Whiskers w;
if (_type.isDynamicallySized()) if (_type.isDynamicallySized())
{ {
w = Whiskers(R"( w = Whiskers(R"(
// <readableTypeName> // <readableTypeName>
function <functionName>(offset, end) -> arrayPos, length { function <functionName>(offset, end) -> <+try> success, </try> arrayPos, length {
if iszero(slt(add(offset, 0x1f), end)) { <revertStringOffset> } if iszero(slt(add(offset, 0x1f), end)) { <revertStringOffset> }
length := calldataload(offset) length := calldataload(offset)
if gt(length, 0xffffffffffffffff) { <revertStringLength> } if gt(length, 0xffffffffffffffff) { <revertStringLength> }
@ -1244,6 +1276,7 @@ string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type)
if gt(add(arrayPos, mul(length, <stride>)), end) { <revertStringPos> } if gt(add(arrayPos, mul(length, <stride>)), end) { <revertStringPos> }
} }
)"); )");
// TODO modify to return failure
w("revertStringOffset", revertReasonIfDebug("ABI decoding: invalid calldata array offset")); w("revertStringOffset", revertReasonIfDebug("ABI decoding: invalid calldata array offset"));
w("revertStringLength", revertReasonIfDebug("ABI decoding: invalid calldata array length")); w("revertStringLength", revertReasonIfDebug("ABI decoding: invalid calldata array length"));
} }
@ -1251,13 +1284,14 @@ string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type)
{ {
w = Whiskers(R"( w = Whiskers(R"(
// <readableTypeName> // <readableTypeName>
function <functionName>(offset, end) -> arrayPos { function <functionName>(offset, end) -> <+try> success, </try> arrayPos {
arrayPos := offset arrayPos := offset
if gt(add(arrayPos, mul(<length>, <stride>)), end) { <revertStringPos> } if gt(add(arrayPos, mul(<length>, <stride>)), end) { <revertStringPos> }
} }
)"); )");
w("length", toCompactHexWithPrefix(_type.length())); w("length", toCompactHexWithPrefix(_type.length()));
} }
// TODO modify to return failure
w("revertStringPos", revertReasonIfDebug("ABI decoding: invalid calldata array stride")); w("revertStringPos", revertReasonIfDebug("ABI decoding: invalid calldata array stride"));
w("functionName", functionName); w("functionName", functionName);
w("readableTypeName", _type.toString(true)); w("readableTypeName", _type.toString(true));
@ -1268,7 +1302,7 @@ string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type)
}); });
} }
string ABIFunctions::abiDecodingFunctionByteArrayAvailableLength(ArrayType const& _type, bool _fromMemory) string ABIFunctions::abiDecodingFunctionByteArrayAvailableLength(ArrayType const& _type, OnError _onError, bool _fromMemory)
{ {
solAssert(_type.dataStoredIn(DataLocation::Memory), ""); solAssert(_type.dataStoredIn(DataLocation::Memory), "");
solAssert(_type.isByteArray(), ""); solAssert(_type.isByteArray(), "");
@ -1276,11 +1310,12 @@ string ABIFunctions::abiDecodingFunctionByteArrayAvailableLength(ArrayType const
string functionName = string functionName =
"abi_decode_available_length_" + "abi_decode_available_length_" +
_type.identifier() + _type.identifier() +
(_fromMemory ? "_fromMemory" : ""); (_fromMemory ? "_fromMemory" : "") +
(_onError == OnError::ReturnFalse ? "_try" : "");
return createFunction(functionName, [&]() { return createFunction(functionName, [&]() {
Whiskers templ(R"( Whiskers templ(R"(
function <functionName>(src, length, end) -> array { function <functionName>(src, length, end) -> <?try> success, </try> array {
array := <allocate>(<allocationSize>(length)) array := <allocate>(<allocationSize>(length))
mstore(array, length) mstore(array, length)
let dst := add(array, 0x20) let dst := add(array, 0x20)
@ -1288,6 +1323,7 @@ string ABIFunctions::abiDecodingFunctionByteArrayAvailableLength(ArrayType const
<copyToMemFun>(src, dst, length) <copyToMemFun>(src, dst, length)
} }
)"); )");
// TODO modify
templ("revertStringLength", revertReasonIfDebug("ABI decoding: invalid byte array length")); templ("revertStringLength", revertReasonIfDebug("ABI decoding: invalid byte array length"));
templ("functionName", functionName); templ("functionName", functionName);
templ("allocate", m_utils.allocationFunction()); templ("allocate", m_utils.allocationFunction());
@ -1297,22 +1333,24 @@ string ABIFunctions::abiDecodingFunctionByteArrayAvailableLength(ArrayType const
}); });
} }
string ABIFunctions::abiDecodingFunctionCalldataStruct(StructType const& _type) string ABIFunctions::abiDecodingFunctionCalldataStruct(StructType const& _type, OnError _onError)
{ {
solAssert(_type.dataStoredIn(DataLocation::CallData), ""); solAssert(_type.dataStoredIn(DataLocation::CallData), "");
string functionName = string functionName =
"abi_decode_" + "abi_decode_" +
_type.identifier(); _type.identifier() +
(_onError == OnError::ReturnFalse ? "_try" : "");
return createFunction(functionName, [&]() { return createFunction(functionName, [&]() {
Whiskers w{R"( Whiskers w{R"(
// <readableTypeName> // <readableTypeName>
function <functionName>(offset, end) -> value { function <functionName>(offset, end) -> <+try> success, </try> value {
if slt(sub(end, offset), <minimumSize>) { <revertString> } if slt(sub(end, offset), <minimumSize>) { <revertString> }
value := offset value := offset
} }
)"}; )"};
// TODO add test // TODO add test
// TODO modfy
w("revertString", revertReasonIfDebug("ABI decoding: struct calldata too short")); w("revertString", revertReasonIfDebug("ABI decoding: struct calldata too short"));
w("functionName", functionName); w("functionName", functionName);
w("readableTypeName", _type.toString(true)); w("readableTypeName", _type.toString(true));
@ -1321,18 +1359,19 @@ string ABIFunctions::abiDecodingFunctionCalldataStruct(StructType const& _type)
}); });
} }
string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, bool _fromMemory) string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, OnError _onError, bool _fromMemory)
{ {
solAssert(!_type.dataStoredIn(DataLocation::CallData), ""); solAssert(!_type.dataStoredIn(DataLocation::CallData), "");
string functionName = string functionName =
"abi_decode_" + "abi_decode_" +
_type.identifier() + _type.identifier() +
(_fromMemory ? "_fromMemory" : ""); (_fromMemory ? "_fromMemory" : "") +
(_onError == OnError::ReturnFalse ? "_try" : "");
return createFunction(functionName, [&]() { return createFunction(functionName, [&]() {
Whiskers templ(R"( Whiskers templ(R"(
// <readableTypeName> // <readableTypeName>
function <functionName>(headStart, end) -> value { function <functionName>(headStart, end) -> <?try> success, </try> value {
if slt(sub(end, headStart), <minimumSize>) { <revertString> } if slt(sub(end, headStart), <minimumSize>) { <revertString> }
value := <allocate>(<memorySize>) value := <allocate>(<memorySize>)
<#members> <#members>
@ -1344,6 +1383,7 @@ string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, bool _fr
} }
)"); )");
// TODO add test // TODO add test
// TODO modify
templ("revertString", revertReasonIfDebug("ABI decoding: struct data too short")); templ("revertString", revertReasonIfDebug("ABI decoding: struct data too short"));
templ("functionName", functionName); templ("functionName", functionName);
templ("readableTypeName", _type.toString(true)); templ("readableTypeName", _type.toString(true));
@ -1365,10 +1405,12 @@ string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, bool _fr
<!dynamic> <!dynamic>
let offset := <pos> let offset := <pos>
</dynamic> </dynamic>
// TODO outline and leave on error
mstore(add(value, <memoryOffset>), <abiDecode>(add(headStart, offset), end)) mstore(add(value, <memoryOffset>), <abiDecode>(add(headStart, offset), end))
)"); )");
memberTempl("dynamic", decodingType->isDynamicallyEncoded()); memberTempl("dynamic", decodingType->isDynamicallyEncoded());
// TODO add test // TODO add test
// TODO modify
memberTempl("revertString", revertReasonIfDebug("ABI decoding: invalid struct offset")); memberTempl("revertString", revertReasonIfDebug("ABI decoding: invalid struct offset"));
memberTempl("load", _fromMemory ? "mload" : "calldataload"); memberTempl("load", _fromMemory ? "mload" : "calldataload");
memberTempl("pos", to_string(headPos)); memberTempl("pos", to_string(headPos));
@ -1386,7 +1428,7 @@ string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, bool _fr
}); });
} }
string ABIFunctions::abiDecodingFunctionFunctionType(FunctionType const& _type, bool _fromMemory, bool _forUseOnStack) string ABIFunctions::abiDecodingFunctionFunctionType(FunctionType const& _type, OnError _onError, bool _fromMemory, bool _forUseOnStack)
{ {
solAssert(_type.kind() == FunctionType::Kind::External, ""); solAssert(_type.kind() == FunctionType::Kind::External, "");
@ -1394,13 +1436,15 @@ string ABIFunctions::abiDecodingFunctionFunctionType(FunctionType const& _type,
"abi_decode_" + "abi_decode_" +
_type.identifier() + _type.identifier() +
(_fromMemory ? "_fromMemory" : "") + (_fromMemory ? "_fromMemory" : "") +
(_forUseOnStack ? "_onStack" : ""); (_forUseOnStack ? "_onStack" : "") +
(_onError == OnError::ReturnFalse ? "_try" : "");
return createFunction(functionName, [&]() { return createFunction(functionName, [&]() {
if (_forUseOnStack) if (_forUseOnStack)
{ {
return Whiskers(R"( return Whiskers(R"(
function <functionName>(offset, end) -> addr, function_selector { function <functionName>(offset, end) -> <?try> failure, </try> addr, function_selector {
// TODO split
addr, function_selector := <splitExtFun>(<decodeFun>(offset, end)) addr, function_selector := <splitExtFun>(<decodeFun>(offset, end))
} }
)") )")
@ -1412,7 +1456,7 @@ string ABIFunctions::abiDecodingFunctionFunctionType(FunctionType const& _type,
else else
{ {
return Whiskers(R"( return Whiskers(R"(
function <functionName>(offset, end) -> fun { function <functionName>(offset, end) -> <?try> failure, </try> fun {
fun := <load>(offset) fun := <load>(offset)
<validateExtFun>(fun) <validateExtFun>(fun)
} }
@ -1425,17 +1469,18 @@ string ABIFunctions::abiDecodingFunctionFunctionType(FunctionType const& _type,
}); });
} }
string ABIFunctions::calldataAccessFunction(Type const& _type) string ABIFunctions::calldataAccessFunction(Type const& _type, OnError _onError)
{ {
solAssert(_type.isValueType() || _type.dataStoredIn(DataLocation::CallData), ""); solAssert(_type.isValueType() || _type.dataStoredIn(DataLocation::CallData), "");
string functionName = "calldata_access_" + _type.identifier(); string functionName = "calldata_access_" + _type.identifier() + (_onError == OnError::ReturnFalse ? "_try" : "");
return createFunction(functionName, [&]() { return createFunction(functionName, [&]() {
if (_type.isDynamicallyEncoded()) if (_type.isDynamicallyEncoded())
{ {
unsigned int tailSize = _type.calldataEncodedTailSize(); unsigned int tailSize = _type.calldataEncodedTailSize();
solAssert(tailSize > 1, ""); solAssert(tailSize > 1, "");
Whiskers w(R"( Whiskers w(R"(
function <functionName>(base_ref, ptr) -> <return> { function <functionName>(base_ref, ptr) -> <?try> failure, </try> <return> {
let rel_offset_of_tail := calldataload(ptr) let rel_offset_of_tail := calldataload(ptr)
if iszero(slt(rel_offset_of_tail, sub(sub(calldatasize(), base_ref), sub(<neededLength>, 1)))) { <revertStringOffset> } if iszero(slt(rel_offset_of_tail, sub(sub(calldatasize(), base_ref), sub(<neededLength>, 1)))) { <revertStringOffset> }
value := add(rel_offset_of_tail, base_ref) value := add(rel_offset_of_tail, base_ref)
@ -1454,8 +1499,10 @@ string ABIFunctions::calldataAccessFunction(Type const& _type)
)") )")
("calldataStride", toCompactHexWithPrefix(arrayType->calldataStride())) ("calldataStride", toCompactHexWithPrefix(arrayType->calldataStride()))
// TODO add test // TODO add test
// TODO modify
("revertStringLength", revertReasonIfDebug("Invalid calldata access length")) ("revertStringLength", revertReasonIfDebug("Invalid calldata access length"))
// TODO add test // TODO add test
// TODO modify
("revertStringStride", revertReasonIfDebug("Invalid calldata access stride")) ("revertStringStride", revertReasonIfDebug("Invalid calldata access stride"))
.render()); .render());
w("return", "value, length"); w("return", "value, length");
@ -1467,6 +1514,7 @@ string ABIFunctions::calldataAccessFunction(Type const& _type)
} }
w("neededLength", toCompactHexWithPrefix(tailSize)); w("neededLength", toCompactHexWithPrefix(tailSize));
w("functionName", functionName); w("functionName", functionName);
// TODO modify
w("revertStringOffset", revertReasonIfDebug("Invalid calldata access offset")); w("revertStringOffset", revertReasonIfDebug("Invalid calldata access offset"));
return w.render(); return w.render();
} }
@ -1484,6 +1532,7 @@ string ABIFunctions::calldataAccessFunction(Type const& _type)
} }
)") )")
("functionName", functionName) ("functionName", functionName)
// TODO modify
("decodingFunction", decodingFunction) ("decodingFunction", decodingFunction)
.render(); .render();
} }
@ -1494,6 +1543,7 @@ string ABIFunctions::calldataAccessFunction(Type const& _type)
_type.category() == Type::Category::Struct, _type.category() == Type::Category::Struct,
"" ""
); );
// TODO modify
return Whiskers(R"( return Whiskers(R"(
function <functionName>(baseRef, ptr) -> value { function <functionName>(baseRef, ptr) -> value {
value := ptr value := ptr

View File

@ -116,6 +116,8 @@ public:
return tupleEncoderPacked(_givenTypes, _targetTypes, true); return tupleEncoderPacked(_givenTypes, _targetTypes, true);
} }
enum OnError { Revert, ReturnFalse };
/// @returns name of an assembly function to ABI-decode values of @a _types /// @returns name of an assembly function to ABI-decode values of @a _types
/// into memory. If @a _fromMemory is true, decodes from memory instead of /// into memory. If @a _fromMemory is true, decodes from memory instead of
/// from calldata. /// from calldata.
@ -124,7 +126,7 @@ public:
/// Outputs: <value0> <value1> ... <valuen> /// Outputs: <value0> <value1> ... <valuen>
/// The values represent stack slots. If a type occupies more or less than one /// The values represent stack slots. If a type occupies more or less than one
/// stack slot, it takes exactly that number of values. /// stack slot, it takes exactly that number of values.
std::string tupleDecoder(TypePointers const& _types, bool _fromMemory = false); std::string tupleDecoder(TypePointers const& _types, OnError _onError, bool _fromMemory = false);
struct EncodingOptions struct EncodingOptions
{ {
@ -168,12 +170,12 @@ public:
/// Decodes array in case of dynamic arrays with offset pointing to /// Decodes array in case of dynamic arrays with offset pointing to
/// data and length already on stack /// data and length already on stack
/// signature: (dataOffset, length, dataEnd) -> decodedArray /// signature: (dataOffset, length, dataEnd) -> decodedArray
std::string abiDecodingFunctionArrayAvailableLength(ArrayType const& _type, bool _fromMemory); std::string abiDecodingFunctionArrayAvailableLength(ArrayType const& _type, OnError _onError, bool _fromMemory);
/// Internal decoding function that is also used by some copying routines. /// Internal decoding function that is also used by some copying routines.
/// @returns the name of a function that decodes structs. /// @returns the name of a function that decodes structs.
/// signature: (dataStart, dataEnd) -> decodedStruct /// signature: (dataStart, dataEnd) -> decodedStruct
std::string abiDecodingFunctionStruct(StructType const& _type, bool _fromMemory); std::string abiDecodingFunctionStruct(StructType const& _type, OnError _onError, bool _fromMemory);
private: private:
/// Part of @a abiEncodingFunction for array target type and given calldata array. /// Part of @a abiEncodingFunction for array target type and given calldata array.
@ -239,17 +241,17 @@ private:
); );
/// Part of @a abiDecodingFunction for value types. /// Part of @a abiDecodingFunction for value types.
std::string abiDecodingFunctionValueType(Type const& _type, bool _fromMemory); std::string abiDecodingFunctionValueType(Type const& _type, OnError _onError, bool _fromMemory);
/// Part of @a abiDecodingFunction for "regular" array types. /// Part of @a abiDecodingFunction for "regular" array types.
std::string abiDecodingFunctionArray(ArrayType const& _type, bool _fromMemory); std::string abiDecodingFunctionArray(ArrayType const& _type, OnError _onError, bool _fromMemory);
/// Part of @a abiDecodingFunction for calldata array types. /// Part of @a abiDecodingFunction for calldata array types.
std::string abiDecodingFunctionCalldataArray(ArrayType const& _type); std::string abiDecodingFunctionCalldataArray(ArrayType const& _type, OnError _onError);
/// Part of @a abiDecodingFunctionArrayWithAvailableLength /// Part of @a abiDecodingFunctionArrayWithAvailableLength
std::string abiDecodingFunctionByteArrayAvailableLength(ArrayType const& _type, bool _fromMemory); std::string abiDecodingFunctionByteArrayAvailableLength(ArrayType const& _type, OnError _onError, bool _fromMemory);
/// Part of @a abiDecodingFunction for calldata struct types. /// Part of @a abiDecodingFunction for calldata struct types.
std::string abiDecodingFunctionCalldataStruct(StructType const& _type); std::string abiDecodingFunctionCalldataStruct(StructType const& _type, OnError _onError);
/// Part of @a abiDecodingFunction for array types. /// Part of @a abiDecodingFunction for array types.
std::string abiDecodingFunctionFunctionType(FunctionType const& _type, bool _fromMemory, bool _forUseOnStack); std::string abiDecodingFunctionFunctionType(FunctionType const& _type, OnError _onError, bool _fromMemory, bool _forUseOnStack);
/// @returns the name of a function that retrieves an element from calldata. /// @returns the name of a function that retrieves an element from calldata.
std::string calldataAccessFunction(Type const& _type); std::string calldataAccessFunction(Type const& _type);

View File

@ -992,6 +992,7 @@ void ContractCompiler::handleCatch(vector<ASTPointer<TryCatchClause>> const& _ca
if (error || panic) if (error || panic)
// Note that this function returns zero on failure, which is not a problem yet, // Note that this function returns zero on failure, which is not a problem yet,
// but will be a problem once we allow user-defined errors. // but will be a problem once we allow user-defined errors.
// TODO
m_context.callYulFunction(m_context.utilFunctions().returnDataSelectorFunction(), 0, 1); m_context.callYulFunction(m_context.utilFunctions().returnDataSelectorFunction(), 0, 1);
// stack: <selector> // stack: <selector>
if (error) if (error)

View File

@ -4129,11 +4129,13 @@ string YulUtilFunctions::returnDataSelectorFunction()
return m_functionCollector.createFunction(functionName, [&]() { return m_functionCollector.createFunction(functionName, [&]() {
return util::Whiskers(R"( return util::Whiskers(R"(
function <functionName>() -> sig { function <functionName>() -> sig, error {
if gt(returndatasize(), 3) { switch gt(returndatasize(), 3)
case 1 {
returndatacopy(0, 0, 4) returndatacopy(0, 0, 4)
sig := <shr224>(mload(0)) sig := <shr224>(mload(0))
} }
default { error := 1 }
} }
)") )")
("functionName", functionName) ("functionName", functionName)
@ -4201,6 +4203,30 @@ string YulUtilFunctions::tryDecodePanicDataFunction()
}); });
} }
string YulUtilFunctions::tryDecodeReturndata(vector<Type const*> const& _types)
{
string const functionName = "try_decode_returndata_";
for (Type const* type: _types)
functionName += type->identifier();
solAssert(m_evmVersion.supportsReturndata(), "");
return m_functionCollector.createFunction(functionName, [&]() {
return util::Whiskers(R"(
function <functionName>() -> success<+values>, <values></+values> {
if lt(returndatasize(), 4) { leave }
let data := <allocate>(sub(returndatasize(), 4))
returndatacopy(data, 4, sub(returndatasize(), 4))
ret := msg
}
)")
("functionName", functionName)
("allocate", allocationFunction())
("finalizeAllocation", finalizeAllocationFunction())
.render();
});
}
string YulUtilFunctions::extractReturndataFunction() string YulUtilFunctions::extractReturndataFunction()
{ {
string const functionName = "extract_returndata"; string const functionName = "extract_returndata";

View File

@ -471,6 +471,12 @@ public:
/// signature: () -> success, value /// signature: () -> success, value
std::string tryDecodePanicDataFunction(); std::string tryDecodePanicDataFunction();
/// @returns the name of a function that tries to abi-decode parameters in a catch clause
/// from the return data.
/// Does not check the return data signature.
/// signature: () -> success, value...
std::string tryDecodeReturndata(std::vector<Type const*> const& _types);
/// Returns a function name that returns a newly allocated `bytes` array that contains the return data. /// Returns a function name that returns a newly allocated `bytes` array that contains the return data.
/// ///
/// If returndatacopy() is not supported by the underlying target, a empty array will be returned instead. /// If returndatacopy() is not supported by the underlying target, a empty array will be returned instead.

View File

@ -2999,49 +2999,57 @@ bool IRGeneratorForStatements::visit(TryStatement const& _tryStatement)
void IRGeneratorForStatements::handleCatch(TryStatement const& _tryStatement) void IRGeneratorForStatements::handleCatch(TryStatement const& _tryStatement)
{ {
string const runFallback = m_context.newYulVariable(); string const runFallback = m_context.newYulVariable();
string const selector = m_context.newYulVariable();
string const shortCalldata = m_context.newYulVariable();
m_code << "let " << runFallback << " := 1\n"; m_code << "let " << runFallback << " := 1\n";
m_code << "let " << selector << ", " << shortCalldata << ", " << " := " << m_utils.returnDataSelectorFunction() << "()\n";
m_code << "if iszero(" << shortCalldata << ") {\n";
// This function returns zero on "short returndata". We have to add a success flag if (_tryStatement.clauses().size() - 1 - (_tryStatement.fallback() ? 1 : 0) > 0)
// once we implement custom error codes. m_code << "switch " << selector << "\n";
if (_tryStatement.errorClause() || _tryStatement.panicClause())
m_code << "switch " << m_utils.returnDataSelectorFunction() << "()\n";
if (TryCatchClause const* errorClause = _tryStatement.errorClause()) for (ASTPointer<TryCatchClause> const& clause: _tryStatement.clauses())
{ {
m_code << "case " << selectorFromSignature32("Error(string)") << " {\n"; if (clause->kind() == TryCatchClause::Kind::Success || clause->kind() == TryCatchClause::Kind::Fallback)
string const dataVariable = m_context.newYulVariable(); continue;
m_code << "let " << dataVariable << " := " << m_utils.tryDecodeErrorMessageFunction() << "()\n";
m_code << "if " << dataVariable << " {\n"; string signature;
m_code << runFallback << " := 0\n"; vector<Type const*> parameterTypes;
if (errorClause->parameters()) if (clause->kind() == TryCatchClause::Kind::Error)
{ {
solAssert(errorClause->parameters()->parameters().size() == 1, ""); signature = "Error(string)";
IRVariable const& var = m_context.addLocalVariable(*errorClause->parameters()->parameters().front()); parameterTypes.push_back(TypeProvider::stringMemory);
define(var) << dataVariable << "\n";
} }
else if (clause->kind() == TryCatchClause::Kind::Panic)
{
signature = "Panic(uint256)";
parameterTypes.push_back(TypeProvider::uint256();
}
else
{
ErrorDefinition const& error = dynamic_cast<ErrorDefinition const&>(
*clause->errorName().annotation().referecedDeclaration
);
signature = error->functionType(true)->externalSignature();
parameterTypes = error->functionType(true)->parameterTypes();
}
solAssert(clause->parameters(), "");
vector<string> variables;
string const success = m_context.newYulVariable();
variables.push_back(success);
for (ASTPointer<VariableDeclaration> const& varDecl: error->parameters()->parameters())
variables += m_context.addLocalVariable(*varDecl).stackSlots();
m_code << "case " << selectorFromSignature32(signature) << " {\n";
m_code << "let " << joinHumanReadable(variables) << " := " << m_utils.tryDecodeReturndata(parameterTypes) << "()\n";
m_code << "if " << success << " {\n";
m_code << runFallback << " := 0\n";
errorClause->accept(*this); errorClause->accept(*this);
m_code << "}\n"; m_code << "}\n";
m_code << "}\n"; m_code << "}\n";
} }
if (TryCatchClause const* panicClause = _tryStatement.panicClause())
{
m_code << "case " << selectorFromSignature32("Panic(uint256)") << " {\n";
string const success = m_context.newYulVariable();
string const code = m_context.newYulVariable();
m_code << "let " << success << ", " << code << " := " << m_utils.tryDecodePanicDataFunction() << "()\n";
m_code << "if " << success << " {\n";
m_code << runFallback << " := 0\n";
if (panicClause->parameters())
{
solAssert(panicClause->parameters()->parameters().size() == 1, "");
IRVariable const& var = m_context.addLocalVariable(*panicClause->parameters()->parameters().front());
define(var) << code << "\n";
}
panicClause->accept(*this);
m_code << "}\n";
m_code << "}\n";
}
m_code << "}\n";
m_code << "if " << runFallback << " {\n"; m_code << "if " << runFallback << " {\n";
if (_tryStatement.fallbackClause()) if (_tryStatement.fallbackClause())
handleCatchFallback(*_tryStatement.fallbackClause()); handleCatchFallback(*_tryStatement.fallbackClause());