[Sol->Yul] Implementing Byte array push() and pop()

This commit is contained in:
Djordje Mijovic 2020-05-05 19:53:17 +02:00
parent a8ca8f75ff
commit d235d0c166
9 changed files with 227 additions and 28 deletions

View File

@ -602,6 +602,25 @@ string YulUtilFunctions::overflowCheckedIntSubFunction(IntegerType const& _type)
});
}
string YulUtilFunctions::extractByteArrayLengthFunction()
{
string functionName = "extract_byte_array_length";
return m_functionCollector.createFunction(functionName, [&]() {
Whiskers w(R"(
function <functionName>(data) -> length {
// Retrieve length both for in-place strings and off-place strings:
// Computes (x & (0x100 * (ISZERO (x & 1)) - 1)) / 2
// i.e. for short strings (x & 1 == 0) it does (x & 0xff) / 2 and for long strings it
// computes (x & (-1)) / 2, which is equivalent to just x / 2.
let mask := sub(mul(0x100, iszero(and(data, 1))), 1)
length := div(and(data, mask), 2)
}
)");
w("functionName", functionName);
return w.render();
});
}
string YulUtilFunctions::arrayLengthFunction(ArrayType const& _type)
{
string functionName = "array_length_" + _type.identifier();
@ -615,12 +634,7 @@ string YulUtilFunctions::arrayLengthFunction(ArrayType const& _type)
<?storage>
length := sload(value)
<?byteArray>
// Retrieve length both for in-place strings and off-place strings:
// Computes (x & (0x100 * (ISZERO (x & 1)) - 1)) / 2
// i.e. for short strings (x & 1 == 0) it does (x & 0xff) / 2 and for long strings it
// computes (x & (-1)) / 2, which is equivalent to just x / 2.
let mask := sub(mul(0x100, iszero(and(length, 1))), 1)
length := div(and(length, mask), 2)
length := <extractByteArrayLength>(length)
</byteArray>
</storage>
<!dynamic>
@ -634,7 +648,12 @@ string YulUtilFunctions::arrayLengthFunction(ArrayType const& _type)
w("length", toCompactHexWithPrefix(_type.length()));
w("memory", _type.location() == DataLocation::Memory);
w("storage", _type.location() == DataLocation::Storage);
w("byteArray", _type.isByteArray());
if (_type.location() == DataLocation::Storage)
{
w("byteArray", _type.isByteArray());
if (_type.isByteArray())
w("extractByteArrayLength", extractByteArrayLengthFunction());
}
if (_type.isDynamicallySized())
solAssert(
_type.location() != DataLocation::CallData,
@ -689,8 +708,9 @@ string YulUtilFunctions::storageArrayPopFunction(ArrayType const& _type)
{
solAssert(_type.location() == DataLocation::Storage, "");
solAssert(_type.isDynamicallySized(), "");
solUnimplementedAssert(!_type.isByteArray(), "Byte Arrays not yet implemented!");
solUnimplementedAssert(_type.baseType()->storageBytes() <= 32, "Base type is not yet implemented.");
if (_type.isByteArray())
return storageByteArrayPopFunction(_type);
string functionName = "array_pop_" + _type.identifier();
return m_functionCollector.createFunction(functionName, [&]() {
@ -699,10 +719,8 @@ string YulUtilFunctions::storageArrayPopFunction(ArrayType const& _type)
let oldLen := <fetchLength>(array)
if iszero(oldLen) { invalid() }
let newLen := sub(oldLen, 1)
let slot, offset := <indexAccess>(array, newLen)
<setToZero>(slot, offset)
sstore(array, newLen)
})")
("functionName", functionName)
@ -713,29 +731,115 @@ string YulUtilFunctions::storageArrayPopFunction(ArrayType const& _type)
});
}
string YulUtilFunctions::storageByteArrayPopFunction(ArrayType const& _type)
{
solAssert(_type.location() == DataLocation::Storage, "");
solAssert(_type.isDynamicallySized(), "");
solAssert(_type.isByteArray(), "");
string functionName = "byte_array_pop_" + _type.identifier();
return m_functionCollector.createFunction(functionName, [&]() {
return Whiskers(R"(
function <functionName>(array) {
let data := sload(array)
let oldLen := <extractByteArrayLength>(data)
if iszero(oldLen) { invalid() }
switch eq(oldLen, 32)
case 1 {
// Here we have a special case where array transitions to shorter than 32
// So we need to copy data
let copyFromSlot := <dataAreaFunction>(array)
data := sload(copyFromSlot)
sstore(copyFromSlot, 0)
// New length is 31, encoded to 31 * 2 = 62
data := or(and(data, not(0xff)), 62)
}
default {
data := sub(data, 2)
let newLen := sub(oldLen, 1)
switch lt(oldLen, 32)
case 1 {
// set last element to zero
let mask := not(<shl>(mul(8, sub(31, newLen)), 0xff))
data := and(data, mask)
}
default {
let slot, offset := <indexAccess>(array, newLen)
<setToZero>(slot, offset)
}
}
sstore(array, data)
})")
("functionName", functionName)
("extractByteArrayLength", extractByteArrayLengthFunction())
("dataAreaFunction", arrayDataAreaFunction(_type))
("indexAccess", storageArrayIndexAccessFunction(_type))
("setToZero", storageSetToZeroFunction(*_type.baseType()))
("shl", shiftLeftFunctionDynamic())
.render();
});
}
string YulUtilFunctions::storageArrayPushFunction(ArrayType const& _type)
{
solAssert(_type.location() == DataLocation::Storage, "");
solAssert(_type.isDynamicallySized(), "");
solUnimplementedAssert(!_type.isByteArray(), "Byte Arrays not yet implemented!");
solUnimplementedAssert(_type.baseType()->storageBytes() <= 32, "Base type is not yet implemented.");
string functionName = "array_push_" + _type.identifier();
return m_functionCollector.createFunction(functionName, [&]() {
return Whiskers(R"(
function <functionName>(array, value) {
let oldLen := <fetchLength>(array)
if iszero(lt(oldLen, <maxArrayLength>)) { invalid() }
sstore(array, add(oldLen, 1))
<?isByteArray>
let data := sload(array)
let oldLen := <extractByteArrayLength>(data)
if iszero(lt(oldLen, <maxArrayLength>)) { invalid() }
let slot, offset := <indexAccess>(array, oldLen)
<storeValue>(slot, offset, value)
switch gt(oldLen, 31)
case 0 {
value := byte(0, value)
switch oldLen
case 31 {
// Here we have special case when array switches from short array to long array
// We need to copy data
let dataArea := <dataAreaFunction>(array)
data := and(data, not(0xff))
sstore(dataArea, or(and(0xff, value), data))
// New length is 32, encoded as (32 * 2 + 1)
sstore(array, 65)
}
default {
data := add(data, 2)
let shiftBits := mul(8, sub(31, oldLen))
let valueShifted := <shl>(shiftBits, and(0xff, value))
let mask := <shl>(shiftBits, 0xff)
data := or(and(data, not(mask)), valueShifted)
sstore(array, data)
}
}
default {
sstore(array, add(data, 2))
let slot, offset := <indexAccess>(array, oldLen)
<storeValue>(slot, offset, value)
}
<!isByteArray>
let oldLen := sload(array)
if iszero(lt(oldLen, <maxArrayLength>)) { invalid() }
sstore(array, add(oldLen, 1))
let slot, offset := <indexAccess>(array, oldLen)
<storeValue>(slot, offset, value)
</isByteArray>
})")
("functionName", functionName)
("fetchLength", arrayLengthFunction(_type))
("extractByteArrayLength", _type.isByteArray() ? extractByteArrayLengthFunction() : "")
("dataAreaFunction", arrayDataAreaFunction(_type))
("isByteArray", _type.isByteArray())
("indexAccess", storageArrayIndexAccessFunction(_type))
("storeValue", updateStorageValueFunction(*_type.baseType()))
("maxArrayLength", (u256(1) << 64).str())
("shl", shiftLeftFunctionDynamic())
("shr", shiftRightFunction(248))
.render();
});
}
@ -947,21 +1051,33 @@ string YulUtilFunctions::arrayDataAreaFunction(ArrayType const& _type)
string YulUtilFunctions::storageArrayIndexAccessFunction(ArrayType const& _type)
{
solUnimplementedAssert(_type.baseType()->storageBytes() > 16, "");
string functionName = "storage_array_index_access_" + _type.identifier();
return m_functionCollector.createFunction(functionName, [&]() {
return Whiskers(R"(
function <functionName>(array, index) -> slot, offset {
if iszero(lt(index, <arrayLen>(array))) {
invalid()
}
let arrayLength := <arrayLen>(array)
if iszero(lt(index, arrayLength)) { invalid() }
let data := <dataAreaFunc>(array)
<?multipleItemsPerSlot>
<?isBytesArray>
offset := sub(31, mod(index, 0x20))
switch lt(arrayLength, 0x20)
case 0 {
let dataArea := <dataAreaFunc>(array)
slot := add(dataArea, div(index, 0x20))
}
default {
slot := array
}
<!isBytesArray>
let itemsPerSlot := div(0x20, <storageBytes>)
let dataArea := <dataAreaFunc>(array)
slot := add(dataArea, div(index, itemsPerSlot))
offset := mod(index, itemsPerSlot)
</isBytesArray>
<!multipleItemsPerSlot>
slot := add(data, mul(index, <storageSize>))
let dataArea := <dataAreaFunc>(array)
slot := add(dataArea, mul(index, <storageSize>))
offset := 0
</multipleItemsPerSlot>
}
@ -970,7 +1086,9 @@ string YulUtilFunctions::storageArrayIndexAccessFunction(ArrayType const& _type)
("arrayLen", arrayLengthFunction(_type))
("dataAreaFunc", arrayDataAreaFunction(_type))
("multipleItemsPerSlot", _type.baseType()->storageBytes() <= 16)
("isBytesArray", _type.isByteArray())
("storageSize", _type.baseType()->storageSize().str())
("storageBytes", toString(_type.baseType()->storageBytes()))
.render();
});
}

View File

@ -364,6 +364,14 @@ private:
/// use exactly one variable to hold the value.
std::string conversionFunctionSpecial(Type const& _from, Type const& _to);
/// @returns function name that extracts and returns byte array length
/// signature: (data) -> length
std::string extractByteArrayLengthFunction();
/// @returns the name of a function that reduces the size of a storage byte array by one element
/// signature: (byteArray)
std::string storageByteArrayPopFunction(ArrayType const& _type);
std::string readFromMemoryOrCalldata(Type const& _type, bool _fromCalldata);
langutil::EVMVersion m_evmVersion;

View File

@ -963,10 +963,12 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
")\n";
break;
}
case FunctionType::Kind::ByteArrayPush:
case FunctionType::Kind::ArrayPush:
{
auto const& memberAccessExpression = dynamic_cast<MemberAccess const&>(_functionCall.expression()).expression();
ArrayType const& arrayType = dynamic_cast<ArrayType const&>(*memberAccessExpression.annotation().type);
if (arguments.empty())
{
auto slotName = m_context.newYulVariable();

View File

@ -12,6 +12,7 @@ contract c {
l = data.length;
}
}
// ====
// compileViaYul: also
// ----
// test() -> 2, 1, 1

View File

@ -9,5 +9,7 @@ contract c {
return true;
}
}
// ====
// compileViaYul: also
// ----
// test() -> FAILURE

