[Sol - Yul] Implement checked multiplication.

This commit is contained in:
Daniel Kirchner 2019-05-20 16:42:27 +02:00
parent e08f521b7e
commit b6f4d4e9bc
4 changed files with 69 additions and 0 deletions

View File

@ -346,6 +346,34 @@ string YulUtilFunctions::overflowCheckedUIntAddFunction(size_t _bits)
});
}
string YulUtilFunctions::overflowCheckedUIntMulFunction(size_t _bits)
{
solAssert(0 < _bits && _bits <= 256 && _bits % 8 == 0, "");
string functionName = "checked_mul_uint_" + to_string(_bits);
return m_functionCollector->createFunction(functionName, [&]() {
return
// - The current overflow check *before* the multiplication could
// be replaced by the following check *after* the multiplication:
// if and(iszero(iszero(x)), iszero(eq(div(product, x), y))) { revert(0, 0) }
// - The case the x equals 0 could be treated separately and directly return zero.
Whiskers(R"(
function <functionName>(x, y) -> product {
if and(iszero(iszero(x)), lt(div(<mask>, x), y)) { revert(0, 0) }
<?shortType>
product := mulmod(x, y, <powerOfTwo>)
<!shortType>
product := mul(x, y)
</shortType>
}
)")
("shortType", _bits < 256)
("functionName", functionName)
("powerOfTwo", toCompactHexWithPrefix(u256(1) << _bits))
("mask", toCompactHexWithPrefix((u256(1) << _bits) - 1))
.render();
});
}
string YulUtilFunctions::overflowCheckedUIntSubFunction()
{
string functionName = "checked_sub_uint";

View File

@ -86,6 +86,8 @@ public:
std::string overflowCheckedUIntAddFunction(size_t _bits);
std::string overflowCheckedUIntMulFunction(size_t _bits);
/// @returns computes the difference between two values.
/// Assumes the input to be in range for the type.
std::string overflowCheckedUIntSubFunction();

View File

@ -304,6 +304,8 @@ bool IRGeneratorForStatements::visit(BinaryOperation const& _binOp)
fun = m_utils.overflowCheckedUIntAddFunction(type->numBits());
else if (_binOp.getOperator() == Token::Sub)
fun = m_utils.overflowCheckedUIntSubFunction();
else if (_binOp.getOperator() == Token::Mul)
fun = m_utils.overflowCheckedUIntMulFunction(type->numBits());
else
solUnimplementedAssert(false, "");
defineExpression(_binOp) << fun << "(" << left << ", " << right << ")\n";

View File

@ -0,0 +1,37 @@
contract C {
function f(uint a, uint b) public pure returns (uint x) {
x = a * b;
}
function g(uint8 a, uint8 b) public pure returns (uint8 x) {
x = a * b;
}
}
// ====
// compileViaYul: true
// ----
// f(uint256,uint256): 5, 6 -> 30
// f(uint256,uint256): -1, 1 -> -1
// f(uint256,uint256): -1, 2 -> FAILURE
// f(uint256,uint256): 0x8000000000000000000000000000000000000000000000000000000000000000, 2 -> FAILURE
// f(uint256,uint256): 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, 2 -> -2
// f(uint256,uint256): 2, 0x8000000000000000000000000000000000000000000000000000000000000000 -> FAILURE
// f(uint256,uint256): 2, 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> -2
// f(uint256,uint256): 0x0100000000000000000000000000000000, 0x0100000000000000000000000000000000 -> FAILURE
// f(uint256,uint256): 0x0100000000000000000000000000000000, 0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000000000000000000000000000
// f(uint256,uint256): 0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, 0x0100000000000000000000000000000000 -> 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000000000000000000000000000
// f(uint256,uint256): 0x0100000000000000000000000000000001, 0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> -1
// f(uint256,uint256): 0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, 0x0100000000000000000000000000000001 -> -1
// f(uint256,uint256): 0x0100000000000000000000000000000002, 0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> FAILURE
// f(uint256,uint256): 0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, 0x0100000000000000000000000000000002 -> FAILURE
// f(uint256,uint256): -1, 0 -> 0
// f(uint256,uint256): 0, -1 -> 0
// g(uint8,uint8): 5, 6 -> 30
// g(uint8,uint8): 0x80, 2 -> FAILURE
// g(uint8,uint8): 0x7F, 2 -> 254
// g(uint8,uint8): 2, 0x7F -> 254
// g(uint8,uint8): 0x10, 0x10 -> FAILURE
// g(uint8,uint8): 0x0F, 0x11 -> 0xFF
// g(uint8,uint8): 0x0F, 0x12 -> FAILURE
// g(uint8,uint8): 0x12, 0x0F -> FAILURE
// g(uint8,uint8): 0xFF, 0 -> 0
// g(uint8,uint8): 0, 0xFF -> 0