From 7aa7e708f9e889cc0b0529654b0d2ae4bdaa153c Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 3 Feb 2021 08:06:35 +0100 Subject: [PATCH] Codegen for catch. --- libsolidity/codegen/ABIFunctions.cpp | 144 ++++++++++++------ libsolidity/codegen/ABIFunctions.h | 20 +-- libsolidity/codegen/ContractCompiler.cpp | 1 + libsolidity/codegen/YulUtilFunctions.cpp | 30 +++- libsolidity/codegen/YulUtilFunctions.h | 6 + .../codegen/ir/IRGeneratorForStatements.cpp | 72 +++++---- 6 files changed, 183 insertions(+), 90 deletions(-) diff --git a/libsolidity/codegen/ABIFunctions.cpp b/libsolidity/codegen/ABIFunctions.cpp index d2512881a..46cb392b8 100644 --- a/libsolidity/codegen/ABIFunctions.cpp +++ b/libsolidity/codegen/ABIFunctions.cpp @@ -190,13 +190,15 @@ string ABIFunctions::tupleEncoderPacked( return templ.render(); }); } -string ABIFunctions::tupleDecoder(TypePointers const& _types, bool _fromMemory) +string ABIFunctions::tupleDecoder(TypePointers const& _types, OnError _onError, bool _fromMemory) { string functionName = string("abi_decode_tuple_"); for (auto const& t: _types) functionName += t->identifier(); if (_fromMemory) functionName += "_fromMemory"; + if (_onError == OnError::ReturnFalse) + functionName += "_try"; return createFunction(functionName, [&]() { TypePointers decodingTypes; @@ -204,17 +206,25 @@ string ABIFunctions::tupleDecoder(TypePointers const& _types, bool _fromMemory) decodingTypes.emplace_back(t->decodingType()); Whiskers templ(R"( - function (headStart, dataEnd) { + function (headStart, dataEnd) <+returnParams> -> { if slt(sub(dataEnd, headStart), ) { } + success := 1 } )"); templ("functionName", functionName); - templ("revertString", revertReasonIfDebug("ABI decoding: tuple data too short")); + templ("try", _onError == OnError::ReturnFalse); + if (_onError == OnError::ReturnFalse) + // TODO rename + templ("revertString", "success := 0 leave"); + else + templ("revertString", revertReasonIfDebug("ABI decoding: tuple data too short")); templ("minimumSize", to_string(headSize(decodingTypes))); string decodeElements; vector valueReturnParams; + if (_onError == OnError::ReturnFalse) + valueReturnParams.emplace_back("success"); size_t headPos = 0; size_t stackPos = 0; for (size_t i = 0; i < _types.size(); ++i) @@ -225,6 +235,8 @@ string ABIFunctions::tupleDecoder(TypePointers const& _types, bool _fromMemory) solAssert(sizeOnStack == decodingTypes[i]->sizeOnStack(), ""); solAssert(sizeOnStack > 0, ""); vector valueNamesLocal; + if (_onError == OnError::ReturnFalse) + valueNamesLocal.push_back("success"); for (size_t j = 0; j < sizeOnStack; j++) { valueNamesLocal.emplace_back("value" + to_string(stackPos)); @@ -240,15 +252,22 @@ string ABIFunctions::tupleDecoder(TypePointers const& _types, bool _fromMemory) let offset := := (add(headStart, offset), dataEnd) + if iszero(success) { leave } } )"); elementTempl("dynamic", decodingTypes[i]->isDynamicallyEncoded()); // TODO add test - elementTempl("revertString", revertReasonIfDebug("ABI decoding: invalid tuple offset")); + if (_onError == OnError::ReturnFalse) + // TODO rename + elementTempl("revertString", "success := 0 leave"); + else + elementTempl("revertString", revertReasonIfDebug("ABI decoding: invalid tuple offset")); elementTempl("load", _fromMemory ? "mload" : "calldataload"); + // TODO assign success elementTempl("values", boost::algorithm::join(valueNamesLocal, ", ")); elementTempl("pos", to_string(headPos)); elementTempl("abiDecode", abiDecodingFunction(*_types[i], _fromMemory, true)); + elementTempl("try", _onError == OnError::ReturnFalse); decodeElements += elementTempl.render(); headPos += decodingTypes[i]->calldataHeadSize(); } @@ -1062,7 +1081,7 @@ string ABIFunctions::abiEncodingFunctionFunctionType( }); } -string ABIFunctions::abiDecodingFunction(Type const& _type, bool _fromMemory, bool _forUseOnStack) +string ABIFunctions::abiDecodingFunction(Type const& _type, OnError _onError, bool _fromMemory, bool _forUseOnStack) { // The decoding function has to perform bounds checks unless it decodes a value type. // Conversely, bounds checks have to be performed before the decoding function @@ -1076,28 +1095,28 @@ string ABIFunctions::abiDecodingFunction(Type const& _type, bool _fromMemory, bo if (arrayType->dataStoredIn(DataLocation::CallData)) { solAssert(!_fromMemory, ""); - return abiDecodingFunctionCalldataArray(*arrayType); + return abiDecodingFunctionCalldataArray(*arrayType, _onError); } else - return abiDecodingFunctionArray(*arrayType, _fromMemory); + return abiDecodingFunctionArray(*arrayType, _onError, _fromMemory); } else if (auto const* structType = dynamic_cast(decodingType)) { if (structType->dataStoredIn(DataLocation::CallData)) { solAssert(!_fromMemory, ""); - return abiDecodingFunctionCalldataStruct(*structType); + return abiDecodingFunctionCalldataStruct(*structType, _onError); } else - return abiDecodingFunctionStruct(*structType, _fromMemory); + return abiDecodingFunctionStruct(*structType, _onError, _fromMemory); } else if (auto const* functionType = dynamic_cast(decodingType)) - return abiDecodingFunctionFunctionType(*functionType, _fromMemory, _forUseOnStack); + return abiDecodingFunctionFunctionType(*functionType, _onError, _fromMemory, _forUseOnStack); else - return abiDecodingFunctionValueType(_type, _fromMemory); + return abiDecodingFunctionValueType(_type, _onError, _fromMemory); } -string ABIFunctions::abiDecodingFunctionValueType(Type const& _type, bool _fromMemory) +string ABIFunctions::abiDecodingFunctionValueType(Type const& _type, OnError _onError, bool _fromMemory) { TypePointer decodingType = _type.decodingType(); solAssert(decodingType, ""); @@ -1109,12 +1128,13 @@ string ABIFunctions::abiDecodingFunctionValueType(Type const& _type, bool _fromM string functionName = "abi_decode_" + _type.identifier() + - (_fromMemory ? "_fromMemory" : ""); + (_fromMemory ? "_fromMemory" : "") + + (_onError == OnError::ReturnFalse ? "_try" : ""); return createFunction(functionName, [&]() { Whiskers templ(R"( - function (offset, end) -> value { + function (offset, end) -> success, value { value := (offset) - (value) + success := (value) } )"); templ("functionName", functionName); @@ -1122,58 +1142,65 @@ string ABIFunctions::abiDecodingFunctionValueType(Type const& _type, bool _fromM // Validation should use the type and not decodingType, because e.g. // the decoding type of an enum is a plain int. templ("validator", m_utils.validatorFunction(_type, true)); + // TODO extend the validator to retrun success return templ.render(); }); } -string ABIFunctions::abiDecodingFunctionArray(ArrayType const& _type, bool _fromMemory) +string ABIFunctions::abiDecodingFunctionArray(ArrayType const& _type, OnError _onError, bool _fromMemory) { solAssert(_type.dataStoredIn(DataLocation::Memory), ""); string functionName = "abi_decode_" + _type.identifier() + - (_fromMemory ? "_fromMemory" : ""); + (_fromMemory ? "_fromMemory" : "") + + (_onError == OnError::ReturnFalse ? "_try" : ""); return createFunction(functionName, [&]() { string load = _fromMemory ? "mload" : "calldataload"; Whiskers templ( R"( // - function (offset, end) -> array { + function (offset, end) -> success, array { if iszero(slt(add(offset, 0x1f), end)) { } let length := - array := (, length, end) + success, array := (, length, end) } )" ); // TODO add test - templ("revertString", revertReasonIfDebug("ABI decoding: invalid calldata array offset")); + if (_onError == OnError::ReturnFalse) + // TODO rename + templ("revertString", "success := 0 leave"); + else + templ("revertString", revertReasonIfDebug("ABI decoding: invalid calldata array offset")); templ("functionName", functionName); templ("readableTypeName", _type.toString(true)); templ("retrieveLength", _type.isDynamicallySized() ? (load + "(offset)") : toCompactHexWithPrefix(_type.length())); templ("offset", _type.isDynamicallySized() ? "add(offset, 0x20)" : "offset"); - templ("abiDecodeAvailableLen", abiDecodingFunctionArrayAvailableLength(_type, _fromMemory)); + templ("abiDecodeAvailableLen", abiDecodingFunctionArrayAvailableLength(_type, _onError, _fromMemory)); return templ.render(); }); } -string ABIFunctions::abiDecodingFunctionArrayAvailableLength(ArrayType const& _type, bool _fromMemory) +string ABIFunctions::abiDecodingFunctionArrayAvailableLength(ArrayType const& _type, OnError _onError, bool _fromMemory) { solAssert(_type.dataStoredIn(DataLocation::Memory), ""); if (_type.isByteArray()) - return abiDecodingFunctionByteArrayAvailableLength(_type, _fromMemory); + return abiDecodingFunctionByteArrayAvailableLength(_type, _onError, _fromMemory); string functionName = "abi_decode_available_length_" + _type.identifier() + - (_fromMemory ? "_fromMemory" : ""); + (_fromMemory ? "_fromMemory" : "") + + (_onError == OnError::ReturnFalse ? "_try" : ""); return createFunction(functionName, [&]() { Whiskers templ(R"( // - function (offset, length, end) -> array { + function (offset, length, end) -> success, array { array := ((length)) let dst := array @@ -1209,17 +1236,20 @@ string ABIFunctions::abiDecodingFunctionArrayAvailableLength(ArrayType const& _t templ("staticBoundsCheck", "if gt(add(src, mul(length, " + calldataStride + ")), end) { " + - revertReasonIfDebug("ABI decoding: invalid calldata array stride") + + ( + _onError == OnError::ReturnFalse ? "success := 0 leave" : + revertReasonIfDebug("ABI decoding: invalid calldata array stride") + ) + " }" ); templ("retrieveElementPos", "src"); } - templ("decodingFun", abiDecodingFunction(*_type.baseType(), _fromMemory, false)); + templ("decodingFun", abiDecodingFunction(*_type.baseType(), _onError, _fromMemory, false)); return templ.render(); }); } -string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type) +string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type, OnError _onError) { solAssert(_type.dataStoredIn(DataLocation::CallData), ""); if (!_type.isDynamicallySized()) @@ -1229,14 +1259,16 @@ string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type) string functionName = "abi_decode_" + - _type.identifier(); + _type.identifier() + + (_onError == OnError::ReturnFalse ? "_try" : ""); + return createFunction(functionName, [&]() { Whiskers w; if (_type.isDynamicallySized()) { w = Whiskers(R"( // - function (offset, end) -> arrayPos, length { + function (offset, end) -> <+try> success, arrayPos, length { if iszero(slt(add(offset, 0x1f), end)) { } length := calldataload(offset) if gt(length, 0xffffffffffffffff) { } @@ -1244,6 +1276,7 @@ string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type) if gt(add(arrayPos, mul(length, )), end) { } } )"); + // TODO modify to return failure w("revertStringOffset", revertReasonIfDebug("ABI decoding: invalid calldata array offset")); w("revertStringLength", revertReasonIfDebug("ABI decoding: invalid calldata array length")); } @@ -1251,13 +1284,14 @@ string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type) { w = Whiskers(R"( // - function (offset, end) -> arrayPos { + function (offset, end) -> <+try> success, arrayPos { arrayPos := offset if gt(add(arrayPos, mul(, )), end) { } } )"); w("length", toCompactHexWithPrefix(_type.length())); } + // TODO modify to return failure w("revertStringPos", revertReasonIfDebug("ABI decoding: invalid calldata array stride")); w("functionName", functionName); w("readableTypeName", _type.toString(true)); @@ -1268,7 +1302,7 @@ string ABIFunctions::abiDecodingFunctionCalldataArray(ArrayType const& _type) }); } -string ABIFunctions::abiDecodingFunctionByteArrayAvailableLength(ArrayType const& _type, bool _fromMemory) +string ABIFunctions::abiDecodingFunctionByteArrayAvailableLength(ArrayType const& _type, OnError _onError, bool _fromMemory) { solAssert(_type.dataStoredIn(DataLocation::Memory), ""); solAssert(_type.isByteArray(), ""); @@ -1276,11 +1310,12 @@ string ABIFunctions::abiDecodingFunctionByteArrayAvailableLength(ArrayType const string functionName = "abi_decode_available_length_" + _type.identifier() + - (_fromMemory ? "_fromMemory" : ""); + (_fromMemory ? "_fromMemory" : "") + + (_onError == OnError::ReturnFalse ? "_try" : ""); return createFunction(functionName, [&]() { Whiskers templ(R"( - function (src, length, end) -> array { + function (src, length, end) -> success, array { array := ((length)) mstore(array, length) let dst := add(array, 0x20) @@ -1288,6 +1323,7 @@ string ABIFunctions::abiDecodingFunctionByteArrayAvailableLength(ArrayType const (src, dst, length) } )"); + // TODO modify templ("revertStringLength", revertReasonIfDebug("ABI decoding: invalid byte array length")); templ("functionName", functionName); templ("allocate", m_utils.allocationFunction()); @@ -1297,22 +1333,24 @@ string ABIFunctions::abiDecodingFunctionByteArrayAvailableLength(ArrayType const }); } -string ABIFunctions::abiDecodingFunctionCalldataStruct(StructType const& _type) +string ABIFunctions::abiDecodingFunctionCalldataStruct(StructType const& _type, OnError _onError) { solAssert(_type.dataStoredIn(DataLocation::CallData), ""); string functionName = "abi_decode_" + - _type.identifier(); + _type.identifier() + + (_onError == OnError::ReturnFalse ? "_try" : ""); return createFunction(functionName, [&]() { Whiskers w{R"( // - function (offset, end) -> value { + function (offset, end) -> <+try> success, value { if slt(sub(end, offset), ) { } value := offset } )"}; // TODO add test + // TODO modfy w("revertString", revertReasonIfDebug("ABI decoding: struct calldata too short")); w("functionName", functionName); w("readableTypeName", _type.toString(true)); @@ -1321,18 +1359,19 @@ string ABIFunctions::abiDecodingFunctionCalldataStruct(StructType const& _type) }); } -string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, bool _fromMemory) +string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, OnError _onError, bool _fromMemory) { solAssert(!_type.dataStoredIn(DataLocation::CallData), ""); string functionName = "abi_decode_" + _type.identifier() + - (_fromMemory ? "_fromMemory" : ""); + (_fromMemory ? "_fromMemory" : "") + + (_onError == OnError::ReturnFalse ? "_try" : ""); return createFunction(functionName, [&]() { Whiskers templ(R"( // - function (headStart, end) -> value { + function (headStart, end) -> success, value { if slt(sub(end, headStart), ) { } value := () <#members> @@ -1344,6 +1383,7 @@ string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, bool _fr } )"); // TODO add test + // TODO modify templ("revertString", revertReasonIfDebug("ABI decoding: struct data too short")); templ("functionName", functionName); templ("readableTypeName", _type.toString(true)); @@ -1365,10 +1405,12 @@ string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, bool _fr let offset := +// TODO outline and leave on error mstore(add(value, ), (add(headStart, offset), end)) )"); memberTempl("dynamic", decodingType->isDynamicallyEncoded()); // TODO add test + // TODO modify memberTempl("revertString", revertReasonIfDebug("ABI decoding: invalid struct offset")); memberTempl("load", _fromMemory ? "mload" : "calldataload"); memberTempl("pos", to_string(headPos)); @@ -1386,7 +1428,7 @@ string ABIFunctions::abiDecodingFunctionStruct(StructType const& _type, bool _fr }); } -string ABIFunctions::abiDecodingFunctionFunctionType(FunctionType const& _type, bool _fromMemory, bool _forUseOnStack) +string ABIFunctions::abiDecodingFunctionFunctionType(FunctionType const& _type, OnError _onError, bool _fromMemory, bool _forUseOnStack) { solAssert(_type.kind() == FunctionType::Kind::External, ""); @@ -1394,13 +1436,15 @@ string ABIFunctions::abiDecodingFunctionFunctionType(FunctionType const& _type, "abi_decode_" + _type.identifier() + (_fromMemory ? "_fromMemory" : "") + - (_forUseOnStack ? "_onStack" : ""); + (_forUseOnStack ? "_onStack" : "") + + (_onError == OnError::ReturnFalse ? "_try" : ""); return createFunction(functionName, [&]() { if (_forUseOnStack) { return Whiskers(R"( - function (offset, end) -> addr, function_selector { + function (offset, end) -> failure, addr, function_selector { +// TODO split addr, function_selector := ((offset, end)) } )") @@ -1412,7 +1456,7 @@ string ABIFunctions::abiDecodingFunctionFunctionType(FunctionType const& _type, else { return Whiskers(R"( - function (offset, end) -> fun { + function (offset, end) -> failure, fun { fun := (offset) (fun) } @@ -1425,17 +1469,18 @@ string ABIFunctions::abiDecodingFunctionFunctionType(FunctionType const& _type, }); } -string ABIFunctions::calldataAccessFunction(Type const& _type) +string ABIFunctions::calldataAccessFunction(Type const& _type, OnError _onError) { solAssert(_type.isValueType() || _type.dataStoredIn(DataLocation::CallData), ""); - string functionName = "calldata_access_" + _type.identifier(); + string functionName = "calldata_access_" + _type.identifier() + (_onError == OnError::ReturnFalse ? "_try" : ""); + return createFunction(functionName, [&]() { if (_type.isDynamicallyEncoded()) { unsigned int tailSize = _type.calldataEncodedTailSize(); solAssert(tailSize > 1, ""); Whiskers w(R"( - function (base_ref, ptr) -> { + function (base_ref, ptr) -> failure, { let rel_offset_of_tail := calldataload(ptr) if iszero(slt(rel_offset_of_tail, sub(sub(calldatasize(), base_ref), sub(, 1)))) { } value := add(rel_offset_of_tail, base_ref) @@ -1454,8 +1499,10 @@ string ABIFunctions::calldataAccessFunction(Type const& _type) )") ("calldataStride", toCompactHexWithPrefix(arrayType->calldataStride())) // TODO add test + // TODO modify ("revertStringLength", revertReasonIfDebug("Invalid calldata access length")) // TODO add test + // TODO modify ("revertStringStride", revertReasonIfDebug("Invalid calldata access stride")) .render()); w("return", "value, length"); @@ -1467,6 +1514,7 @@ string ABIFunctions::calldataAccessFunction(Type const& _type) } w("neededLength", toCompactHexWithPrefix(tailSize)); w("functionName", functionName); + // TODO modify w("revertStringOffset", revertReasonIfDebug("Invalid calldata access offset")); return w.render(); } @@ -1484,6 +1532,7 @@ string ABIFunctions::calldataAccessFunction(Type const& _type) } )") ("functionName", functionName) + // TODO modify ("decodingFunction", decodingFunction) .render(); } @@ -1494,6 +1543,7 @@ string ABIFunctions::calldataAccessFunction(Type const& _type) _type.category() == Type::Category::Struct, "" ); + // TODO modify return Whiskers(R"( function (baseRef, ptr) -> value { value := ptr diff --git a/libsolidity/codegen/ABIFunctions.h b/libsolidity/codegen/ABIFunctions.h index 81ddcac00..52e7254e4 100644 --- a/libsolidity/codegen/ABIFunctions.h +++ b/libsolidity/codegen/ABIFunctions.h @@ -116,6 +116,8 @@ public: return tupleEncoderPacked(_givenTypes, _targetTypes, true); } + enum OnError { Revert, ReturnFalse }; + /// @returns name of an assembly function to ABI-decode values of @a _types /// into memory. If @a _fromMemory is true, decodes from memory instead of /// from calldata. @@ -124,7 +126,7 @@ public: /// Outputs: ... /// The values represent stack slots. If a type occupies more or less than one /// stack slot, it takes exactly that number of values. - std::string tupleDecoder(TypePointers const& _types, bool _fromMemory = false); + std::string tupleDecoder(TypePointers const& _types, OnError _onError, bool _fromMemory = false); struct EncodingOptions { @@ -168,12 +170,12 @@ public: /// Decodes array in case of dynamic arrays with offset pointing to /// data and length already on stack /// signature: (dataOffset, length, dataEnd) -> decodedArray - std::string abiDecodingFunctionArrayAvailableLength(ArrayType const& _type, bool _fromMemory); + std::string abiDecodingFunctionArrayAvailableLength(ArrayType const& _type, OnError _onError, bool _fromMemory); /// Internal decoding function that is also used by some copying routines. /// @returns the name of a function that decodes structs. /// signature: (dataStart, dataEnd) -> decodedStruct - std::string abiDecodingFunctionStruct(StructType const& _type, bool _fromMemory); + std::string abiDecodingFunctionStruct(StructType const& _type, OnError _onError, bool _fromMemory); private: /// Part of @a abiEncodingFunction for array target type and given calldata array. @@ -239,17 +241,17 @@ private: ); /// Part of @a abiDecodingFunction for value types. - std::string abiDecodingFunctionValueType(Type const& _type, bool _fromMemory); + std::string abiDecodingFunctionValueType(Type const& _type, OnError _onError, bool _fromMemory); /// Part of @a abiDecodingFunction for "regular" array types. - std::string abiDecodingFunctionArray(ArrayType const& _type, bool _fromMemory); + std::string abiDecodingFunctionArray(ArrayType const& _type, OnError _onError, bool _fromMemory); /// Part of @a abiDecodingFunction for calldata array types. - std::string abiDecodingFunctionCalldataArray(ArrayType const& _type); + std::string abiDecodingFunctionCalldataArray(ArrayType const& _type, OnError _onError); /// Part of @a abiDecodingFunctionArrayWithAvailableLength - std::string abiDecodingFunctionByteArrayAvailableLength(ArrayType const& _type, bool _fromMemory); + std::string abiDecodingFunctionByteArrayAvailableLength(ArrayType const& _type, OnError _onError, bool _fromMemory); /// Part of @a abiDecodingFunction for calldata struct types. - std::string abiDecodingFunctionCalldataStruct(StructType const& _type); + std::string abiDecodingFunctionCalldataStruct(StructType const& _type, OnError _onError); /// Part of @a abiDecodingFunction for array types. - std::string abiDecodingFunctionFunctionType(FunctionType const& _type, bool _fromMemory, bool _forUseOnStack); + std::string abiDecodingFunctionFunctionType(FunctionType const& _type, OnError _onError, bool _fromMemory, bool _forUseOnStack); /// @returns the name of a function that retrieves an element from calldata. std::string calldataAccessFunction(Type const& _type); diff --git a/libsolidity/codegen/ContractCompiler.cpp b/libsolidity/codegen/ContractCompiler.cpp index 410eae4c4..14fcd40d7 100644 --- a/libsolidity/codegen/ContractCompiler.cpp +++ b/libsolidity/codegen/ContractCompiler.cpp @@ -992,6 +992,7 @@ void ContractCompiler::handleCatch(vector> const& _ca if (error || panic) // Note that this function returns zero on failure, which is not a problem yet, // but will be a problem once we allow user-defined errors. + // TODO m_context.callYulFunction(m_context.utilFunctions().returnDataSelectorFunction(), 0, 1); // stack: if (error) diff --git a/libsolidity/codegen/YulUtilFunctions.cpp b/libsolidity/codegen/YulUtilFunctions.cpp index 508b74497..269eaf57d 100644 --- a/libsolidity/codegen/YulUtilFunctions.cpp +++ b/libsolidity/codegen/YulUtilFunctions.cpp @@ -4129,11 +4129,13 @@ string YulUtilFunctions::returnDataSelectorFunction() return m_functionCollector.createFunction(functionName, [&]() { return util::Whiskers(R"( - function () -> sig { - if gt(returndatasize(), 3) { + function () -> sig, error { + switch gt(returndatasize(), 3) + case 1 { returndatacopy(0, 0, 4) sig := (mload(0)) } + default { error := 1 } } )") ("functionName", functionName) @@ -4201,6 +4203,30 @@ string YulUtilFunctions::tryDecodePanicDataFunction() }); } +string YulUtilFunctions::tryDecodeReturndata(vector const& _types) +{ + string const functionName = "try_decode_returndata_"; + for (Type const* type: _types) + functionName += type->identifier(); + solAssert(m_evmVersion.supportsReturndata(), ""); + + return m_functionCollector.createFunction(functionName, [&]() { + return util::Whiskers(R"( + function () -> success<+values>, { + if lt(returndatasize(), 4) { leave } + let data := (sub(returndatasize(), 4)) + returndatacopy(data, 4, sub(returndatasize(), 4)) + + ret := msg + } + )") + ("functionName", functionName) + ("allocate", allocationFunction()) + ("finalizeAllocation", finalizeAllocationFunction()) + .render(); + }); +} + string YulUtilFunctions::extractReturndataFunction() { string const functionName = "extract_returndata"; diff --git a/libsolidity/codegen/YulUtilFunctions.h b/libsolidity/codegen/YulUtilFunctions.h index 074c348a0..dacf0fb85 100644 --- a/libsolidity/codegen/YulUtilFunctions.h +++ b/libsolidity/codegen/YulUtilFunctions.h @@ -471,6 +471,12 @@ public: /// signature: () -> success, value std::string tryDecodePanicDataFunction(); + /// @returns the name of a function that tries to abi-decode parameters in a catch clause + /// from the return data. + /// Does not check the return data signature. + /// signature: () -> success, value... + std::string tryDecodeReturndata(std::vector const& _types); + /// Returns a function name that returns a newly allocated `bytes` array that contains the return data. /// /// If returndatacopy() is not supported by the underlying target, a empty array will be returned instead. diff --git a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp index b397af546..7ba81f811 100644 --- a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp +++ b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp @@ -2999,49 +2999,57 @@ bool IRGeneratorForStatements::visit(TryStatement const& _tryStatement) void IRGeneratorForStatements::handleCatch(TryStatement const& _tryStatement) { string const runFallback = m_context.newYulVariable(); + string const selector = m_context.newYulVariable(); + string const shortCalldata = m_context.newYulVariable(); m_code << "let " << runFallback << " := 1\n"; + m_code << "let " << selector << ", " << shortCalldata << ", " << " := " << m_utils.returnDataSelectorFunction() << "()\n"; + m_code << "if iszero(" << shortCalldata << ") {\n"; - // This function returns zero on "short returndata". We have to add a success flag - // once we implement custom error codes. - if (_tryStatement.errorClause() || _tryStatement.panicClause()) - m_code << "switch " << m_utils.returnDataSelectorFunction() << "()\n"; + if (_tryStatement.clauses().size() - 1 - (_tryStatement.fallback() ? 1 : 0) > 0) + m_code << "switch " << selector << "\n"; - if (TryCatchClause const* errorClause = _tryStatement.errorClause()) + for (ASTPointer const& clause: _tryStatement.clauses()) { - m_code << "case " << selectorFromSignature32("Error(string)") << " {\n"; - string const dataVariable = m_context.newYulVariable(); - m_code << "let " << dataVariable << " := " << m_utils.tryDecodeErrorMessageFunction() << "()\n"; - m_code << "if " << dataVariable << " {\n"; - m_code << runFallback << " := 0\n"; - if (errorClause->parameters()) + if (clause->kind() == TryCatchClause::Kind::Success || clause->kind() == TryCatchClause::Kind::Fallback) + continue; + + string signature; + vector parameterTypes; + if (clause->kind() == TryCatchClause::Kind::Error) { - solAssert(errorClause->parameters()->parameters().size() == 1, ""); - IRVariable const& var = m_context.addLocalVariable(*errorClause->parameters()->parameters().front()); - define(var) << dataVariable << "\n"; + signature = "Error(string)"; + parameterTypes.push_back(TypeProvider::stringMemory); } + else if (clause->kind() == TryCatchClause::Kind::Panic) + { + signature = "Panic(uint256)"; + parameterTypes.push_back(TypeProvider::uint256(); + } + else + { + ErrorDefinition const& error = dynamic_cast( + *clause->errorName().annotation().referecedDeclaration + ); + signature = error->functionType(true)->externalSignature(); + parameterTypes = error->functionType(true)->parameterTypes(); + } + solAssert(clause->parameters(), ""); + vector variables; + string const success = m_context.newYulVariable(); + variables.push_back(success); + for (ASTPointer const& varDecl: error->parameters()->parameters()) + variables += m_context.addLocalVariable(*varDecl).stackSlots(); + + m_code << "case " << selectorFromSignature32(signature) << " {\n"; + m_code << "let " << joinHumanReadable(variables) << " := " << m_utils.tryDecodeReturndata(parameterTypes) << "()\n"; + m_code << "if " << success << " {\n"; + m_code << runFallback << " := 0\n"; errorClause->accept(*this); m_code << "}\n"; m_code << "}\n"; } - if (TryCatchClause const* panicClause = _tryStatement.panicClause()) - { - m_code << "case " << selectorFromSignature32("Panic(uint256)") << " {\n"; - string const success = m_context.newYulVariable(); - string const code = m_context.newYulVariable(); - m_code << "let " << success << ", " << code << " := " << m_utils.tryDecodePanicDataFunction() << "()\n"; - m_code << "if " << success << " {\n"; - m_code << runFallback << " := 0\n"; - if (panicClause->parameters()) - { - solAssert(panicClause->parameters()->parameters().size() == 1, ""); - IRVariable const& var = m_context.addLocalVariable(*panicClause->parameters()->parameters().front()); - define(var) << code << "\n"; - } - panicClause->accept(*this); - m_code << "}\n"; - m_code << "}\n"; - } + m_code << "}\n"; m_code << "if " << runFallback << " {\n"; if (_tryStatement.fallbackClause()) handleCatchFallback(*_tryStatement.fallbackClause());