View File

@ -13,6 +13,7 @@ contract c {
if (l != 0x03) return true;
}
}
// ====
// compileViaYul: also
// ----
// test() -> false

View File

@ -13,6 +13,7 @@ contract c {
return 0;
}
}
// ====
// compileViaYul: also
// ----
// test() -> 0

View File

@ -0,0 +1,45 @@
contract c {
bytes data;
function test_short() public returns (uint256 r) {
assembly {
sstore(data_slot, 0)
}
for (uint8 i = 0; i < 15; i++) {
data.push(bytes1(i));
}
assembly {
r := sload(data_slot)
}
}
function test_long() public returns (uint256 r) {
assembly {
sstore(data_slot, 0)
}
for (uint8 i = 0; i < 33; i++) {
data.push(bytes1(i));
}
assembly {
r := sload(data_slot)
}
}
function test_pop() public returns (uint256 r) {
assembly {
sstore(data_slot, 0)
}
for (uint8 i = 0; i < 32; i++) {
data.push(bytes1(i));
}
data.pop();
assembly {
r := sload(data_slot)
}
}
}
// ====
// compileViaYul: also
// ----
// test_short() -> 1780731860627700044960722568376587075150542249149356309979516913770823710
// test_long() -> 67
// test_pop() -> 1780731860627700044960722568376592200742329637303199754547598369979440702

View File

@ -0,0 +1,21 @@
// Tests transition between short and long encoding both ways
contract c {
bytes data;
function test() public returns (uint256) {
for (uint8 i = 0; i < 33; i++) {
data.push(bytes1(i));
}
for (uint8 i = 0; i < data.length; i++)
if (data[i] != bytes1(i)) return i;
data.pop();
data.pop();
for (uint8 i = 0; i < data.length; i++)
if (data[i] != bytes1(i)) return i;
return 0;
}
}
// ====
// compileViaYul: also
// ----
// test() -> 0