diff --git a/libsolidity/codegen/YulUtilFunctions.cpp b/libsolidity/codegen/YulUtilFunctions.cpp index aaceffa03..45cbfd6b4 100644 --- a/libsolidity/codegen/YulUtilFunctions.cpp +++ b/libsolidity/codegen/YulUtilFunctions.cpp @@ -346,6 +346,34 @@ string YulUtilFunctions::overflowCheckedUIntAddFunction(size_t _bits) }); } +string YulUtilFunctions::overflowCheckedUIntMulFunction(size_t _bits) +{ + solAssert(0 < _bits && _bits <= 256 && _bits % 8 == 0, ""); + string functionName = "checked_mul_uint_" + to_string(_bits); + return m_functionCollector->createFunction(functionName, [&]() { + return + // - The current overflow check *before* the multiplication could + // be replaced by the following check *after* the multiplication: + // if and(iszero(iszero(x)), iszero(eq(div(product, x), y))) { revert(0, 0) } + // - The case the x equals 0 could be treated separately and directly return zero. + Whiskers(R"( + function (x, y) -> product { + if and(iszero(iszero(x)), lt(div(, x), y)) { revert(0, 0) } + + product := mulmod(x, y, ) + + product := mul(x, y) + + } + )") + ("shortType", _bits < 256) + ("functionName", functionName) + ("powerOfTwo", toCompactHexWithPrefix(u256(1) << _bits)) + ("mask", toCompactHexWithPrefix((u256(1) << _bits) - 1)) + .render(); + }); +} + string YulUtilFunctions::overflowCheckedUIntSubFunction() { string functionName = "checked_sub_uint"; diff --git a/libsolidity/codegen/YulUtilFunctions.h b/libsolidity/codegen/YulUtilFunctions.h index 54b9c2546..25740db99 100644 --- a/libsolidity/codegen/YulUtilFunctions.h +++ b/libsolidity/codegen/YulUtilFunctions.h @@ -86,6 +86,8 @@ public: std::string overflowCheckedUIntAddFunction(size_t _bits); + std::string overflowCheckedUIntMulFunction(size_t _bits); + /// @returns computes the difference between two values. /// Assumes the input to be in range for the type. std::string overflowCheckedUIntSubFunction(); diff --git a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp index c2d727a03..76ca6b4aa 100644 --- a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp +++ b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp @@ -304,6 +304,8 @@ bool IRGeneratorForStatements::visit(BinaryOperation const& _binOp) fun = m_utils.overflowCheckedUIntAddFunction(type->numBits()); else if (_binOp.getOperator() == Token::Sub) fun = m_utils.overflowCheckedUIntSubFunction(); + else if (_binOp.getOperator() == Token::Mul) + fun = m_utils.overflowCheckedUIntMulFunction(type->numBits()); else solUnimplementedAssert(false, ""); defineExpression(_binOp) << fun << "(" << left << ", " << right << ")\n"; diff --git a/test/libsolidity/semanticTests/viaYul/detect_mul_overflow.yul b/test/libsolidity/semanticTests/viaYul/detect_mul_overflow.yul new file mode 100644 index 000000000..c49758686 --- /dev/null +++ b/test/libsolidity/semanticTests/viaYul/detect_mul_overflow.yul @@ -0,0 +1,37 @@ +contract C { + function f(uint a, uint b) public pure returns (uint x) { + x = a * b; + } + function g(uint8 a, uint8 b) public pure returns (uint8 x) { + x = a * b; + } +} +// ==== +// compileViaYul: true +// ---- +// f(uint256,uint256): 5, 6 -> 30 +// f(uint256,uint256): -1, 1 -> -1 +// f(uint256,uint256): -1, 2 -> FAILURE +// f(uint256,uint256): 0x8000000000000000000000000000000000000000000000000000000000000000, 2 -> FAILURE +// f(uint256,uint256): 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, 2 -> -2 +// f(uint256,uint256): 2, 0x8000000000000000000000000000000000000000000000000000000000000000 -> FAILURE +// f(uint256,uint256): 2, 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> -2 +// f(uint256,uint256): 0x0100000000000000000000000000000000, 0x0100000000000000000000000000000000 -> FAILURE +// f(uint256,uint256): 0x0100000000000000000000000000000000, 0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000000000000000000000000000 +// f(uint256,uint256): 0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, 0x0100000000000000000000000000000000 -> 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000000000000000000000000000 +// f(uint256,uint256): 0x0100000000000000000000000000000001, 0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> -1 +// f(uint256,uint256): 0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, 0x0100000000000000000000000000000001 -> -1 +// f(uint256,uint256): 0x0100000000000000000000000000000002, 0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> FAILURE +// f(uint256,uint256): 0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, 0x0100000000000000000000000000000002 -> FAILURE +// f(uint256,uint256): -1, 0 -> 0 +// f(uint256,uint256): 0, -1 -> 0 +// g(uint8,uint8): 5, 6 -> 30 +// g(uint8,uint8): 0x80, 2 -> FAILURE +// g(uint8,uint8): 0x7F, 2 -> 254 +// g(uint8,uint8): 2, 0x7F -> 254 +// g(uint8,uint8): 0x10, 0x10 -> FAILURE +// g(uint8,uint8): 0x0F, 0x11 -> 0xFF +// g(uint8,uint8): 0x0F, 0x12 -> FAILURE +// g(uint8,uint8): 0x12, 0x0F -> FAILURE +// g(uint8,uint8): 0xFF, 0 -> 0 +// g(uint8,uint8): 0, 0xFF -> 0