Properly handle fixed-byte-like types.

This commit is contained in:
chriseth 2021-09-28 14:21:39 +02:00
parent 6109b5c3a1
commit da5c5928fe
4 changed files with 20 additions and 14 deletions

View File

@ -2537,6 +2537,7 @@ Type const& UserDefinedValueType::underlyingType() const
{ {
Type const* type = m_definition.underlyingType()->annotation().type; Type const* type = m_definition.underlyingType()->annotation().type;
solAssert(type, ""); solAssert(type, "");
solAssert(type->category() != Category::UserDefinedValueType, "");
return *type; return *type;
} }

View File

@ -1552,10 +1552,13 @@ void CompilerUtils::storeStringData(bytesConstRef _data)
unsigned CompilerUtils::loadFromMemoryHelper(Type const& _type, bool _fromCalldata, bool _padToWords) unsigned CompilerUtils::loadFromMemoryHelper(Type const& _type, bool _fromCalldata, bool _padToWords)
{ {
solAssert(_type.isValueType(), ""); solAssert(_type.isValueType(), "");
Type const* type = &_type;
if (auto const* userDefined = dynamic_cast<UserDefinedValueType const*>(type))
type = &userDefined->underlyingType();
unsigned numBytes = _type.calldataEncodedSize(_padToWords); unsigned numBytes = type->calldataEncodedSize(_padToWords);
bool isExternalFunctionType = false; bool isExternalFunctionType = false;
if (auto const* funType = dynamic_cast<FunctionType const*>(&_type)) if (auto const* funType = dynamic_cast<FunctionType const*>(type))
if (funType->kind() == FunctionType::Kind::External) if (funType->kind() == FunctionType::Kind::External)
isExternalFunctionType = true; isExternalFunctionType = true;
if (numBytes == 0) if (numBytes == 0)
@ -1570,21 +1573,20 @@ unsigned CompilerUtils::loadFromMemoryHelper(Type const& _type, bool _fromCallda
splitExternalFunctionType(true); splitExternalFunctionType(true);
else if (numBytes != 32) else if (numBytes != 32)
{ {
bool leftAligned = _type.category() == Type::Category::FixedBytes;
// add leading or trailing zeros by dividing/multiplying depending on alignment // add leading or trailing zeros by dividing/multiplying depending on alignment
unsigned shiftFactor = (32 - numBytes) * 8; unsigned shiftFactor = (32 - numBytes) * 8;
rightShiftNumberOnStack(shiftFactor); rightShiftNumberOnStack(shiftFactor);
if (leftAligned) if (type->leftAligned())
{ {
leftShiftNumberOnStack(shiftFactor); leftShiftNumberOnStack(shiftFactor);
cleanupNeeded = false; cleanupNeeded = false;
} }
else if (IntegerType const* intType = dynamic_cast<IntegerType const*>(&_type)) else if (IntegerType const* intType = dynamic_cast<IntegerType const*>(type))
if (!intType->isSigned()) if (!intType->isSigned())
cleanupNeeded = false; cleanupNeeded = false;
} }
if (_fromCalldata) if (_fromCalldata)
convertType(_type, _type, cleanupNeeded, false, true); convertType(_type, *type, cleanupNeeded, false, true);
return numBytes; return numBytes;
} }
@ -1639,12 +1641,10 @@ unsigned CompilerUtils::prepareMemoryStore(Type const& _type, bool _padToWords,
"Memory store of more than 32 bytes requested (Type: " + _type.toString(true) + ")." "Memory store of more than 32 bytes requested (Type: " + _type.toString(true) + ")."
); );
bool leftAligned = _type.category() == Type::Category::FixedBytes;
if (_cleanup) if (_cleanup)
convertType(_type, _type, true); convertType(_type, _type, true);
if (numBytes != 32 && !leftAligned && !_padToWords) if (numBytes != 32 && !_type.leftAligned() && !_padToWords)
// shift the value accordingly before storing // shift the value accordingly before storing
leftShiftNumberOnStack((32 - numBytes) * 8); leftShiftNumberOnStack((32 - numBytes) * 8);

