Optimization for exponentiation when the base is a literal

This commit is contained in:
Harikrishnan Mulackal 2020-09-15 22:42:48 +02:00 committed by hrkrshnn
parent 7a86a61b08
commit 418aa01c5b
5 changed files with 232 additions and 4 deletions

View File

@ -635,6 +635,119 @@ string YulUtilFunctions::overflowCheckedIntExpFunction(
});
}
string YulUtilFunctions::overflowCheckedIntLiteralExpFunction(
RationalNumberType const& _baseType,
IntegerType const& _exponentType,
IntegerType const& _commonType
)
{
solAssert(!_exponentType.isSigned(), "");
solAssert(_baseType.isNegative() == _commonType.isSigned(), "");
solAssert(_commonType.numBits() == 256, "");
string functionName = "checked_exp_" + _baseType.richIdentifier() + "_" + _exponentType.identifier();
return m_functionCollector.createFunction(functionName, [&]()
{
// Converts a bigint number into u256 (negative numbers represented in two's complement form.)
// We assume that `_v` fits in 256 bits.
auto bigint2u = [&](bigint const& _v) -> u256
{
if (_v < 0)
return s2u(s256(_v));
return u256(_v);
};
// Calculates the upperbound for exponentiation, that is, calculate `b`, such that
// _base**b <= _maxValue and _base**(b + 1) > _maxValue
auto findExponentUpperbound = [](bigint const _base, bigint const _maxValue) -> unsigned
{
// There is no overflow for these cases
if (_base == 0 || _base == -1 || _base == 1)
return 0;
unsigned first = 0;
unsigned last = 255;
unsigned middle;
while (first < last)
{
middle = (first + last) / 2;
if (
// The condition on msb is a shortcut that avoids computing large powers in
// arbitrary precision.
boost::multiprecision::msb(_base) * middle <= boost::multiprecision::msb(_maxValue) &&
boost::multiprecision::pow(_base, middle) <= _maxValue
)
{
if (boost::multiprecision::pow(_base, middle + 1) > _maxValue)
return middle;
else
first = middle + 1;
}
else
last = middle;
}
return last;
};
bigint baseValue = _baseType.isNegative() ?
u2s(_baseType.literalValue(nullptr)) :
_baseType.literalValue(nullptr);
bool needsOverflowCheck = !((baseValue == 0) || (baseValue == -1) || (baseValue == 1));
unsigned exponentUpperbound;
if (_baseType.isNegative())
{
// Only checks for underflow. The only case where this can be a problem is when, for a
// negative base, say `b`, and an even exponent, say `e`, `b**e = 2**255` (which is an
// overflow.) But this never happens because, `255 = 3*5*17`, and therefore there is no even
// number `e` such that `b**e = 2**255`.
exponentUpperbound = findExponentUpperbound(abs(baseValue), abs(_commonType.minValue()));
bigint power = boost::multiprecision::pow(baseValue, exponentUpperbound);
bigint overflowedPower = boost::multiprecision::pow(baseValue, exponentUpperbound + 1);
if (needsOverflowCheck)
solAssert(
(power <= _commonType.maxValue()) && (power >= _commonType.minValue()) &&
!((overflowedPower <= _commonType.maxValue()) && (overflowedPower >= _commonType.minValue())),
"Incorrect exponent upper bound calculated."
);
}
else
{
exponentUpperbound = findExponentUpperbound(baseValue, _commonType.maxValue());
if (needsOverflowCheck)
solAssert(
boost::multiprecision::pow(baseValue, exponentUpperbound) <= _commonType.maxValue() &&
boost::multiprecision::pow(baseValue, exponentUpperbound + 1) > _commonType.maxValue(),
"Incorrect exponent upper bound calculated."
);
}
return Whiskers(R"(
function <functionName>(exponent) -> power {
exponent := <exponentCleanupFunction>(exponent)
<?needsOverflowCheck>
if gt(exponent, <exponentUpperbound>) { <panic>() }
</needsOverflowCheck>
power := exp(<base>, exponent)
}
)")
("functionName", functionName)
("exponentCleanupFunction", cleanupFunction(_exponentType))
("needsOverflowCheck", needsOverflowCheck)
("exponentUpperbound", to_string(exponentUpperbound))
("panic", panicFunction())
("base", bigint2u(baseValue).str())
.render();
});
}
string YulUtilFunctions::overflowCheckedUnsignedExpFunction()
{
// Checks for the "small number specialization" below.

View File

@ -129,6 +129,14 @@ public:
/// signature: (base, exponent) -> power
std::string overflowCheckedIntExpFunction(IntegerType const& _type, IntegerType const& _exponentType);
/// @returns the name of the exponentiation function, specialized for literal base.
/// signature: exponent -> power
std::string overflowCheckedIntLiteralExpFunction(
RationalNumberType const& _baseType,
IntegerType const& _exponentType,
IntegerType const& _commonType
);
/// Generic unsigned checked exponentiation function.
/// Reverts if the result is larger than max.
/// signature: (base, exponent, max) -> power

View File

@ -695,17 +695,34 @@ bool IRGeneratorForStatements::visit(BinaryOperation const& _binOp)
solAssert(false, "Unknown comparison operator.");
define(_binOp) << expr << "\n";
}
else if (TokenTraits::isShiftOp(op) || op == Token::Exp)
else if (op == Token::Exp)
{
IRVariable left = convert(_binOp.leftExpression(), *commonType);
IRVariable right = convert(_binOp.rightExpression(), *type(_binOp.rightExpression()).mobileType());
if (op == Token::Exp)
if (auto rationalNumberType = dynamic_cast<RationalNumberType const*>(_binOp.leftExpression().annotation().type))
{
solAssert(rationalNumberType->integerType(), "Invalid literal as the base for exponentiation.");
solAssert(dynamic_cast<IntegerType const*>(commonType), "");
define(_binOp) << m_utils.overflowCheckedIntLiteralExpFunction(
*rationalNumberType,
dynamic_cast<IntegerType const&>(right.type()),
dynamic_cast<IntegerType const&>(*commonType)
) << "(" << right.name() << ")\n";
}
else
define(_binOp) << m_utils.overflowCheckedIntExpFunction(
dynamic_cast<IntegerType const&>(left.type()),
dynamic_cast<IntegerType const&>(right.type())
) << "(" << left.name() << ", " << right.name() << ")\n";
else
define(_binOp) << shiftOperation(_binOp.getOperator(), left, right) << "\n";
}
else if (TokenTraits::isShiftOp(op))
{
IRVariable left = convert(_binOp.leftExpression(), *commonType);
IRVariable right = convert(_binOp.rightExpression(), *type(_binOp.rightExpression()).mobileType());
define(_binOp) << shiftOperation(_binOp.getOperator(), left, right) << "\n";
}
else
{

View File

@ -0,0 +1,49 @@
contract C {
function exp_2(uint y) public returns (uint) {
return 2**y;
}
function exp_minus_2(uint y) public returns (int) {
return (-2)**y;
}
function exp_uint_max(uint y) public returns (uint) {
return (2**256 - 1)**y;
}
function exp_int_max(uint y) public returns (int) {
return ((-2)**255)**y;
}
function exp_5(uint y) public returns (uint) {
return 5**y;
}
function exp_minus_5(uint y) public returns (int) {
return (-5)**y;
}
function exp_256(uint y) public returns (uint) {
return 256**y;
}
function exp_minus_256(uint y) public returns (int) {
return (-256)**y;
}
}
// ====
// compileViaYul: true
// ----
// exp_2(uint256): 255 -> 57896044618658097711785492504343953926634992332820282019728792003956564819968
// exp_2(uint256): 256 -> FAILURE
// exp_minus_2(uint256): 255 -> -57896044618658097711785492504343953926634992332820282019728792003956564819968
// exp_minus_2(uint256): 256 -> FAILURE
// exp_uint_max(uint256): 1 -> 115792089237316195423570985008687907853269984665640564039457584007913129639935
// exp_uint_max(uint256): 2 -> FAILURE
// exp_int_max(uint256): 1 -> -57896044618658097711785492504343953926634992332820282019728792003956564819968
// exp_int_max(uint256): 2 -> FAILURE
// exp_5(uint256): 110 -> 77037197775489434122239117703397092741524065928615527809597551822662353515625
// exp_5(uint256): 111 -> FAILURE
// exp_minus_5(uint256): 109 -> -15407439555097886824447823540679418548304813185723105561919510364532470703125
// exp_minus_5(uint256): 110 -> FAILURE
// exp_256(uint256): 31 -> 452312848583266388373324160190187140051835877600158453279131187530910662656
// exp_256(uint256): 32 -> FAILURE
// exp_minus_256(uint256): 31 -> -452312848583266388373324160190187140051835877600158453279131187530910662656
// exp_minus_256(uint256): 32 -> FAILURE

View File

@ -0,0 +1,41 @@
contract C {
function exp_2(uint y) public returns (uint) {
return 2**y;
}
function exp_minus_2(uint y) public returns (int) {
return (-2)**y;
}
function exp_uint_max(uint y) public returns (uint) {
return (2**256 - 1)**y;
}
function exp_int_max(uint y) public returns (int) {
return ((-2)**255)**y;
}
function exp_5(uint y) public returns (uint) {
return 5**y;
}
function exp_minus_5(uint y) public returns (int) {
return (-5)**y;
}
function exp_256(uint y) public returns (uint) {
return 256**y;
}
function exp_minus_256(uint y) public returns (int) {
return (-256)**y;
}
}
// ====
// compileViaYul: also
// ----
// exp_2(uint256): 255 -> 57896044618658097711785492504343953926634992332820282019728792003956564819968
// exp_minus_2(uint256): 255 -> -57896044618658097711785492504343953926634992332820282019728792003956564819968
// exp_uint_max(uint256): 1 -> 115792089237316195423570985008687907853269984665640564039457584007913129639935
// exp_int_max(uint256): 1 -> -57896044618658097711785492504343953926634992332820282019728792003956564819968
// exp_5(uint256): 110 -> 77037197775489434122239117703397092741524065928615527809597551822662353515625
// exp_minus_5(uint256): 109 -> -15407439555097886824447823540679418548304813185723105561919510364532470703125
// exp_256(uint256): 31 -> 452312848583266388373324160190187140051835877600158453279131187530910662656
// exp_minus_256(uint256): 31 -> -452312848583266388373324160190187140051835877600158453279131187530910662656