Implement checked exponentiation.

This commit is contained in:
chriseth 2020-07-22 16:52:28 +02:00
parent 660ef792ab
commit c34e349572
9 changed files with 325 additions and 2 deletions

View File

@ -597,6 +597,139 @@ string YulUtilFunctions::overflowCheckedIntSubFunction(IntegerType const& _type)
});
}
string YulUtilFunctions::overflowCheckedIntExpFunction(
IntegerType const& _type,
IntegerType const& _exponentType
)
{
solAssert(!_exponentType.isSigned(), "");
string functionName = "checked_exp_" + _type.identifier() + "_" + _exponentType.identifier();
return m_functionCollector.createFunction(functionName, [&]() {
return
Whiskers(R"(
function <functionName>(base, exponent) -> power {
base := <baseCleanupFunction>(base)
exponent := <exponentCleanupFunction>(exponent)
<?signed>
power := <exp>(base, exponent, <minValue>, <maxValue>)
<!signed>
power := <exp>(base, exponent, <maxValue>)
</signed>
}
)")
("functionName", functionName)
("signed", _type.isSigned())
("exp", _type.isSigned() ? overflowCheckedSignedExpFunction() : overflowCheckedUnsignedExpFunction())
("maxValue", toCompactHexWithPrefix(_type.max()))
("minValue", toCompactHexWithPrefix(_type.min()))
("baseCleanupFunction", cleanupFunction(_type))
("exponentCleanupFunction", cleanupFunction(_exponentType))
.render();
});
}
string YulUtilFunctions::overflowCheckedUnsignedExpFunction()
{
string functionName = "checked_exp_unsigned";
return m_functionCollector.createFunction(functionName, [&]() {
return
Whiskers(R"(
function <functionName>(base, exponent, max) -> power {
// This function currently cannot be inlined because of the
// "leave" statements. We have to improve the optimizer.
// Note that 0**0 == 1
if iszero(exponent) { power := 1 leave }
if iszero(base) { power := 0 leave }
power := 1
for { } gt(exponent, 1) {}
{
// overflow check for base * base
if gt(base, div(max, base)) { revert(0, 0) }
if and(exponent, 1)
{
// no check needed here because base >= power
power := mul(power, base)
}
base := mul(base, base)
exponent := <shr_1>(exponent)
}
if gt(power, div(max, base)) { revert(0, 0) }
power := mul(power, base)
}
)")
("functionName", functionName)
("shr_1", shiftRightFunction(1))
.render();
});
}
string YulUtilFunctions::overflowCheckedSignedExpFunction()
{
string functionName = "checked_exp_signed";
return m_functionCollector.createFunction(functionName, [&]() {
return
Whiskers(R"(
function <functionName>(base, exponent, min, max) -> power {
// Currently, `leave` avoids this function being inlined.
// We have to improve the optimizer.
// Note that 0**0 == 1
switch exponent
case 0 { power := 1 leave }
case 1 { power := base leave }
if iszero(base) { power := 0 leave }
power := 1
// We pull out the first iteration because it is the only one in which
// base can be negative.
// Exponent is at least 2 here.
// overflow check for base * base
switch sgt(base, 0)
case 1 { if gt(base, div(max, base)) { revert(0, 0) } }
case 0 { if slt(base, sdiv(max, base)) { revert(0, 0) } }
if and(exponent, 1)
{
power := base
}
base := mul(base, base)
exponent := <shr_1>(exponent)
// Below this point, base is always positive.
for { } gt(exponent, 1) {}
{
// overflow check for base * base
if gt(base, div(max, base)) { revert(0, 0) }
if and(exponent, 1)
{
// No checks for power := mul(power, base) needed, because the check
// for base * base above is sufficient, since:
// |power| <= base (proof by induction) and thus:
// |power * base| <= base * base <= max <= |min|
power := mul(power, base)
}
base := mul(base, base)
exponent := <shr_1>(exponent)
}
if and(sgt(power, 0), gt(power, div(max, base))) { revert(0, 0) }
if and(slt(power, 0), slt(power, sdiv(min, base))) { revert(0, 0) }
power := mul(power, base)
}
)")
("functionName", functionName)
("shr_1", shiftRightFunction(1))
.render();
});
}
string YulUtilFunctions::extractByteArrayLengthFunction()
{
string functionName = "extract_byte_array_length";

View File

@ -125,6 +125,21 @@ public:
/// signature: (x, y) -> diff
std::string overflowCheckedIntSubFunction(IntegerType const& _type);
/// @returns the name of the exponentiation function.
/// signature: (base, exponent) -> power
std::string overflowCheckedIntExpFunction(IntegerType const& _type, IntegerType const& _exponentType);
/// Generic unsigned checked exponentiation function.
/// Reverts if the result is larger than max.
/// signature: (base, exponent, max) -> power
std::string overflowCheckedUnsignedExpFunction();
/// Generic signed checked exponentiation function.
/// Reverts if the result is smaller than min or larger than max.
/// The code relies on max <= |min| and min < 0.
/// signature: (base, exponent, min, max) -> power
std::string overflowCheckedSignedExpFunction();
/// @returns the name of a function that fetches the length of the given
/// array
/// signature: (array) -> length

View File

@ -277,6 +277,7 @@ bool IRGeneratorForStatements::visit(Assignment const& _assignment)
solAssert(type(_assignment.leftHandSide()).isValueType(), "Compound operators only available for value types.");
solAssert(rightIntermediateType->isValueType(), "Compound operators only available for value types.");
IRVariable leftIntermediate = readFromLValue(*m_currentLValue);
solAssert(binaryOperator != Token::Exp, "");
if (TokenTraits::isShiftOp(binaryOperator))
{
solAssert(type(_assignment) == leftIntermediate.type(), "");
@ -593,11 +594,17 @@ bool IRGeneratorForStatements::visit(BinaryOperation const& _binOp)
solAssert(false, "Unknown comparison operator.");
define(_binOp) << expr << "\n";
}
else if (TokenTraits::isShiftOp(op))
else if (TokenTraits::isShiftOp(op) || op == Token::Exp)
{
IRVariable left = convert(_binOp.leftExpression(), *commonType);
IRVariable right = convert(_binOp.rightExpression(), *type(_binOp.rightExpression()).mobileType());
define(_binOp) << shiftOperation(_binOp.getOperator(), left, right) << "\n";
if (op == Token::Exp)
define(_binOp) << m_utils.overflowCheckedIntExpFunction(
dynamic_cast<IntegerType const&>(left.type()),
dynamic_cast<IntegerType const&>(right.type())
) << "(" << left.name() << ", " << right.name() << ")\n";
else
define(_binOp) << shiftOperation(_binOp.getOperator(), left, right) << "\n";
}
else
{

View File

@ -10,5 +10,7 @@ contract test {
return (x**y1, x**y2);
}
}
// ====
// compileViaYul: also
// ----
// f() -> 9, -27

View File

@ -0,0 +1,19 @@
contract C {
function f(uint x, uint y) public returns (uint) {
return x**y;
}
}
// ====
// compileViaYul: also
// ----
// f(uint256,uint256): 0, 0 -> 1
// f(uint256,uint256): 0, 1 -> 0x00
// f(uint256,uint256): 0, 2 -> 0x00
// f(uint256,uint256): 1, 0 -> 1
// f(uint256,uint256): 1, 1 -> 1
// f(uint256,uint256): 1, 2 -> 1
// f(uint256,uint256): 2, 0 -> 1
// f(uint256,uint256): 2, 1 -> 2
// f(uint256,uint256): 2, 2 -> 4
// f(uint256,uint256): 7, 63 -> 174251498233690814305510551794710260107945042018748343
// f(uint256,uint256): 128, 2 -> 0x4000

View File

@ -0,0 +1,28 @@
contract C {
function f(int x, uint y) public returns (int) {
return x**y;
}
}
// ====
// compileViaYul: also
// ----
// f(int256,uint256): 0, 0 -> 1
// f(int256,uint256): 0, 1 -> 0x00
// f(int256,uint256): 0, 2 -> 0x00
// f(int256,uint256): 1, 0 -> 1
// f(int256,uint256): 1, 1 -> 1
// f(int256,uint256): 1, 2 -> 1
// f(int256,uint256): 2, 0 -> 1
// f(int256,uint256): 2, 1 -> 2
// f(int256,uint256): 2, 2 -> 4
// f(int256,uint256): 7, 63 -> 174251498233690814305510551794710260107945042018748343
// f(int256,uint256): 128, 2 -> 0x4000
// f(int256,uint256): -1, 0 -> 1
// f(int256,uint256): -1, 1 -> -1
// f(int256,uint256): -1, 2 -> 1
// f(int256,uint256): -2, 0 -> 1
// f(int256,uint256): -2, 1 -> -2
// f(int256,uint256): -2, 2 -> 4
// f(int256,uint256): -7, 63 -> -174251498233690814305510551794710260107945042018748343
// f(int256,uint256): -128, 2 -> 0x4000
// f(int256,uint256): -1, 115792089237316195423570985008687907853269984665640564039457584007913129639935 -> -1

View File

@ -0,0 +1,38 @@
contract C {
function f(int8 x, uint y) public returns (int) {
return x**y;
}
function g(int256 x, uint y) public returns (int) {
return x**y;
}
}
// ====
// compileViaYul: true
// ----
// f(int8,uint256): 2, 6 -> 64
// f(int8,uint256): 2, 7 -> FAILURE
// f(int8,uint256): 2, 8 -> FAILURE
// f(int8,uint256): -2, 6 -> 64
// f(int8,uint256): -2, 7 -> -128
// f(int8,uint256): -2, 8 -> FAILURE
// f(int8,uint256): 6, 3 -> FAILURE
// f(int8,uint256): 7, 2 -> 0x31
// f(int8,uint256): 7, 3 -> FAILURE
// f(int8,uint256): -7, 2 -> 0x31
// f(int8,uint256): -7, 3 -> FAILURE
// f(int8,uint256): -7, 4 -> FAILURE
// f(int8,uint256): 127, 31 -> FAILURE
// f(int8,uint256): 127, 131 -> FAILURE
// f(int8,uint256): -128, 0 -> 1
// f(int8,uint256): -128, 1 -> -128
// f(int8,uint256): -128, 31 -> FAILURE
// f(int8,uint256): -128, 131 -> FAILURE
// f(int8,uint256): -11, 2 -> 121
// f(int8,uint256): -12, 2 -> FAILURE
// f(int8,uint256): 12, 2 -> FAILURE
// f(int8,uint256): -5, 3 -> -125
// f(int8,uint256): -6, 3 -> FAILURE
// g(int256,uint256): -7, 90 -> 11450477594321044359340126713545146077054004823284978858214566372120240027249
// g(int256,uint256): -7, 91 -> FAILURE
// g(int256,uint256): -63, 42 -> 3735107253208426854890677539053540390278853997836851167913009474475553834369
// g(int256,uint256): -63, 43 -> FAILURE

View File

@ -0,0 +1,31 @@
contract C {
function f(uint8 x, uint8 y) public returns (uint) {
return x**y;
}
function g(uint x, uint y) public returns (uint) {
return x**y;
}
}
// ====
// compileViaYul: true
// ----
// f(uint8,uint8): 2, 7 -> 0x80
// f(uint8,uint8): 2, 8 -> FAILURE
// f(uint8,uint8): 15, 2 -> 225
// f(uint8,uint8): 6, 3 -> 0xd8
// f(uint8,uint8): 7, 2 -> 0x31
// f(uint8,uint8): 7, 3 -> FAILURE
// f(uint8,uint8): 7, 4 -> FAILURE
// f(uint8,uint8): 255, 31 -> FAILURE
// f(uint8,uint8): 255, 131 -> FAILURE
// g(uint256,uint256): 0x200000000000000000000000000000000, 1 -> 0x0200000000000000000000000000000000
// g(uint256,uint256): 0x100000000000000000000000000000010, 2 -> FAILURE
// g(uint256,uint256): 0x200000000000000000000000000000000, 2 -> FAILURE
// g(uint256,uint256): 0x200000000000000000000000000000000, 3 -> FAILURE
// g(uint256,uint256): 255, 31 -> 400631961586894742455537928461950192806830589109049416147172451019287109375
// g(uint256,uint256): 255, 32 -> -13630939032658036097408813250890608687528184442832962921928608997994916749311
// g(uint256,uint256): 255, 33 -> FAILURE
// g(uint256,uint256): 255, 131 -> FAILURE
// g(uint256,uint256): 258, 31 -> 575719427506838823084316385994930914701079543089399988096291424922125729792
// g(uint256,uint256): 258, 37 -> FAILURE
// g(uint256,uint256): 258, 131 -> FAILURE

View File

@ -0,0 +1,50 @@
contract C {
function f(uint8 x, uint8 y) public returns (uint) {
return x**y;
}
function g(uint x, uint y) public returns (uint) {
return x**y;
}
}
// ====
// compileViaYul: also
// ----
// f(uint8,uint8): 0, 0 -> 1
// f(uint8,uint8): 0, 1 -> 0x00
// f(uint8,uint8): 0, 2 -> 0x00
// f(uint8,uint8): 0, 3 -> 0x00
// f(uint8,uint8): 1, 0 -> 1
// f(uint8,uint8): 1, 1 -> 1
// f(uint8,uint8): 1, 2 -> 1
// f(uint8,uint8): 1, 3 -> 1
// f(uint8,uint8): 2, 0 -> 1
// f(uint8,uint8): 2, 1 -> 2
// f(uint8,uint8): 2, 2 -> 4
// f(uint8,uint8): 2, 3 -> 8
// f(uint8,uint8): 3, 0 -> 1
// f(uint8,uint8): 3, 1 -> 3
// f(uint8,uint8): 3, 2 -> 9
// f(uint8,uint8): 3, 3 -> 0x1b
// f(uint8,uint8): 10, 0 -> 1
// f(uint8,uint8): 10, 1 -> 0x0a
// f(uint8,uint8): 10, 2 -> 100
// g(uint256,uint256): 0, 0 -> 1
// g(uint256,uint256): 0, 1 -> 0x00
// g(uint256,uint256): 0, 2 -> 0x00
// g(uint256,uint256): 0, 3 -> 0x00
// g(uint256,uint256): 1, 0 -> 1
// g(uint256,uint256): 1, 1 -> 1
// g(uint256,uint256): 1, 2 -> 1
// g(uint256,uint256): 1, 3 -> 1
// g(uint256,uint256): 2, 0 -> 1
// g(uint256,uint256): 2, 1 -> 2
// g(uint256,uint256): 2, 2 -> 4
// g(uint256,uint256): 2, 3 -> 8
// g(uint256,uint256): 3, 0 -> 1
// g(uint256,uint256): 3, 1 -> 3
// g(uint256,uint256): 3, 2 -> 9
// g(uint256,uint256): 3, 3 -> 0x1b
// g(uint256,uint256): 10, 10 -> 10000000000
// g(uint256,uint256): 10, 77 -> -15792089237316195423570985008687907853269984665640564039457584007913129639936
// g(uint256,uint256): 256, 2 -> 0x010000
// g(uint256,uint256): 256, 31 -> 0x0100000000000000000000000000000000000000000000000000000000000000