Merge pull request #6929 from ethereum/solYulBinaryOps

[Sol -> Yul] Checked signed arithmetic and modulo.
This commit is contained in:
chriseth 2019-06-20 13:02:30 +02:00 committed by GitHub
commit 848959fff0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 318 additions and 56 deletions

View File

@ -379,63 +379,70 @@ string YulUtilFunctions::roundUpFunction()
});
}
string YulUtilFunctions::overflowCheckedUIntAddFunction(size_t _bits)
string YulUtilFunctions::overflowCheckedIntAddFunction(IntegerType const& _type)
{
solAssert(0 < _bits && _bits <= 256 && _bits % 8 == 0, "");
string functionName = "checked_add_uint_" + to_string(_bits);
string functionName = "checked_add_" + _type.identifier();
// TODO: Consider to add a special case for unsigned 256-bit integers
// and use the following instead:
// sum := add(x, y) if lt(sum, x) { revert(0, 0) }
return m_functionCollector->createFunction(functionName, [&]() {
return
Whiskers(R"(
function <functionName>(x, y) -> sum {
<?shortType>
let mask := <mask>
sum := add(and(x, mask), and(y, mask))
if and(sum, not(mask)) { revert(0, 0) }
<!shortType>
sum := add(x, y)
if lt(sum, x) { revert(0, 0) }
</shortType>
<?signed>
// overflow, if x >= 0 and y > (maxValue - x)
if and(iszero(slt(x, 0)), sgt(y, sub(<maxValue>, x))) { revert(0, 0) }
// underflow, if x < 0 and y < (minValue - x)
if and(slt(x, 0), slt(y, sub(<minValue>, x))) { revert(0, 0) }
<!signed>
// overflow, if x > (maxValue - y)
if gt(x, sub(<maxValue>, y)) { revert(0, 0) }
</signed>
sum := add(x, y)
}
)")
("shortType", _bits < 256)
("functionName", functionName)
("mask", toCompactHexWithPrefix((u256(1) << _bits) - 1))
("signed", _type.isSigned())
("maxValue", toCompactHexWithPrefix(u256(_type.maxValue())))
("minValue", toCompactHexWithPrefix(u256(_type.minValue())))
.render();
});
}
string YulUtilFunctions::overflowCheckedUIntMulFunction(size_t _bits)
string YulUtilFunctions::overflowCheckedIntMulFunction(IntegerType const& _type)
{
solAssert(0 < _bits && _bits <= 256 && _bits % 8 == 0, "");
string functionName = "checked_mul_uint_" + to_string(_bits);
string functionName = "checked_mul_" + _type.identifier();
return m_functionCollector->createFunction(functionName, [&]() {
return
// - The current overflow check *before* the multiplication could
// be replaced by the following check *after* the multiplication:
// if and(iszero(iszero(x)), iszero(eq(div(product, x), y))) { revert(0, 0) }
// - The case the x equals 0 could be treated separately and directly return zero.
// Multiplication by zero could be treated separately and directly return zero.
Whiskers(R"(
function <functionName>(x, y) -> product {
if and(iszero(iszero(x)), lt(div(<mask>, x), y)) { revert(0, 0) }
<?shortType>
product := mulmod(x, y, <powerOfTwo>)
<!shortType>
product := mul(x, y)
</shortType>
<?signed>
// overflow, if x > 0, y > 0 and x > (maxValue / y)
if and(and(sgt(x, 0), sgt(y, 0)), gt(x, div(<maxValue>, y))) { revert(0, 0) }
// underflow, if x > 0, y < 0 and y < (minValue / x)
if and(and(sgt(x, 0), slt(y, 0)), slt(y, sdiv(<minValue>, x))) { revert(0, 0) }
// underflow, if x < 0, y > 0 and x < (minValue / y)
if and(and(slt(x, 0), sgt(y, 0)), slt(x, sdiv(<minValue>, y))) { revert(0, 0) }
// overflow, if x < 0, y < 0 and x < (maxValue / y)
if and(and(slt(x, 0), slt(y, 0)), slt(x, sdiv(<maxValue>, y))) { revert(0, 0) }
<!signed>
// overflow, if x != 0 and y > (maxValue / x)
if and(iszero(iszero(x)), gt(y, div(<maxValue>, x))) { revert(0, 0) }
</signed>
product := mul(x, y)
}
)")
("shortType", _bits < 256)
("functionName", functionName)
("powerOfTwo", toCompactHexWithPrefix(u256(1) << _bits))
("mask", toCompactHexWithPrefix((u256(1) << _bits) - 1))
.render();
("functionName", functionName)
("signed", _type.isSigned())
("maxValue", toCompactHexWithPrefix(u256(_type.maxValue())))
("minValue", toCompactHexWithPrefix(u256(_type.minValue())))
.render();
});
}
string YulUtilFunctions::overflowCheckedIntDivFunction(IntegerType const& _type)
{
unsigned bits = _type.numBits();
solAssert(0 < bits && bits <= 256 && bits % 8 == 0, "");
string functionName = "checked_div_" + _type.identifier();
return m_functionCollector->createFunction(functionName, [&]() {
return
@ -443,7 +450,7 @@ string YulUtilFunctions::overflowCheckedIntDivFunction(IntegerType const& _type)
function <functionName>(x, y) -> r {
if iszero(y) { revert(0, 0) }
<?signed>
// x / -1 == x
// overflow for minVal / -1
if and(
eq(x, <minVal>),
eq(y, sub(0, 1))
@ -452,25 +459,52 @@ string YulUtilFunctions::overflowCheckedIntDivFunction(IntegerType const& _type)
r := <?signed>s</signed>div(x, y)
}
)")
("functionName", functionName)
("signed", _type.isSigned())
("minVal", (0 - (u256(1) << (bits - 1))).str())
.render();
("functionName", functionName)
("signed", _type.isSigned())
("minVal", toCompactHexWithPrefix(u256(_type.minValue())))
.render();
});
}
string YulUtilFunctions::overflowCheckedUIntSubFunction()
string YulUtilFunctions::checkedIntModFunction(IntegerType const& _type)
{
string functionName = "checked_sub_uint";
string functionName = "checked_mod_" + _type.identifier();
return m_functionCollector->createFunction(functionName, [&]() {
return
Whiskers(R"(
function <functionName>(x, y) -> r {
if iszero(y) { revert(0, 0) }
r := <?signed>s</signed>mod(x, y)
}
)")
("functionName", functionName)
("signed", _type.isSigned())
.render();
});
}
string YulUtilFunctions::overflowCheckedIntSubFunction(IntegerType const& _type)
{
string functionName = "checked_sub_" + _type.identifier();
return m_functionCollector->createFunction(functionName, [&] {
return
Whiskers(R"(
function <functionName>(x, y) -> diff {
if lt(x, y) { revert(0, 0) }
<?signed>
// underflow, if y >= 0 and x < (minValue + y)
if and(iszero(slt(y, 0)), slt(x, add(<minValue>, y))) { revert(0, 0) }
// overflow, if y < 0 and x > (maxValue + y)
if and(slt(y, 0), sgt(x, add(<maxValue>, y))) { revert(0, 0) }
<!signed>
if lt(x, y) { revert(0, 0) }
</signed>
diff := sub(x, y)
}
)")
("functionName", functionName)
("signed", _type.isSigned())
("maxValue", toCompactHexWithPrefix(u256(_type.maxValue())))
("minValue", toCompactHexWithPrefix(u256(_type.minValue())))
.render();
});
}
@ -607,7 +641,7 @@ string YulUtilFunctions::arrayConvertLengthToSize(ArrayType const& _type)
("multiSlot", baseType.storageSize() > 1)
("itemsPerSlot", to_string(32 / baseStorageBytes))
("storageSize", baseType.storageSize().str())
("mul", overflowCheckedUIntMulFunction(TypeProvider::uint256()->numBits()))
("mul", overflowCheckedIntMulFunction(*TypeProvider::uint256()))
.render();
}
case DataLocation::CallData: // fallthrough
@ -623,7 +657,7 @@ string YulUtilFunctions::arrayConvertLengthToSize(ArrayType const& _type)
("functionName", functionName)
("elementSize", _type.location() == DataLocation::Memory ? baseType.memoryHeadSize() : baseType.calldataEncodedSize())
("byteArray", _type.isByteArray())
("mul", overflowCheckedUIntMulFunction(TypeProvider::uint256()->numBits()))
("mul", overflowCheckedIntMulFunction(*TypeProvider::uint256()))
.render();
default:
solAssert(false, "");

