Yul: Implement memory struct allocation

This commit is contained in:
Mathias Baumann 2020-06-24 18:14:29 +02:00
parent 3d602b3190
commit 50373ac1b0
11 changed files with 150 additions and 11 deletions

View File

@ -1646,13 +1646,30 @@ string YulUtilFunctions::allocateAndInitializeMemoryArrayFunction(ArrayType cons
}); });
} }
string YulUtilFunctions::allocateAndInitializeMemoryStructFunction(StructType const& _type) string YulUtilFunctions::allocateMemoryStructFunction(StructType const& _type)
{ {
string functionName = "allocate_and_initialize_memory_struct_" + _type.identifier(); string functionName = "allocate_memory_struct_" + _type.identifier();
return m_functionCollector.createFunction(functionName, [&]() { return m_functionCollector.createFunction(functionName, [&]() {
Whiskers templ(R"( Whiskers templ(R"(
function <functionName>() -> memPtr { function <functionName>() -> memPtr {
memPtr := <alloc>(<allocSize>) memPtr := <alloc>(<allocSize>)
}
)");
templ("functionName", functionName);
templ("alloc", allocationFunction());
templ("allocSize", _type.memoryDataSize().str());
return templ.render();
});
}
string YulUtilFunctions::allocateAndInitializeMemoryStructFunction(StructType const& _type)
{
string functionName = "allocate_and_zero_memory_struct_" + _type.identifier();
return m_functionCollector.createFunction(functionName, [&]() {
Whiskers templ(R"(
function <functionName>() -> memPtr {
memPtr := <allocStruct>()
let offset := memPtr let offset := memPtr
<#member> <#member>
mstore(offset, <zeroValue>()) mstore(offset, <zeroValue>())
@ -1661,10 +1678,9 @@ string YulUtilFunctions::allocateAndInitializeMemoryStructFunction(StructType co
} }
)"); )");
templ("functionName", functionName); templ("functionName", functionName);
templ("alloc", allocationFunction()); templ("allocStruct", allocateMemoryStructFunction(_type));
TypePointers const& members = _type.memoryMemberTypes(); TypePointers const& members = _type.memoryMemberTypes();
templ("allocSize", _type.memoryDataSize().str());
vector<map<string, string>> memberParams(members.size()); vector<map<string, string>> memberParams(members.size());
for (size_t i = 0; i < members.size(); ++i) for (size_t i = 0; i < members.size(); ++i)

View File

@ -286,8 +286,13 @@ public:
/// signature: (length) -> memPtr /// signature: (length) -> memPtr
std::string allocateAndInitializeMemoryArrayFunction(ArrayType const& _type); std::string allocateAndInitializeMemoryArrayFunction(ArrayType const& _type);
/// @returns the name of a function that allocates a memory struct (no
/// initialization takes place).
/// signature: () -> memPtr
std::string allocateMemoryStructFunction(StructType const& _type);
/// @returns the name of a function that allocates and zeroes a memory struct. /// @returns the name of a function that allocates and zeroes a memory struct.
/// signature: (members) -> memPtr /// signature: () -> memPtr
std::string allocateAndInitializeMemoryStructFunction(StructType const& _type); std::string allocateAndInitializeMemoryStructFunction(StructType const& _type);
/// @returns the name of the function that converts a value of type @a _from /// @returns the name of the function that converts a value of type @a _from

View File

@ -600,22 +600,30 @@ bool IRGeneratorForStatements::visit(FunctionCall const& _functionCall)
void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall) void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
{ {
solUnimplementedAssert( solUnimplementedAssert(
_functionCall.annotation().kind == FunctionCallKind::FunctionCall || _functionCall.annotation().kind != FunctionCallKind::Unset,
_functionCall.annotation().kind == FunctionCallKind::TypeConversion,
"This type of function call is not yet implemented" "This type of function call is not yet implemented"
); );
Type const& funcType = type(_functionCall.expression());
if (_functionCall.annotation().kind == FunctionCallKind::TypeConversion) if (_functionCall.annotation().kind == FunctionCallKind::TypeConversion)
{ {
solAssert(funcType.category() == Type::Category::TypeType, "Expected category to be TypeType"); solAssert(
_functionCall.expression().annotation().type->category() == Type::Category::TypeType,
"Expected category to be TypeType"
);
solAssert(_functionCall.arguments().size() == 1, "Expected one argument for type conversion"); solAssert(_functionCall.arguments().size() == 1, "Expected one argument for type conversion");
define(_functionCall, *_functionCall.arguments().front()); define(_functionCall, *_functionCall.arguments().front());
return; return;
} }
FunctionTypePointer functionType = dynamic_cast<FunctionType const*>(&funcType); FunctionTypePointer functionType = nullptr;
if (_functionCall.annotation().kind == FunctionCallKind::StructConstructorCall)
{
auto const& type = dynamic_cast<TypeType const&>(*_functionCall.expression().annotation().type);
auto const& structType = dynamic_cast<StructType const&>(*type.actualType());
functionType = structType.constructorType();
}
else
functionType = dynamic_cast<FunctionType const*>(_functionCall.expression().annotation().type);
TypePointers parameterTypes = functionType->parameterTypes(); TypePointers parameterTypes = functionType->parameterTypes();
vector<ASTPointer<Expression const>> const& callArguments = _functionCall.arguments(); vector<ASTPointer<Expression const>> const& callArguments = _functionCall.arguments();
@ -639,6 +647,34 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
arguments.push_back(callArguments[static_cast<size_t>(std::distance(callArgumentNames.begin(), it))]); arguments.push_back(callArguments[static_cast<size_t>(std::distance(callArgumentNames.begin(), it))]);
} }
if (_functionCall.annotation().kind == FunctionCallKind::StructConstructorCall)
{
TypeType const& type = dynamic_cast<TypeType const&>(*_functionCall.expression().annotation().type);
auto const& structType = dynamic_cast<StructType const&>(*type.actualType());
define(_functionCall) << m_utils.allocateMemoryStructFunction(structType) << "()\n";
MemberList::MemberMap members = structType.nativeMembers(nullptr);
solAssert(members.size() == arguments.size(), "Struct parameter mismatch.");
for (size_t i = 0; i < arguments.size(); i++)
{
IRVariable converted = convert(*arguments[i], *parameterTypes[i]);
m_code <<
m_utils.writeToMemoryFunction(*functionType->parameterTypes()[i]) <<
"(add(" <<
IRVariable(_functionCall).part("mpos").name() <<
", " <<
structType.memoryOffsetOfMember(members[i].name) <<
"), " <<
converted.commaSeparatedList() <<
")\n";
}
return;
}
auto memberAccess = dynamic_cast<MemberAccess const*>(&_functionCall.expression()); auto memberAccess = dynamic_cast<MemberAccess const*>(&_functionCall.expression());
if (memberAccess) if (memberAccess)
{ {

View File

@ -6,5 +6,7 @@ contract C {
return (s.a, s.b); return (s.a, s.b);
} }
} }
// ====
// compileViaYul: also
// ---- // ----
// f((uint256,uint256)): 42, 23 -> 42, 23 // f((uint256,uint256)): 42, 23 -> 42, 23

View File

@ -0,0 +1,18 @@
pragma experimental ABIEncoderV2;
contract C {
struct S {
uint256 a;
bool x;
}
function s() public returns(S memory)
{
return S({x: true, a: 8});
}
}
// ====
// compileViaYul: also
// ----
// s() -> 8, true

View File

@ -28,5 +28,7 @@ contract Test {
} }
} }
// ====
// compileViaYul: also
// ---- // ----
// test() -> 1, 2, 3 // test() -> 1, 2, 3

View File

@ -38,5 +38,7 @@ contract Test {
} }
} }
// ====
// compileViaYul: also
// ---- // ----
// test() -> 1, 2, 3, 4 // test() -> 1, 2, 3, 4

View File

@ -0,0 +1,24 @@
contract C {
struct I {
uint b;
uint c;
function(uint) external returns (uint) x;
}
struct S {
I a;
}
function o(uint a) external returns(uint) { return a+1; }
function f() external returns (uint) {
S memory s = S(I(1,2, this.o));
return s.a.x(1);
}
}
// ====
// compileViaYul: also
// ----
// f() -> 2

View File

@ -0,0 +1,18 @@
contract C {
struct I {
uint b;
uint c;
}
struct S {
I a;
}
function f() external returns (uint) {
S memory s = S(I(1,2));
return s.a.b;
}
}
// ====
// compileViaYul: also
// ----
// f() -> 1

View File

@ -0,0 +1,14 @@
contract C {
struct S {
uint a;
}
function f() external returns (uint) {
S memory s = S(1);
return s.a;
}
}
// ====
// compileViaYul: also
// ----
// f() -> 1

View File

@ -27,6 +27,8 @@ contract test {
data.recursive[4].z = 9; data.recursive[4].z = 9;
} }
} }
// ====
// compileViaYul: also
// ---- // ----
// check() -> false // check() -> false
// set() -> // set() ->