Merge pull request #9122 from ethereum/optimizeShiftBytes

Optimize combination of byte and shl.
This commit is contained in:
Daniel Kirchner 2020-07-08 21:41:49 +02:00 committed by GitHub
commit 29bad26dee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 155 additions and 1 deletions

2
.gitignore vendored
View File

@ -35,7 +35,7 @@ build/
build*/ build*/
emscripten_build/ emscripten_build/
docs/_build docs/_build
docs/utils/__pycache__ __pycache__
docs/utils/*.pyc docs/utils/*.pyc
/deps/downloads/ /deps/downloads/
deps/install deps/install

View File

@ -6,6 +6,9 @@ Compiler Features:
Bugfixes: Bugfixes:
* Type Checker: Fix overload resolution in combination with ``{value: ...}``. * Type Checker: Fix overload resolution in combination with ``{value: ...}``.
Compiler Features:
* Optimizer: Add rule to remove shifts inside the byte opcode.
### 0.6.11 (2020-07-07) ### 0.6.11 (2020-07-07)

View File

@ -571,6 +571,20 @@ std::vector<SimplificationRule<Pattern>> simplificationRuleListPart7(
feasibilityFunction feasibilityFunction
}); });
rules.push_back({
Builtins::BYTE(A, Builtins::SHL(B, X)),
[=]() -> Pattern { return Builtins::BYTE(A.d() + B.d() / 8, X); },
false,
[=] { return B.d() % 8 == 0 && A.d() <= 32 && B.d() <= 256; }
});
rules.push_back({
Builtins::BYTE(A, Builtins::SHR(B, X)),
[=]() -> Pattern { return A.d() < B.d() / 8 ? Word(0) : Builtins::BYTE(A.d() - B.d() / 8, X); },
false,
[=] { return B.d() % 8 == 0 && A.d() < Pattern::WordSize / 8 && B.d() <= Pattern::WordSize; }
});
return rules; return rules;
} }

View File

@ -0,0 +1,29 @@
from rule import Rule
from opcodes import *
"""
byte(A, shl(B, X))
given B % 8 == 0 && A <= 32 && B <= 256
->
byte(A + B / 8, X)
"""
rule = Rule()
n_bits = 256
# Input vars
X = BitVec('X', n_bits)
A = BitVec('A', n_bits)
B = BitVec('B', n_bits)
# Non optimized result
nonopt = BYTE(A, SHL(B, X))
# Optimized result
opt = BYTE(A + B / 8, X)
rule.require(B % 8 == 0)
rule.require(ULE(A, 32))
rule.require(ULE(B, 256))
rule.check(nonopt, opt)

View File

@ -0,0 +1,30 @@
from rule import Rule
from opcodes import *
"""
byte(A, shr(B, X))
given B % 8 == 0 && A < n_bits/8 && B <= n_bits && A >= B / 8
->
byte(A - B / 8, X)
"""
rule = Rule()
n_bits = 256
# Input vars
X = BitVec('X', n_bits)
A = BitVec('A', n_bits)
B = BitVec('B', n_bits)
# Non optimized result
nonopt = BYTE(A, SHR(B, X))
# Optimized result
opt = BYTE(A - B / 8, X)
rule.require(B % 8 == 0)
rule.require(ULT(A, n_bits/8))
rule.require(ULE(B, n_bits))
rule.require(UGE(A, DIV(B,8)))
rule.check(nonopt, opt)

View File

@ -0,0 +1,30 @@
from rule import Rule
from opcodes import *
"""
byte(A, shr(B, X))
given B % 8 == 0 && A < n_bits/8 && B <= n_bits && A < B / 8
->
0
"""
rule = Rule()
n_bits = 256
# Input vars
X = BitVec('X', n_bits)
A = BitVec('A', n_bits)
B = BitVec('B', n_bits)
# Non optimized result
nonopt = BYTE(A, SHR(B, X))
# Optimized result
opt = 0
rule.require(B % 8 == 0)
rule.require(ULT(A, n_bits/8))
rule.require(ULE(B, n_bits))
rule.require(ULT(A, DIV(B,8)))
rule.check(nonopt, opt)

View File

@ -56,3 +56,7 @@ def SHR(x, y):
def SAR(x, y): def SAR(x, y):
return y >> x return y >> x
def BYTE(i, x):
bit = (i + 1) * 8
return If(UGT(bit, x.size()), BitVecVal(0, x.size()), (LShR(x, (x.size() - bit))) & 0xff)

View File

@ -0,0 +1,44 @@
// This tests the optimizer rule
// byte(A, shl(B, X))
// ->
// byte(A + B / 8, X)
// given A <= 32 && B % 8 == 0 && B <= 256
//
// and the respective rule about shr
contract C {
function f(uint a) public returns (uint, uint, uint) {
uint x = a << (256 - 8);
assembly {
x := byte(0, x)
}
uint y = a << 8;
assembly {
y := byte(30, y)
}
uint z = a << 16;
assembly {
z := byte(1, z)
}
return (x, y, z);
}
function g(uint a) public returns (uint, uint, uint) {
uint x = a >> (256 - 16);
assembly {
x := byte(31, x)
}
uint y = a >> 8;
assembly {
y := byte(4, y)
}
uint z = a >> 16;
assembly {
z := byte(7, z)
}
return (x, y, z);
}
}
// ====
// compileViaYul: also
// ----
// f(uint256): 0x0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f -> 0x1f, 0x1f, 3
// g(uint256): 0x0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f -> 1, 3, 5