IR generation for shifts

This commit is contained in:
Alex Beregszaszi 2020-04-20 22:16:42 +01:00 committed by chriseth
parent 3a93080ca9
commit 327c75bc1b
32 changed files with 236 additions and 18 deletions

View File

@ -299,9 +299,6 @@ string YulUtilFunctions::shiftRightFunction(size_t _numBits)
string YulUtilFunctions::shiftRightFunctionDynamic()
{
// Note that if this is extended with signed shifts,
// the opcodes SAR and SDIV behave differently with regards to rounding!
string const functionName = "shift_right_unsigned_dynamic";
return m_functionCollector.createFunction(functionName, [&]() {
return
@ -321,6 +318,86 @@ string YulUtilFunctions::shiftRightFunctionDynamic()
});
}
string YulUtilFunctions::shiftRightSignedFunctionDynamic()
{
string const functionName = "shift_right_signed_dynamic";
return m_functionCollector.createFunction(functionName, [&]() {
return
Whiskers(R"(
function <functionName>(bits, value) -> result {
<?hasShifts>
result := sar(bits, value)
<!hasShifts>
let divisor := exp(2, bits)
let xor_mask := sub(0, slt(value, 0))
result := xor(div(xor(value, xor_mask), divisor), xor_mask)
// combined version of
// switch slt(value, 0)
// case 0 { result := div(value, divisor) }
// default { result := not(div(not(value), divisor)) }
</hasShifts>
}
)")
("functionName", functionName)
("hasShifts", m_evmVersion.hasBitwiseShifting())
.render();
});
}
string YulUtilFunctions::typedShiftLeftFunction(Type const& _type, Type const& _amountType)
{
solAssert(_type.category() == Type::Category::FixedBytes || _type.category() == Type::Category::Integer, "");
solAssert(_amountType.category() == Type::Category::Integer, "");
string const functionName = "shift_left_" + _type.identifier() + "_" + _amountType.identifier();
return m_functionCollector.createFunction(functionName, [&]() {
return
Whiskers(R"(
function <functionName>(value, bits) -> result {
bits := <cleanAmount>(bits)
<?amountSigned>
if slt(bits, 0) { invalid() }
</amountSigned>
result := <cleanup>(<shift>(bits, value))
}
)")
("functionName", functionName)
("amountSigned", dynamic_cast<IntegerType const&>(_amountType).isSigned())
("cleanAmount", cleanupFunction(_amountType))
("shift", shiftLeftFunctionDynamic())
("cleanup", cleanupFunction(_type))
.render();
});
}
string YulUtilFunctions::typedShiftRightFunction(Type const& _type, Type const& _amountType)
{
solAssert(_type.category() == Type::Category::FixedBytes || _type.category() == Type::Category::Integer, "");
solAssert(_amountType.category() == Type::Category::Integer, "");
IntegerType const* integerType = dynamic_cast<IntegerType const*>(&_type);
bool valueSigned = integerType && integerType->isSigned();
string const functionName = "shift_right_" + _type.identifier() + "_" + _amountType.identifier();
return m_functionCollector.createFunction(functionName, [&]() {
return
Whiskers(R"(
function <functionName>(value, bits) -> result {
bits := <cleanAmount>(bits)
<?amountSigned>
if slt(bits, 0) { invalid() }
</amountSigned>
result := <cleanup>(<shift>(bits, <cleanup>(value)))
}
)")
("functionName", functionName)
("amountSigned", dynamic_cast<IntegerType const&>(_amountType).isSigned())
("cleanAmount", cleanupFunction(_amountType))
("shift", valueSigned ? shiftRightSignedFunctionDynamic() : shiftRightFunctionDynamic())
("cleanup", cleanupFunction(_type))
.render();
});
}
string YulUtilFunctions::updateByteSliceFunction(size_t _numBytes, size_t _shiftBytes)
{
solAssert(_numBytes <= 32, "");

View File

@ -81,6 +81,14 @@ public:
std::string shiftLeftFunctionDynamic();
std::string shiftRightFunction(size_t _numBits);
std::string shiftRightFunctionDynamic();
std::string shiftRightSignedFunctionDynamic();
/// @returns the name of a function that performs a left shift and subsequent cleanup
/// and, if needed, prior cleanup.
/// If the amount to shift by is signed, a check for negativeness is performed.
/// signature: (value, amountToShift) -> result
std::string typedShiftLeftFunction(Type const& _type, Type const& _amountType);
std::string typedShiftRightFunction(Type const& _type, Type const& _amountType);
/// @returns the name of a function which replaces the
/// _numBytes bytes starting at byte position _shiftBytes (counted from the least significant

View File

@ -254,29 +254,53 @@ bool IRGeneratorForStatements::visit(Conditional const& _conditional)
bool IRGeneratorForStatements::visit(Assignment const& _assignment)
{
_assignment.rightHandSide().accept(*this);
Type const* intermediateType = type(_assignment.rightHandSide()).closestTemporaryType(
&type(_assignment.leftHandSide())
);
IRVariable value = convert(_assignment.rightHandSide(), *intermediateType);
Token assignmentOperator = _assignment.assignmentOperator();
Token binaryOperator =
assignmentOperator == Token::Assign ?
assignmentOperator :
TokenTraits::AssignmentToBinaryOp(assignmentOperator);
Type const* rightIntermediateType =
TokenTraits::isShiftOp(binaryOperator) ?
type(_assignment.rightHandSide()).mobileType() :
type(_assignment.rightHandSide()).closestTemporaryType(
&type(_assignment.leftHandSide())
);
solAssert(rightIntermediateType, "");
IRVariable value = convert(_assignment.rightHandSide(), *rightIntermediateType);
_assignment.leftHandSide().accept(*this);
solAssert(!!m_currentLValue, "LValue not retrieved.");
if (_assignment.assignmentOperator() != Token::Assign)
if (assignmentOperator != Token::Assign)
{
solAssert(type(_assignment.leftHandSide()) == *intermediateType, "");
solAssert(intermediateType->isValueType(), "Compound operators only available for value types.");
solAssert(type(_assignment.leftHandSide()).isValueType(), "Compound operators only available for value types.");
solAssert(rightIntermediateType->isValueType(), "Compound operators only available for value types.");
IRVariable leftIntermediate = readFromLValue(*m_currentLValue);
m_code << value.name() << " := " << binaryOperation(
TokenTraits::AssignmentToBinaryOp(_assignment.assignmentOperator()),
*intermediateType,
leftIntermediate.name(),
value.name()
);
if (TokenTraits::isShiftOp(binaryOperator))
{
solAssert(type(_assignment) == leftIntermediate.type(), "");
solAssert(type(_assignment) == type(_assignment.leftHandSide()), "");
define(_assignment) << shiftOperation(binaryOperator, leftIntermediate, value);
writeToLValue(*m_currentLValue, IRVariable(_assignment));
m_currentLValue.reset();
return false;
}
else
{
solAssert(type(_assignment.leftHandSide()) == *rightIntermediateType, "");
m_code << value.name() << " := " << binaryOperation(
binaryOperator,
*rightIntermediateType,
leftIntermediate.name(),
value.name()
);
}
}
writeToLValue(*m_currentLValue, value);
m_currentLValue.reset();
if (*_assignment.annotation().type != *TypeProvider::emptyTuple())
define(_assignment, value);
@ -541,6 +565,12 @@ bool IRGeneratorForStatements::visit(BinaryOperation const& _binOp)
solAssert(false, "Unknown comparison operator.");
define(_binOp) << expr << "\n";
}
else if (TokenTraits::isShiftOp(op))
{
IRVariable left = convert(_binOp.leftExpression(), *commonType);
IRVariable right = convert(_binOp.rightExpression(), *type(_binOp.rightExpression()).mobileType());
define(_binOp) << shiftOperation(_binOp.getOperator(), left, right) << "\n";
}
else
{
string left = expressionAsType(_binOp.leftExpression(), *commonType);
@ -1921,6 +1951,10 @@ string IRGeneratorForStatements::binaryOperation(
string const& _right
)
{
solAssert(
!TokenTraits::isShiftOp(_operator),
"Have to use specific shift operation function for shifts."
);
if (IntegerType const* type = dynamic_cast<IntegerType const*>(&_type))
{
string fun;
@ -1964,6 +1998,31 @@ string IRGeneratorForStatements::binaryOperation(
return {};
}
std::string IRGeneratorForStatements::shiftOperation(
langutil::Token _operator,
IRVariable const& _value,
IRVariable const& _amountToShift
)
{
IntegerType const* amountType = dynamic_cast<IntegerType const*>(&_amountToShift.type());
solAssert(amountType, "");
solAssert(_operator == Token::SHL || _operator == Token::SAR, "");
return
Whiskers(R"(
<shift>(<value>, <amount>)
)")
("shift",
_operator == Token::SHL ?
m_utils.typedShiftLeftFunction(_value.type(), *amountType) :
m_utils.typedShiftRightFunction(_value.type(), *amountType)
)
("value", _value.name())
("amount", _amountToShift.name())
.render();
}
void IRGeneratorForStatements::appendAndOrOperatorCode(BinaryOperation const& _binOp)
{
langutil::Token const op = _binOp.getOperator();

View File

@ -147,6 +147,11 @@ private:
std::string const& _right
);
/// @returns code to perform the given shift operation.
/// The operation itself will be performed in the type of the value,
/// while the amount to shift can have its own type.
std::string shiftOperation(langutil::Token _op, IRVariable const& _value, IRVariable const& _shiftAmount);
/// Assigns the value of @a _value to the lvalue @a _lvalue.
void writeToLValue(IRLValue const& _lvalue, IRVariable const& _value);
/// @returns a fresh IR variable containing the value of the lvalue @a _lvalue.

View File

@ -7,5 +7,7 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f() -> 0x0

View File

@ -5,5 +5,7 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f() -> 0x4200

View File

@ -5,5 +5,7 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f() -> 0x42

View File

@ -4,6 +4,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(uint256,uint256): 0x4266, 0x0 -> 0x4266
// f(uint256,uint256): 0x4266, 0x8 -> 0x426600

View File

@ -5,6 +5,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(uint256,uint256): 0x4266, 0x0 -> 0x4266
// f(uint256,uint256): 0x4266, 0x8 -> 0x426600

View File

@ -5,6 +5,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(uint256,uint8): 0x4266, 0x0 -> 0x4266
// f(uint256,uint8): 0x4266, 0x8 -> 0x426600

View File

@ -6,6 +6,7 @@ contract C {
return y << x;
}
}
// ====
// compileViaYul: also
// ----
// f() -> 0

View File

@ -4,6 +4,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(uint32,uint32): 0x4266, 0x0 -> 0x4266
// f(uint32,uint32): 0x4266, 0x8 -> 0x426600

View File

@ -4,6 +4,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(uint8,uint8): 0x66, 0x0 -> 0x66
// f(uint8,uint8): 0x66, 0x8 -> 0

View File

@ -8,6 +8,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(int256,int256): 1, -1 -> FAILURE
// g(int256,int256): 1, -1 -> FAILURE

View File

@ -10,6 +10,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(int256,int256): 1, -1 -> FAILURE
// g(int256,int256): 1, -1 -> FAILURE

View File

@ -8,6 +8,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// leftU(uint8,uint8): 255, 8 -> 0
// leftU(uint8,uint8): 255, 1 -> 254

View File

@ -4,6 +4,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(uint256,uint256): 0x4266, 0x0 -> 0x4266
// f(uint256,uint256): 0x4266, 0x8 -> 0x42

View File

@ -5,6 +5,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(uint256,uint256): 0x4266, 0x0 -> 0x4266
// f(uint256,uint256): 0x4266, 0x8 -> 0x42

View File

@ -5,6 +5,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(int256,int256): 0x4266, 0x0 -> 0x4266
// f(int256,int256): 0x4266, 0x8 -> 0x42

View File

@ -18,6 +18,8 @@ contract C {
return a >> b;
}
}
// ====
// compileViaYul: also
// ----
// f(int8,uint8): 0x00, 0x03 -> 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe
// f(int8,uint8): 0x00, 0x04 -> 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff

View File

@ -10,6 +10,8 @@ contract C {
return a >> b;
}
}
// ====
// compileViaYul: also
// ----
// f(uint8,uint8): 0x00, 0x04 -> 0x0f
// f(uint8,uint8): 0x00, 0x1004 -> FAILURE

View File

@ -4,6 +4,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(int256,int256): -4266, 0 -> -4266
// f(int256,int256): -4266, 1 -> -2133

View File

@ -5,6 +5,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(int256,int256): -4266, 0 -> -4266
// f(int256,int256): -4266, 1 -> -2133

View File

@ -4,6 +4,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(int16,int16): -4266, 0 -> -4266
// f(int16,int16): -4266, 1 -> -2133

View File

@ -4,6 +4,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(int32,int32): -4266, 0 -> -4266
// f(int32,int32): -4266, 1 -> -2133

View File

@ -4,6 +4,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(int8,int8): -66, 0 -> -66
// f(int8,int8): -66, 1 -> -33

View File

@ -6,6 +6,8 @@ contract C {
return a >> b;
}
}
// ====
// compileViaYul: also
// ----
// f(int16,int16): 0xff99, 0x00 -> FAILURE
// f(int16,int16): 0xff99, 0x01 -> FAILURE