View File

@ -93,20 +93,24 @@ public:
std::string roundUpFunction();
/// signature: (x, y) -> sum
std::string overflowCheckedUIntAddFunction(size_t _bits);
std::string overflowCheckedIntAddFunction(IntegerType const& _type);
/// signature: (x, y) -> product
std::string overflowCheckedUIntMulFunction(size_t _bits);
std::string overflowCheckedIntMulFunction(IntegerType const& _type);
/// @returns name of function to perform division on integers.
/// Checks for division by zero and the special case of
/// signed division of the smallest number by -1.
std::string overflowCheckedIntDivFunction(IntegerType const& _type);
/// @returns name of function to perform modulo on integers.
/// Reverts for modulo by zero.
std::string checkedIntModFunction(IntegerType const& _type);
/// @returns computes the difference between two values.
/// Assumes the input to be in range for the type.
/// signature: (x, y) -> diff
std::string overflowCheckedUIntSubFunction();
std::string overflowCheckedIntSubFunction(IntegerType const& _type);
/// @returns the name of a function that fetches the length of the given
/// array

View File

@ -1126,18 +1126,28 @@ string IRGeneratorForStatements::binaryOperation(
if (IntegerType const* type = dynamic_cast<IntegerType const*>(&_type))
{
string fun;
// TODO: Only division is implemented for signed integers for now.
if (!type->isSigned())
// TODO: Implement all operations for signed and unsigned types.
switch (_operator)
{
if (_operator == Token::Add)
fun = m_utils.overflowCheckedUIntAddFunction(type->numBits());
else if (_operator == Token::Sub)
fun = m_utils.overflowCheckedUIntSubFunction();
else if (_operator == Token::Mul)
fun = m_utils.overflowCheckedUIntMulFunction(type->numBits());
case Token::Add:
fun = m_utils.overflowCheckedIntAddFunction(*type);
break;
case Token::Sub:
fun = m_utils.overflowCheckedIntSubFunction(*type);
break;
case Token::Mul:
fun = m_utils.overflowCheckedIntMulFunction(*type);
break;
case Token::Div:
fun = m_utils.overflowCheckedIntDivFunction(*type);
break;
case Token::Mod:
fun = m_utils.checkedIntModFunction(*type);
break;
default:
break;
}
if (_operator == Token::Div)
fun = m_utils.overflowCheckedIntDivFunction(*type);
solUnimplementedAssert(!fun.empty(), "");
return fun + "(" + _left + ", " + _right + ")\n";
}

View File

@ -0,0 +1,37 @@
contract C {
function f(int a, int b) public pure returns (int x) {
x = a + b;
}
function g(int8 a, int8 b) public pure returns (int8 x) {
x = a + b;
}
}
// ====
// compileViaYul: true
// ----
// f(int256,int256): 5, 6 -> 11
// f(int256,int256): -2, 1 -> -1
// f(int256,int256): -2, 2 -> 0
// f(int256,int256): 2, -2 -> 0
// f(int256,int256): -5, -6 -> -11
// f(int256,int256): 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0, 0x0F -> 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
// f(int256,int256): 0x0F, 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0 -> 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
// f(int256,int256): 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, 1 -> FAILURE
// f(int256,int256): 1, 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> FAILURE
// f(int256,int256): 0x8000000000000000000000000000000000000000000000000000000000000001, -1 -> 0x8000000000000000000000000000000000000000000000000000000000000000
// f(int256,int256): -1, 0x8000000000000000000000000000000000000000000000000000000000000001 -> 0x8000000000000000000000000000000000000000000000000000000000000000
// f(int256,int256): 0x8000000000000000000000000000000000000000000000000000000000000000, -1 -> FAILURE
// f(int256,int256): -1, 0x8000000000000000000000000000000000000000000000000000000000000000 -> FAILURE
// g(int8,int8): 5, 6 -> 11
// g(int8,int8): -2, 1 -> -1
// g(int8,int8): -2, 2 -> 0
// g(int8,int8): 2, -2 -> 0
// g(int8,int8): -5, -6 -> -11
// g(int8,int8): 126, 1 -> 127
// g(int8,int8): 1, 126 -> 127
// g(int8,int8): 127, 1 -> FAILURE
// g(int8,int8): 1, 127 -> FAILURE
// g(int8,int8): -127, -1 -> -128
// g(int8,int8): -1, -127 -> -128
// g(int8,int8): -127, -2 -> FAILURE
// g(int8,int8): -2, -127 -> FAILURE

View File

@ -0,0 +1,25 @@
contract C {
function f(uint a, uint b) public pure returns (uint x) {
x = a % b;
}
function g(uint8 a, uint8 b) public pure returns (uint8 x) {
x = a % b;
}
}
// ====
// compileViaYul: only
// ----
// f(uint256,uint256): 10, 3 -> 1
// f(uint256,uint256): 10, 2 -> 0
// f(uint256,uint256): 11, 2 -> 1
// f(uint256,uint256): 2, 2 -> 0
// f(uint256,uint256): 1, 0 -> FAILURE
// f(uint256,uint256): 0, 0 -> FAILURE
// f(uint256,uint256): 0, 1 -> 0
// g(uint8,uint8): 10, 3 -> 1
// g(uint8,uint8): 10, 2 -> 0
// g(uint8,uint8): 11, 2 -> 1
// g(uint8,uint8): 2, 2 -> 0
// g(uint8,uint8): 1, 0 -> FAILURE
// g(uint8,uint8): 0, 0 -> FAILURE
// g(uint8,uint8): 0, 1 -> 0

View File

@ -0,0 +1,35 @@
contract C {
function f(int a, int b) public pure returns (int x) {
x = a % b;
}
function g(int8 a, int8 b) public pure returns (int8 x) {
x = a % b;
}
}
// ====
// compileViaYul: only
// ----
// f(int256,int256): 10, 3 -> 1
// f(int256,int256): 10, 2 -> 0
// f(int256,int256): 11, 2 -> 1
// f(int256,int256): -10, 3 -> -1
// f(int256,int256): 10, -3 -> 1
// f(int256,int256): -10, -3 -> -1
// f(int256,int256): 2, 2 -> 0
// f(int256,int256): 1, 0 -> FAILURE
// f(int256,int256): -1, 0 -> FAILURE
// f(int256,int256): 0, 0 -> FAILURE
// f(int256,int256): 0, 1 -> 0
// f(int256,int256): 0, -1 -> 0
// g(int8,int8): 10, 3 -> 1
// g(int8,int8): 10, 2 -> 0
// g(int8,int8): 11, 2 -> 1
// g(int8,int8): -10, 3 -> -1
// g(int8,int8): 10, -3 -> 1
// g(int8,int8): -10, -3 -> -1
// g(int8,int8): 2, 2 -> 0
// g(int8,int8): 1, 0 -> FAILURE
// g(int8,int8): -1, 0 -> FAILURE
// g(int8,int8): 0, 0 -> FAILURE
// g(int8,int8): 0, 1 -> 0
// g(int8,int8): 0, -1 -> 0

View File

@ -0,0 +1,58 @@
contract C {
function f(int a, int b) public pure returns (int x) {
x = a * b;
}
function g(int8 a, int8 b) public pure returns (int8 x) {
x = a * b;
}
}
// ====
// compileViaYul: true
// ----
// f(int256,int256): 5, 6 -> 30
// f(int256,int256): -1, 1 -> -1
// f(int256,int256): -1, 2 -> -2
// # positive, positive #
// f(int256,int256): 0x3FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, 2 -> 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE
// f(int256,int256): 0x4000000000000000000000000000000000000000000000000000000000000000, 2 -> FAILURE
// f(int256,int256): 2, 0x3FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE
// f(int256,int256): 2, 0x4000000000000000000000000000000000000000000000000000000000000000 -> FAILURE
// # positive, negative #
// f(int256,int256): 0x4000000000000000000000000000000000000000000000000000000000000000, -2 -> 0x8000000000000000000000000000000000000000000000000000000000000000
// f(int256,int256): 0x4000000000000000000000000000000000000000000000000000000000000001, -2 -> FAILURE
// f(int256,int256): 2, 0xC000000000000000000000000000000000000000000000000000000000000000 -> 0x8000000000000000000000000000000000000000000000000000000000000000
// f(int256,int256): 2, 0xBFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> FAILURE
// # negative, positive #
// f(int256,int256): -2, 0x4000000000000000000000000000000000000000000000000000000000000000 -> 0x8000000000000000000000000000000000000000000000000000000000000000
// f(int256,int256): -2, 0x4000000000000000000000000000000000000000000000000000000000000001 -> FAILURE
// f(int256,int256): 0xC000000000000000000000000000000000000000000000000000000000000000, 2 -> 0x8000000000000000000000000000000000000000000000000000000000000000
// f(int256,int256): 0xBFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, 2 -> FAILURE
// # negative, negative #
// f(int256,int256): 0xC000000000000000000000000000000000000000000000000000000000000001, -2 -> 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE
// f(int256,int256): 0xC000000000000000000000000000000000000000000000000000000000000000, -2 -> FAILURE
// f(int256,int256): -2, 0xC000000000000000000000000000000000000000000000000000000000000001 -> 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE
// f(int256,int256): -2, 0xC000000000000000000000000000000000000000000000000000000000000000 -> FAILURE
// # small type #
// g(int8,int8): 5, 6 -> 30
// g(int8,int8): -1, 1 -> -1
// g(int8,int8): -1, 2 -> -2
// # positive, positive #
// g(int8,int8): 63, 2 -> 126
// g(int8,int8): 64, 2 -> FAILURE
// g(int8,int8): 2, 63 -> 126
// g(int8,int8): 2, 64 -> FAILURE
// # positive, negative #
// g(int8,int8): 64, -2 -> -128
// g(int8,int8): 65, -2 -> FAILURE
// g(int8,int8): 2, -64 -> -128
// g(int8,int8): 2, -65 -> FAILURE
// # negative, positive #
// g(int8,int8): -2, 64 -> -128
// g(int8,int8): -2, 65 -> FAILURE
// g(int8,int8): -64, 2 -> -128
// g(int8,int8): -65, 2 -> FAILURE
// # negative, negative #
// g(int8,int8): -63, -2 -> 126
// g(int8,int8): -64, -2 -> FAILURE
// g(int8,int8): -2, -63 -> 126
// g(int8,int8): -2, -64 -> FAILURE

View File

@ -0,0 +1,17 @@
contract C {
function f(uint a, uint b) public pure returns (uint x) {
x = a - b;
}
function g(uint8 a, uint8 b) public pure returns (uint8 x) {
x = a - b;
}
}
// ====
// compileViaYul: true
// ----
// f(uint256,uint256): 6, 5 -> 1
// f(uint256,uint256): 6, 6 -> 0
// f(uint256,uint256): 5, 6 -> FAILURE
// g(uint8,uint8): 6, 5 -> 1
// g(uint8,uint8): 6, 6 -> 0
// g(uint8,uint8): 5, 6 -> FAILURE

View File

@ -0,0 +1,42 @@
contract C {
function f(int a, int b) public pure returns (int x) {
x = a - b;
}
function g(int8 a, int8 b) public pure returns (int8 x) {
x = a - b;
}
}
// ====
// compileViaYul: true
// ----
// f(int256,int256): 5, 6 -> -1
// f(int256,int256): -2, 1 -> -3
// f(int256,int256): -2, 2 -> -4
// f(int256,int256): 2, -2 -> 4
// f(int256,int256): 2, 2 -> 0
// f(int256,int256): -5, -6 -> 1
// f(int256,int256): 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0, -15 -> 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
// f(int256,int256): 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0, -16 -> FAILURE
// f(int256,int256): 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, -1 -> FAILURE
// f(int256,int256): 15, 0x8000000000000000000000000000000000000000000000000000000000000010 -> 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
// f(int256,int256): 16, 0x8000000000000000000000000000000000000000000000000000000000000010 -> FAILURE
// f(int256,int256): 1, 0x8000000000000000000000000000000000000000000000000000000000000000 -> FAILURE
// f(int256,int256): -1, 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> 0x8000000000000000000000000000000000000000000000000000000000000000
// f(int256,int256): -2, 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF -> FAILURE
// f(int256,int256): 0x8000000000000000000000000000000000000000000000000000000000000001, 1 -> 0x8000000000000000000000000000000000000000000000000000000000000000
// f(int256,int256): 0x8000000000000000000000000000000000000000000000000000000000000001, 2 -> FAILURE
// f(int256,int256): 0x8000000000000000000000000000000000000000000000000000000000000000, 1 -> FAILURE
// g(int8,int8): 5, 6 -> -1
// g(int8,int8): -2, 1 -> -3
// g(int8,int8): -2, 2 -> -4
// g(int8,int8): 2, -2 -> 4
// g(int8,int8): 2, 2 -> 0
// g(int8,int8): -5, -6 -> 1
// g(int8,int8): 126, -1 -> 127
// g(int8,int8): 1, -126 -> 127
// g(int8,int8): 127, -1 -> FAILURE
// g(int8,int8): 1, -127 -> FAILURE
// g(int8,int8): -127, 1 -> -128
// g(int8,int8): -1, 127 -> -128
// g(int8,int8): -127, 2 -> FAILURE
// g(int8,int8): -2, 127 -> FAILURE