diff --git a/libsolidity/codegen/YulUtilFunctions.cpp b/libsolidity/codegen/YulUtilFunctions.cpp index 737edd696..9c93441ef 100644 --- a/libsolidity/codegen/YulUtilFunctions.cpp +++ b/libsolidity/codegen/YulUtilFunctions.cpp @@ -409,31 +409,35 @@ string YulUtilFunctions::overflowCheckedIntAddFunction(IntegerType const& _type) }); } -string YulUtilFunctions::overflowCheckedUIntMulFunction(size_t _bits) +string YulUtilFunctions::overflowCheckedIntMulFunction(IntegerType const& _type) { - solAssert(0 < _bits && _bits <= 256 && _bits % 8 == 0, ""); - string functionName = "checked_mul_uint_" + to_string(_bits); + string functionName = "checked_mul_" + _type.identifier(); return m_functionCollector->createFunction(functionName, [&]() { return - // - The current overflow check *before* the multiplication could - // be replaced by the following check *after* the multiplication: - // if and(iszero(iszero(x)), iszero(eq(div(product, x), y))) { revert(0, 0) } - // - The case the x equals 0 could be treated separately and directly return zero. + // Multiplication by zero could be treated separately and directly return zero. Whiskers(R"( function (x, y) -> product { - if and(iszero(iszero(x)), lt(div(, x), y)) { revert(0, 0) } - - product := mulmod(x, y, ) - - product := mul(x, y) - + + // overflow, if x > 0, y > 0 and x > (maxValue / y) + if and(and(sgt(x, 0), sgt(y, 0)), gt(x, div(, y))) { revert(0, 0) } + // underflow, if x > 0, y < 0 and y < (minValue / x) + if and(and(sgt(x, 0), slt(y, 0)), slt(y, sdiv(, x))) { revert(0, 0) } + // underflow, if x < 0, y > 0 and x < (minValue / y) + if and(and(slt(x, 0), sgt(y, 0)), slt(x, sdiv(, y))) { revert(0, 0) } + // overflow, if x < 0, y < 0 and x < (maxValue / y) + if and(and(slt(x, 0), slt(y, 0)), slt(x, sdiv(, y))) { revert(0, 0) } + + // overflow, if x != 0 and y > (maxValue / x) + if and(iszero(iszero(x)), gt(y, div(, x))) { revert(0, 0) } + + product := mul(x, y) } )") - ("shortType", _bits < 256) - ("functionName", functionName) - ("powerOfTwo", toCompactHexWithPrefix(u256(1) << _bits)) - ("mask", toCompactHexWithPrefix((u256(1) << _bits) - 1)) - .render(); + ("functionName", functionName) + ("signed", _type.isSigned()) + ("maxValue", toCompactHexWithPrefix(u256(_type.maxValue()))) + ("minValue", toCompactHexWithPrefix(u256(_type.minValue()))) + .render(); }); } @@ -620,7 +624,7 @@ string YulUtilFunctions::arrayConvertLengthToSize(ArrayType const& _type) ("multiSlot", baseType.storageSize() > 1) ("itemsPerSlot", to_string(32 / baseStorageBytes)) ("storageSize", baseType.storageSize().str()) - ("mul", overflowCheckedUIntMulFunction(TypeProvider::uint256()->numBits())) + ("mul", overflowCheckedIntMulFunction(*TypeProvider::uint256())) .render(); } case DataLocation::CallData: // fallthrough @@ -636,7 +640,7 @@ string YulUtilFunctions::arrayConvertLengthToSize(ArrayType const& _type) ("functionName", functionName) ("elementSize", _type.location() == DataLocation::Memory ? baseType.memoryHeadSize() : baseType.calldataEncodedSize()) ("byteArray", _type.isByteArray()) - ("mul", overflowCheckedUIntMulFunction(TypeProvider::uint256()->numBits())) + ("mul", overflowCheckedIntMulFunction(*TypeProvider::uint256())) .render(); default: solAssert(false, ""); diff --git a/libsolidity/codegen/YulUtilFunctions.h b/libsolidity/codegen/YulUtilFunctions.h index 1fd4e02e1..ccfa29cc7 100644 --- a/libsolidity/codegen/YulUtilFunctions.h +++ b/libsolidity/codegen/YulUtilFunctions.h @@ -96,7 +96,7 @@ public: std::string overflowCheckedIntAddFunction(IntegerType const& _type); /// signature: (x, y) -> product - std::string overflowCheckedUIntMulFunction(size_t _bits); + std::string overflowCheckedIntMulFunction(IntegerType const& _type); /// @returns name of function to perform division on integers. /// Checks for division by zero and the special case of diff --git a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp index 6919004b6..de1fe3ac4 100644 --- a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp +++ b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp @@ -1136,8 +1136,7 @@ string IRGeneratorForStatements::binaryOperation( fun = m_utils.overflowCheckedIntSubFunction(*type); break; case Token::Mul: - if (!type->isSigned()) - fun = m_utils.overflowCheckedUIntMulFunction(type->numBits()); + fun = m_utils.overflowCheckedIntMulFunction(*type); break; case Token::Div: fun = m_utils.overflowCheckedIntDivFunction(*type); diff --git a/test/libsolidity/semanticTests/viaYul/detect_mul_overflow.yul b/test/libsolidity/semanticTests/viaYul/detect_mul_overflow.sol similarity index 100% rename from test/libsolidity/semanticTests/viaYul/detect_mul_overflow.yul rename to test/libsolidity/semanticTests/viaYul/detect_mul_overflow.sol diff --git a/test/libsolidity/semanticTests/viaYul/detect_mul_overflow_signed.sol b/test/libsolidity/semanticTests/viaYul/detect_mul_overflow_signed.sol new file mode 100644 index 000000000..930a24ed7 --- /dev/null +++ b/test/libsolidity/semanticTests/viaYul/detect_mul_overflow_signed.sol @@ -0,0 +1,58 @@ +contract C { + function f(int a, int b) public pure returns (int x) { + x = a * b; + } + function g(int8 a, int8 b) public pure returns (int8 x) { + x = a * b; + } +} +// ==== +// compileViaYul: true +// ---- +// f(int256,int256): 5, 6 -> 30 +// f(int256,int256): -1, 1 -> -1 +// f(int256,int256): -1, 2 -> -2 +// # positive, positive # +// f(int256,int256): 0x3FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, 2 -> 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE +// f(int256,int256): 0x4000000000000000000000000000000000000000000000000000000000000000, 2 -> FAILURE +// f(int256,int256): 2, 0x3FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE +// f(int256,int256): 2, 0x4000000000000000000000000000000000000000000000000000000000000000 -> FAILURE +// # positive, negative # +// f(int256,int256): 0x4000000000000000000000000000000000000000000000000000000000000000, -2 -> 0x8000000000000000000000000000000000000000000000000000000000000000 +// f(int256,int256): 0x4000000000000000000000000000000000000000000000000000000000000001, -2 -> FAILURE +// f(int256,int256): 2, 0xC000000000000000000000000000000000000000000000000000000000000000 -> 0x8000000000000000000000000000000000000000000000000000000000000000 +// f(int256,int256): 2, 0xBFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> FAILURE +// # negative, positive # +// f(int256,int256): -2, 0x4000000000000000000000000000000000000000000000000000000000000000 -> 0x8000000000000000000000000000000000000000000000000000000000000000 +// f(int256,int256): -2, 0x4000000000000000000000000000000000000000000000000000000000000001 -> FAILURE +// f(int256,int256): 0xC000000000000000000000000000000000000000000000000000000000000000, 2 -> 0x8000000000000000000000000000000000000000000000000000000000000000 +// f(int256,int256): 0xBFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, 2 -> FAILURE +// # negative, negative # +// f(int256,int256): 0xC000000000000000000000000000000000000000000000000000000000000001, -2 -> 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE +// f(int256,int256): 0xC000000000000000000000000000000000000000000000000000000000000000, -2 -> FAILURE +// f(int256,int256): -2, 0xC000000000000000000000000000000000000000000000000000000000000001 -> 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE +// f(int256,int256): -2, 0xC000000000000000000000000000000000000000000000000000000000000000 -> FAILURE +// # small type # +// g(int8,int8): 5, 6 -> 30 +// g(int8,int8): -1, 1 -> -1 +// g(int8,int8): -1, 2 -> -2 +// # positive, positive # +// g(int8,int8): 63, 2 -> 126 +// g(int8,int8): 64, 2 -> FAILURE +// g(int8,int8): 2, 63 -> 126 +// g(int8,int8): 2, 64 -> FAILURE +// # positive, negative # +// g(int8,int8): 64, -2 -> -128 +// g(int8,int8): 65, -2 -> FAILURE +// g(int8,int8): 2, -64 -> -128 +// g(int8,int8): 2, -65 -> FAILURE +// # negative, positive # +// g(int8,int8): -2, 64 -> -128 +// g(int8,int8): -2, 65 -> FAILURE +// g(int8,int8): -64, 2 -> -128 +// g(int8,int8): -65, 2 -> FAILURE +// # negative, negative # +// g(int8,int8): -63, -2 -> 126 +// g(int8,int8): -64, -2 -> FAILURE +// g(int8,int8): -2, -63 -> 126 +// g(int8,int8): -2, -64 -> FAILURE