View File

@ -6,6 +6,8 @@ contract C {
return a >> b;
}
}
// ====
// compileViaYul: also
// ----
// f(int32,int32): 0xffffff99, 0x00 -> FAILURE
// f(int32,int32): 0xffffff99, 0x01 -> FAILURE

View File

@ -6,6 +6,8 @@ contract C {
return a >> b;
}
}
// ====
// compileViaYul: also
// ----
// f(int8,int8): 0x99, 0x00 -> FAILURE
// f(int8,int8): 0x99, 0x01 -> FAILURE

View File

@ -4,6 +4,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(uint32,uint32): 0x4266, 0x0 -> 0x4266
// f(uint32,uint32): 0x4266, 0x8 -> 0x42

View File

@ -4,6 +4,8 @@ contract C {
}
}
// ====
// compileViaYul: also
// ----
// f(uint8,uint8): 0x66, 0x0 -> 0x66
// f(uint8,uint8): 0x66, 0x8 -> 0x0

View File

@ -0,0 +1,16 @@
contract C {
function f(uint256 a, int8 b) public returns (uint256) {
assembly { b := 0xff }
return a << b;
}
function g(uint256 a, int8 b) public returns (uint256) {
assembly { b := 0xff }
return a >> b;
}
}
// ====
// compileViaYul: also
// ----
// f(uint256,int8): 0x1234, 0x0 -> FAILURE
// g(uint256,int8): 0x1234, 0x0 -> FAILURE