Merge pull request #9816 from ethereum/exp-base-literals

[Sol->Yul] Optimization for exponentiation when the base is a literal
This commit is contained in:
chriseth 2020-10-12 19:34:38 +02:00 committed by GitHub
commit 4b342a7cad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 557 additions and 4 deletions

View File

@ -635,6 +635,119 @@ string YulUtilFunctions::overflowCheckedIntExpFunction(
});
}
string YulUtilFunctions::overflowCheckedIntLiteralExpFunction(
RationalNumberType const& _baseType,
IntegerType const& _exponentType,
IntegerType const& _commonType
)
{
solAssert(!_exponentType.isSigned(), "");
solAssert(_baseType.isNegative() == _commonType.isSigned(), "");
solAssert(_commonType.numBits() == 256, "");
string functionName = "checked_exp_" + _baseType.richIdentifier() + "_" + _exponentType.identifier();
return m_functionCollector.createFunction(functionName, [&]()
{
// Converts a bigint number into u256 (negative numbers represented in two's complement form.)
// We assume that `_v` fits in 256 bits.
auto bigint2u = [&](bigint const& _v) -> u256
{
if (_v < 0)
return s2u(s256(_v));
return u256(_v);
};
// Calculates the upperbound for exponentiation, that is, calculate `b`, such that
// _base**b <= _maxValue and _base**(b + 1) > _maxValue
auto findExponentUpperbound = [](bigint const _base, bigint const _maxValue) -> unsigned
{
// There is no overflow for these cases
if (_base == 0 || _base == -1 || _base == 1)
return 0;
unsigned first = 0;
unsigned last = 255;
unsigned middle;
while (first < last)
{
middle = (first + last) / 2;
if (
// The condition on msb is a shortcut that avoids computing large powers in
// arbitrary precision.
boost::multiprecision::msb(_base) * middle <= boost::multiprecision::msb(_maxValue) &&
boost::multiprecision::pow(_base, middle) <= _maxValue
)
{
if (boost::multiprecision::pow(_base, middle + 1) > _maxValue)
return middle;
else
first = middle + 1;
}
else
last = middle;
}
return last;
};
bigint baseValue = _baseType.isNegative() ?
u2s(_baseType.literalValue(nullptr)) :
_baseType.literalValue(nullptr);
bool needsOverflowCheck = !((baseValue == 0) || (baseValue == -1) || (baseValue == 1));
unsigned exponentUpperbound;
if (_baseType.isNegative())
{
// Only checks for underflow. The only case where this can be a problem is when, for a
// negative base, say `b`, and an even exponent, say `e`, `b**e = 2**255` (which is an
// overflow.) But this never happens because, `255 = 3*5*17`, and therefore there is no even
// number `e` such that `b**e = 2**255`.
exponentUpperbound = findExponentUpperbound(abs(baseValue), abs(_commonType.minValue()));
bigint power = boost::multiprecision::pow(baseValue, exponentUpperbound);
bigint overflowedPower = boost::multiprecision::pow(baseValue, exponentUpperbound + 1);
if (needsOverflowCheck)
solAssert(
(power <= _commonType.maxValue()) && (power >= _commonType.minValue()) &&
!((overflowedPower <= _commonType.maxValue()) && (overflowedPower >= _commonType.minValue())),
"Incorrect exponent upper bound calculated."
);
}
else
{
exponentUpperbound = findExponentUpperbound(baseValue, _commonType.maxValue());
if (needsOverflowCheck)
solAssert(
boost::multiprecision::pow(baseValue, exponentUpperbound) <= _commonType.maxValue() &&
boost::multiprecision::pow(baseValue, exponentUpperbound + 1) > _commonType.maxValue(),
"Incorrect exponent upper bound calculated."
);
}
return Whiskers(R"(
function <functionName>(exponent) -> power {
exponent := <exponentCleanupFunction>(exponent)
<?needsOverflowCheck>
if gt(exponent, <exponentUpperbound>) { <panic>() }
</needsOverflowCheck>
power := exp(<base>, exponent)
}
)")
("functionName", functionName)
("exponentCleanupFunction", cleanupFunction(_exponentType))
("needsOverflowCheck", needsOverflowCheck)
("exponentUpperbound", to_string(exponentUpperbound))
("panic", panicFunction())
("base", bigint2u(baseValue).str())
.render();
});
}
string YulUtilFunctions::overflowCheckedUnsignedExpFunction()
{
// Checks for the "small number specialization" below.

View File

@ -128,6 +128,14 @@ public:
/// signature: (base, exponent) -> power
std::string overflowCheckedIntExpFunction(IntegerType const& _type, IntegerType const& _exponentType);
/// @returns the name of the exponentiation function, specialized for literal base.
/// signature: exponent -> power
std::string overflowCheckedIntLiteralExpFunction(
RationalNumberType const& _baseType,
IntegerType const& _exponentType,
IntegerType const& _commonType
);
/// Generic unsigned checked exponentiation function.
/// Reverts if the result is larger than max.
/// signature: (base, exponent, max) -> power

View File

@ -695,17 +695,34 @@ bool IRGeneratorForStatements::visit(BinaryOperation const& _binOp)
solAssert(false, "Unknown comparison operator.");
define(_binOp) << expr << "\n";
}
else if (TokenTraits::isShiftOp(op) || op == Token::Exp)
else if (op == Token::Exp)
{
IRVariable left = convert(_binOp.leftExpression(), *commonType);
IRVariable right = convert(_binOp.rightExpression(), *type(_binOp.rightExpression()).mobileType());
if (op == Token::Exp)
if (auto rationalNumberType = dynamic_cast<RationalNumberType const*>(_binOp.leftExpression().annotation().type))
{
solAssert(rationalNumberType->integerType(), "Invalid literal as the base for exponentiation.");
solAssert(dynamic_cast<IntegerType const*>(commonType), "");
define(_binOp) << m_utils.overflowCheckedIntLiteralExpFunction(
*rationalNumberType,
dynamic_cast<IntegerType const&>(right.type()),
dynamic_cast<IntegerType const&>(*commonType)
) << "(" << right.name() << ")\n";
}
else
define(_binOp) << m_utils.overflowCheckedIntExpFunction(
dynamic_cast<IntegerType const&>(left.type()),
dynamic_cast<IntegerType const&>(right.type())
) << "(" << left.name() << ", " << right.name() << ")\n";
else
define(_binOp) << shiftOperation(_binOp.getOperator(), left, right) << "\n";
}
else if (TokenTraits::isShiftOp(op))
{
IRVariable left = convert(_binOp.leftExpression(), *commonType);
IRVariable right = convert(_binOp.rightExpression(), *type(_binOp.rightExpression()).mobileType());
define(_binOp) << shiftOperation(_binOp.getOperator(), left, right) << "\n";
}
else
{

View File

@ -0,0 +1 @@
--ir

View File

@ -0,0 +1 @@

View File

@ -0,0 +1 @@
0

View File

@ -0,0 +1,18 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity > 0.7.1;
contract C {
function f(uint a, uint b, uint c, uint d) public pure returns (uint, int, uint, uint) {
uint w = 2**a;
int x = (-2)**b;
uint y = 10**c;
uint z = (2**256 -1 )**d;
// Special cases: 0, 1, and -1
w = (0)**a;
x = (-1)**b;
y = 1**c;
return (w, x, y, z);
}
}

View File

@ -0,0 +1,304 @@
IR:
/*******************************************************
* WARNING *
* Solidity to Yul compilation is still EXPERIMENTAL *
* It can result in LOSS OF FUNDS or worse *
* !USE AT YOUR OWN RISK! *
*******************************************************/
object "C_80" {
code {
mstore(64, memoryguard(128))
if callvalue() { revert(0, 0) }
constructor_C_80()
codecopy(0, dataoffset("C_80_deployed"), datasize("C_80_deployed"))
return(0, datasize("C_80_deployed"))
function constructor_C_80() {
}
}
object "C_80_deployed" {
code {
mstore(64, memoryguard(128))
if iszero(lt(calldatasize(), 4))
{
let selector := shift_right_224_unsigned(calldataload(0))
switch selector
case 0x70cb9605
{
// f(uint256,uint256,uint256,uint256)
if callvalue() { revert(0, 0) }
let param_0, param_1, param_2, param_3 := abi_decode_tuple_t_uint256t_uint256t_uint256t_uint256(4, calldatasize())
let ret_0, ret_1, ret_2, ret_3 := fun_f_79(param_0, param_1, param_2, param_3)
let memPos := allocateMemory(0)
let memEnd := abi_encode_tuple_t_uint256_t_int256_t_uint256_t_uint256__to_t_uint256_t_int256_t_uint256_t_uint256__fromStack(memPos , ret_0, ret_1, ret_2, ret_3)
return(memPos, sub(memEnd, memPos))
}
default {}
}
if iszero(calldatasize()) { }
revert(0, 0)
function abi_decode_t_uint256(offset, end) -> value {
value := calldataload(offset)
validator_revert_t_uint256(value)
}
function abi_decode_tuple_t_uint256t_uint256t_uint256t_uint256(headStart, dataEnd) -> value0, value1, value2, value3 {
if slt(sub(dataEnd, headStart), 128) { revert(0, 0) }
{
let offset := 0
value0 := abi_decode_t_uint256(add(headStart, offset), dataEnd)
}
{
let offset := 32
value1 := abi_decode_t_uint256(add(headStart, offset), dataEnd)
}
{
let offset := 64
value2 := abi_decode_t_uint256(add(headStart, offset), dataEnd)
}
{
let offset := 96
value3 := abi_decode_t_uint256(add(headStart, offset), dataEnd)
}
}
function abi_encode_t_int256_to_t_int256_fromStack(value, pos) {
mstore(pos, cleanup_t_int256(value))
}
function abi_encode_t_uint256_to_t_uint256_fromStack(value, pos) {
mstore(pos, cleanup_t_uint256(value))
}
function abi_encode_tuple_t_uint256_t_int256_t_uint256_t_uint256__to_t_uint256_t_int256_t_uint256_t_uint256__fromStack(headStart , value0, value1, value2, value3) -> tail {
tail := add(headStart, 128)
abi_encode_t_uint256_to_t_uint256_fromStack(value0, add(headStart, 0))
abi_encode_t_int256_to_t_int256_fromStack(value1, add(headStart, 32))
abi_encode_t_uint256_to_t_uint256_fromStack(value2, add(headStart, 64))
abi_encode_t_uint256_to_t_uint256_fromStack(value3, add(headStart, 96))
}
function allocateMemory(size) -> memPtr {
memPtr := mload(64)
let newFreePtr := add(memPtr, size)
// protect against overflow
if or(gt(newFreePtr, 0xffffffffffffffff), lt(newFreePtr, memPtr)) { panic_error() }
mstore(64, newFreePtr)
}
function checked_exp_t_rational_0_by_1_t_uint256(exponent) -> power {
exponent := cleanup_t_uint256(exponent)
power := exp(0, exponent)
}
function checked_exp_t_rational_10_by_1_t_uint256(exponent) -> power {
exponent := cleanup_t_uint256(exponent)
if gt(exponent, 77) { panic_error() }
power := exp(10, exponent)
}
function checked_exp_t_rational_115792089237316195423570985008687907853269984665640564039457584007913129639935_by_1_t_uint256(exponent) -> power {
exponent := cleanup_t_uint256(exponent)
if gt(exponent, 1) { panic_error() }
power := exp(115792089237316195423570985008687907853269984665640564039457584007913129639935, exponent)
}
function checked_exp_t_rational_1_by_1_t_uint256(exponent) -> power {
exponent := cleanup_t_uint256(exponent)
power := exp(1, exponent)
}
function checked_exp_t_rational_2_by_1_t_uint256(exponent) -> power {
exponent := cleanup_t_uint256(exponent)
if gt(exponent, 255) { panic_error() }
power := exp(2, exponent)
}
function checked_exp_t_rational_minus_1_by_1_t_uint256(exponent) -> power {
exponent := cleanup_t_uint256(exponent)
power := exp(115792089237316195423570985008687907853269984665640564039457584007913129639935, exponent)
}
function checked_exp_t_rational_minus_2_by_1_t_uint256(exponent) -> power {
exponent := cleanup_t_uint256(exponent)
if gt(exponent, 255) { panic_error() }
power := exp(115792089237316195423570985008687907853269984665640564039457584007913129639934, exponent)
}
function cleanup_t_int256(value) -> cleaned {
cleaned := value
}
function cleanup_t_uint256(value) -> cleaned {
cleaned := value
}
function convert_t_rational_0_by_1_to_t_uint256(value) -> converted {
converted := cleanup_t_uint256(value)
}
function convert_t_rational_10_by_1_to_t_uint256(value) -> converted {
converted := cleanup_t_uint256(value)
}
function convert_t_rational_115792089237316195423570985008687907853269984665640564039457584007913129639935_by_1_to_t_uint256(value) -> converted {
converted := cleanup_t_uint256(value)
}
function convert_t_rational_1_by_1_to_t_uint256(value) -> converted {
converted := cleanup_t_uint256(value)
}
function convert_t_rational_2_by_1_to_t_uint256(value) -> converted {
converted := cleanup_t_uint256(value)
}
function convert_t_rational_minus_1_by_1_to_t_int256(value) -> converted {
converted := cleanup_t_int256(value)
}
function convert_t_rational_minus_2_by_1_to_t_int256(value) -> converted {
converted := cleanup_t_int256(value)
}
function fun_f_79(vloc_a_3, vloc_b_5, vloc_c_7, vloc_d_9) -> vloc__12, vloc__14, vloc__16, vloc__18 {
let zero_value_for_type_t_uint256_1 := zero_value_for_split_t_uint256()
vloc__12 := zero_value_for_type_t_uint256_1
let zero_value_for_type_t_int256_2 := zero_value_for_split_t_int256()
vloc__14 := zero_value_for_type_t_int256_2
let zero_value_for_type_t_uint256_3 := zero_value_for_split_t_uint256()
vloc__16 := zero_value_for_type_t_uint256_3
let zero_value_for_type_t_uint256_4 := zero_value_for_split_t_uint256()
vloc__18 := zero_value_for_type_t_uint256_4
let expr_22 := 0x02
let _5 := vloc_a_3
let expr_23 := _5
let _6 := convert_t_rational_2_by_1_to_t_uint256(expr_22)
let expr_24 := checked_exp_t_rational_2_by_1_t_uint256(expr_23)
let vloc_w_21 := expr_24
let expr_28 := 0x02
let expr_29 := 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe
let expr_30 := expr_29
let _7 := vloc_b_5
let expr_31 := _7
let _8 := convert_t_rational_minus_2_by_1_to_t_int256(expr_30)
let expr_32 := checked_exp_t_rational_minus_2_by_1_t_uint256(expr_31)
let vloc_x_27 := expr_32
let expr_36 := 0x0a
let _9 := vloc_c_7
let expr_37 := _9
let _10 := convert_t_rational_10_by_1_to_t_uint256(expr_36)
let expr_38 := checked_exp_t_rational_10_by_1_t_uint256(expr_37)
let vloc_y_35 := expr_38
let expr_46 := 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
let expr_47 := expr_46
let _11 := vloc_d_9
let expr_48 := _11
let _12 := convert_t_rational_115792089237316195423570985008687907853269984665640564039457584007913129639935_by_1_to_t_uint256(expr_47)
let expr_49 := checked_exp_t_rational_115792089237316195423570985008687907853269984665640564039457584007913129639935_by_1_t_uint256(expr_48)
let vloc_z_41 := expr_49
let expr_52 := 0x00
let expr_53 := expr_52
let _13 := vloc_a_3
let expr_54 := _13
let _14 := convert_t_rational_0_by_1_to_t_uint256(expr_53)
let expr_55 := checked_exp_t_rational_0_by_1_t_uint256(expr_54)
vloc_w_21 := expr_55
let expr_56 := expr_55
let expr_59 := 0x01
let expr_60 := 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
let expr_61 := expr_60
let _15 := vloc_b_5
let expr_62 := _15
let _16 := convert_t_rational_minus_1_by_1_to_t_int256(expr_61)
let expr_63 := checked_exp_t_rational_minus_1_by_1_t_uint256(expr_62)
vloc_x_27 := expr_63
let expr_64 := expr_63
let expr_67 := 0x01
let _17 := vloc_c_7
let expr_68 := _17
let _18 := convert_t_rational_1_by_1_to_t_uint256(expr_67)
let expr_69 := checked_exp_t_rational_1_by_1_t_uint256(expr_68)
vloc_y_35 := expr_69
let expr_70 := expr_69
let _19 := vloc_w_21
let expr_72 := _19
let expr_76_component_1 := expr_72
let _20 := vloc_x_27
let expr_73 := _20
let expr_76_component_2 := expr_73
let _21 := vloc_y_35
let expr_74 := _21
let expr_76_component_3 := expr_74
let _22 := vloc_z_41
let expr_75 := _22
let expr_76_component_4 := expr_75
vloc__12 := expr_76_component_1
vloc__14 := expr_76_component_2
vloc__16 := expr_76_component_3
vloc__18 := expr_76_component_4
leave
}
function panic_error() {
invalid()
}
function shift_right_224_unsigned(value) -> newValue {
newValue :=
shr(224, value)
}
function validator_revert_t_uint256(value) {
if iszero(eq(value, cleanup_t_uint256(value))) { revert(0, 0) }
}
function zero_value_for_split_t_int256() -> ret {
ret := 0
}
function zero_value_for_split_t_uint256() -> ret {
ret := 0
}
}
}
}

View File

@ -0,0 +1,49 @@
contract C {
function exp_2(uint y) public returns (uint) {
return 2**y;
}
function exp_minus_2(uint y) public returns (int) {
return (-2)**y;
}
function exp_uint_max(uint y) public returns (uint) {
return (2**256 - 1)**y;
}
function exp_int_max(uint y) public returns (int) {
return ((-2)**255)**y;
}
function exp_5(uint y) public returns (uint) {
return 5**y;
}
function exp_minus_5(uint y) public returns (int) {
return (-5)**y;
}
function exp_256(uint y) public returns (uint) {
return 256**y;
}
function exp_minus_256(uint y) public returns (int) {
return (-256)**y;
}
}
// ====
// compileViaYul: true
// ----
// exp_2(uint256): 255 -> 57896044618658097711785492504343953926634992332820282019728792003956564819968
// exp_2(uint256): 256 -> FAILURE
// exp_minus_2(uint256): 255 -> -57896044618658097711785492504343953926634992332820282019728792003956564819968
// exp_minus_2(uint256): 256 -> FAILURE
// exp_uint_max(uint256): 1 -> 115792089237316195423570985008687907853269984665640564039457584007913129639935
// exp_uint_max(uint256): 2 -> FAILURE
// exp_int_max(uint256): 1 -> -57896044618658097711785492504343953926634992332820282019728792003956564819968
// exp_int_max(uint256): 2 -> FAILURE
// exp_5(uint256): 110 -> 77037197775489434122239117703397092741524065928615527809597551822662353515625
// exp_5(uint256): 111 -> FAILURE
// exp_minus_5(uint256): 109 -> -15407439555097886824447823540679418548304813185723105561919510364532470703125
// exp_minus_5(uint256): 110 -> FAILURE
// exp_256(uint256): 31 -> 452312848583266388373324160190187140051835877600158453279131187530910662656
// exp_256(uint256): 32 -> FAILURE
// exp_minus_256(uint256): 31 -> -452312848583266388373324160190187140051835877600158453279131187530910662656
// exp_minus_256(uint256): 32 -> FAILURE

View File

@ -0,0 +1,41 @@
contract C {
function exp_2(uint y) public returns (uint) {
return 2**y;
}
function exp_minus_2(uint y) public returns (int) {
return (-2)**y;
}
function exp_uint_max(uint y) public returns (uint) {
return (2**256 - 1)**y;
}
function exp_int_max(uint y) public returns (int) {
return ((-2)**255)**y;
}
function exp_5(uint y) public returns (uint) {
return 5**y;
}
function exp_minus_5(uint y) public returns (int) {
return (-5)**y;
}
function exp_256(uint y) public returns (uint) {
return 256**y;
}
function exp_minus_256(uint y) public returns (int) {
return (-256)**y;
}
}
// ====
// compileViaYul: also
// ----
// exp_2(uint256): 255 -> 57896044618658097711785492504343953926634992332820282019728792003956564819968
// exp_minus_2(uint256): 255 -> -57896044618658097711785492504343953926634992332820282019728792003956564819968
// exp_uint_max(uint256): 1 -> 115792089237316195423570985008687907853269984665640564039457584007913129639935
// exp_int_max(uint256): 1 -> -57896044618658097711785492504343953926634992332820282019728792003956564819968
// exp_5(uint256): 110 -> 77037197775489434122239117703397092741524065928615527809597551822662353515625
// exp_minus_5(uint256): 109 -> -15407439555097886824447823540679418548304813185723105561919510364532470703125
// exp_256(uint256): 31 -> 452312848583266388373324160190187140051835877600158453279131187530910662656
// exp_minus_256(uint256): 31 -> -452312848583266388373324160190187140051835877600158453279131187530910662656