View File

@ -113,6 +113,7 @@ void MemoryItem::storeValue(Type const& _sourceType, SourceLocation const&, bool
if (!m_padded) if (!m_padded)
{ {
solAssert(m_dataType->calldataEncodedSize(false) == 1, "Invalid non-padded type."); solAssert(m_dataType->calldataEncodedSize(false) == 1, "Invalid non-padded type.");
solAssert(m_dataType->category() != Type::Category::UserDefinedValueType, "");
if (m_dataType->category() == Type::Category::FixedBytes) if (m_dataType->category() == Type::Category::FixedBytes)
m_context << u256(0) << Instruction::BYTE; m_context << u256(0) << Instruction::BYTE;
m_context << Instruction::SWAP1 << Instruction::MSTORE8; m_context << Instruction::SWAP1 << Instruction::MSTORE8;
@ -233,7 +234,7 @@ void StorageItem::retrieveValue(SourceLocation const&, bool _remove) const
if (m_dataType->category() == Type::Category::FixedPoint) if (m_dataType->category() == Type::Category::FixedPoint)
// implementation should be very similar to the integer case. // implementation should be very similar to the integer case.
solUnimplemented("Not yet implemented - FixedPointType."); solUnimplemented("Not yet implemented - FixedPointType.");
if (m_dataType->category() == Type::Category::FixedBytes) if (m_dataType->leftAligned())
{ {
CompilerUtils(m_context).leftShiftNumberOnStack(256 - 8 * m_dataType->storageBytes()); CompilerUtils(m_context).leftShiftNumberOnStack(256 - 8 * m_dataType->storageBytes());
cleaned = true; cleaned = true;
@ -329,10 +330,13 @@ void StorageItem::storeValue(Type const& _sourceType, SourceLocation const& _loc
Instruction::AND; Instruction::AND;
} }
} }
else if (m_dataType->category() == Type::Category::FixedBytes) else if (m_dataType->leftAligned())
{ {
solAssert(_sourceType.category() == Type::Category::FixedBytes, "source not fixed bytes"); solAssert(_sourceType.category() == Type::Category::FixedBytes || (
CompilerUtils(m_context).rightShiftNumberOnStack(256 - 8 * dynamic_cast<FixedBytesType const&>(*m_dataType).numBytes()); _sourceType.encodingType() &&
_sourceType.encodingType()->category() == Type::Category::FixedBytes
), "source not fixed bytes");
CompilerUtils(m_context).rightShiftNumberOnStack(256 - 8 * m_dataType->storageBytes());
} }
else else
{ {

View File

@ -2965,7 +2965,7 @@ string YulUtilFunctions::prepareStoreFunction(Type const& _type)
} }
)"); )");
templ("functionName", functionName); templ("functionName", functionName);
if (_type.category() == Type::Category::FixedBytes) if (_type.leftAligned())
templ("actualPrepare", shiftRightFunction(256 - 8 * _type.storageBytes()) + "(value)"); templ("actualPrepare", shiftRightFunction(256 - 8 * _type.storageBytes()) + "(value)");
else else
templ("actualPrepare", "value"); templ("actualPrepare", "value");
@ -3304,6 +3304,7 @@ string YulUtilFunctions::conversionFunction(Type const& _from, Type const& _to)
bodyTemplate("cleanOutput", cleanupFunction(_to)); bodyTemplate("cleanOutput", cleanupFunction(_to));
string convert; string convert;
solAssert(_to.category() != Type::Category::UserDefinedValueType, "");
if (auto const* toFixedBytes = dynamic_cast<FixedBytesType const*>(&_to)) if (auto const* toFixedBytes = dynamic_cast<FixedBytesType const*>(&_to))
convert = shiftLeftFunction(256 - toFixedBytes->numBytes() * 8); convert = shiftLeftFunction(256 - toFixedBytes->numBytes() * 8);
else if (dynamic_cast<FixedPointType const*>(&_to)) else if (dynamic_cast<FixedPointType const*>(&_to))