[Sol -> Yul] Fix ForLoops and implement WhileLoops

This commit is contained in:
Mathias Baumann 2019-05-08 13:27:36 +02:00
parent f61348728c
commit 0abe00d393
6 changed files with 196 additions and 38 deletions

View File

@ -152,27 +152,28 @@ bool IRGeneratorForStatements::visit(TupleExpression const& _tuple)
return false;
}
bool IRGeneratorForStatements::visit(ForStatement const& _for)
bool IRGeneratorForStatements::visit(ForStatement const& _forStatement)
{
m_code << "for {\n";
if (_for.initializationExpression())
_for.initializationExpression()->accept(*this);
m_code << "} return_flag {\n";
if (_for.loopExpression())
_for.loopExpression()->accept(*this);
m_code << "}\n";
if (_for.condition())
{
_for.condition()->accept(*this);
m_code <<
"if iszero(" <<
expressionAsType(*_for.condition(), *TypeProvider::boolean()) <<
") { break }\n";
}
_for.body().accept(*this);
m_code << "}\n";
// Bubble up the return condition.
m_code << "if iszero(return_flag) { break }\n";
generateLoop(
_forStatement.body(),
_forStatement.condition(),
_forStatement.initializationExpression(),
_forStatement.loopExpression()
);
return false;
}
bool IRGeneratorForStatements::visit(WhileStatement const& _whileStatement)
{
generateLoop(
_whileStatement.body(),
&_whileStatement.condition(),
nullptr,
nullptr,
_whileStatement.isDoWhile()
);
return false;
}
@ -796,6 +797,54 @@ void IRGeneratorForStatements::setLValue(Expression const& _expression, unique_p
defineExpression(_expression) << _lvalue->retrieveValue() << "\n";
}
void IRGeneratorForStatements::generateLoop(
Statement const& _body,
Expression const* _conditionExpression,
Statement const* _initExpression,
ExpressionStatement const* _loopExpression,
bool _isDoWhile
)
{
string firstRun;
if (_isDoWhile)
{
solAssert(_conditionExpression, "Expected condition for doWhile");
firstRun = m_context.newYulVariable();
m_code << "let " << firstRun << " := 1\n";
}
m_code << "for {\n";
if (_initExpression)
_initExpression->accept(*this);
m_code << "} return_flag {\n";
if (_loopExpression)
_loopExpression->accept(*this);
m_code << "}\n";
m_code << "{\n";
if (_conditionExpression)
{
if (_isDoWhile)
m_code << "if iszero(" << firstRun << ") {\n";
_conditionExpression->accept(*this);
m_code <<
"if iszero(" <<
expressionAsType(*_conditionExpression, *TypeProvider::boolean()) <<
") { break }\n";
if (_isDoWhile)
m_code << "}\n" << firstRun << " := 0\n";
}
_body.accept(*this);
m_code << "}\n";
// Bubble up the return condition.
m_code << "if iszero(return_flag) { break }\n";
}
Type const& IRGeneratorForStatements::type(Expression const& _expression)
{
solAssert(_expression.annotation().type, "Type of expression not set.");

View File

@ -49,6 +49,7 @@ public:
bool visit(Assignment const& _assignment) override;
bool visit(TupleExpression const& _tuple) override;
bool visit(ForStatement const& _forStatement) override;
bool visit(WhileStatement const& _whileStatement) override;
bool visit(Continue const& _continueStatement) override;
bool visit(Break const& _breakStatement) override;
void endVisit(Return const& _return) override;
@ -70,6 +71,13 @@ private:
void appendAndOrOperatorCode(BinaryOperation const& _binOp);
void setLValue(Expression const& _expression, std::unique_ptr<IRLValue> _lvalue);
void generateLoop(
Statement const& _body,
Expression const* _conditionExpression,
Statement const* _initExpression = nullptr,
ExpressionStatement const* _loopExpression = nullptr,
bool _isDoWhile = false
);
static Type const& type(Expression const& _expression);

View File

@ -0,0 +1,31 @@
contract C {
function f() public returns (uint x) {
x = 1;
for (uint a = 0; a < 10; a = a + 1) {
x = x + x;
break;
}
}
function g() public returns (uint x) {
x = 1;
uint a = 0;
while (a < 10) {
x = x + x;
break;
a = a + 1;
}
}
function h() public returns (uint x) {
x = 1;
do {
x = x + 1;
break;
} while (x < 3);
}
}
// ====
// compileViaYul: true
// ----
// f() -> 2
// g() -> 2
// h() -> 2

View File

@ -0,0 +1,37 @@
contract C {
function f() public returns (uint x) {
x = 1;
uint a = 0;
for (; a < 10; a = a + 1) {
continue;
x = x + x;
}
x = x + a;
}
function g() public returns (uint x) {
x = 1;
uint a = 0;
while (a < 10) {
a = a + 1;
continue;
x = x + x;
}
x = x + a;
}
function h() public returns (uint x) {
x = 1;
uint a = 0;
do {
a = a + 1;
continue;
x = x + x;
} while (a < 4);
x = x + a;
}
}
// ====
// compileViaYul: true
// ----
// f() -> 11
// g() -> 11
// h() -> 5

View File

@ -0,0 +1,34 @@
contract C {
function f() public returns (uint x) {
x = 1;
uint a;
for (; a < 10; a = a + 1) {
return x;
x = x + x;
}
x = x + a;
}
function g() public returns (uint x) {
x = 1;
uint a;
while (a < 10) {
return x;
x = x + x;
a = a + 1;
}
x = x + a;
}
function h() public returns (uint x) {
x = 1;
do {
x = x + 1;
return x;
} while (x < 3);
}
}
// ====
// compileViaYul: true
// ----
// f() -> 1
// g() -> 1
// h() -> 2

View File

@ -7,34 +7,33 @@ contract C {
}
function g() public returns (uint x) {
x = 1;
for (uint a = 0; a < 10; a = a + 1) {
uint a = 0;
while (a < 10) {
x = x + x;
break;
a = a + 1;
}
}
function h() public returns (uint x) {
x = 1;
uint a = 0;
for (; a < 10; a = a + 1) {
continue;
x = x + x;
}
x = x + a;
do {
x = x + 1;
} while (false);
}
function i() public returns (uint x) {
x = 1;
uint a;
for (; a < 10; a = a + 1) {
return x;
x = x + x;
}
x = x + a;
do {
x = x + 1;
} while (x < 3);
}
function j() public {
for (;;) {break;}
}
}
// ===
// ====
// compileViaYul: true
// ----
// f() -> 1024
// g() -> 2
// h() -> 11
// i() -> 1
// g() -> 1024
// h() -> 2
// i() -> 3
// j() ->