[SMTChecker] Implement short circuit

This commit is contained in:
Leonardo Alt 2019-03-22 17:05:58 +01:00
parent a1d59dfb4c
commit a7e826a224
3 changed files with 38 additions and 15 deletions

View File

@ -200,12 +200,12 @@ bool SMTChecker::visit(IfStatement const& _node)
if (isRootFunction()) if (isRootFunction())
checkBooleanNotConstant(_node.condition(), "Condition is always $VALUE."); checkBooleanNotConstant(_node.condition(), "Condition is always $VALUE.");
auto indicesEndTrue = visitBranch(_node.trueStatement(), expr(_node.condition())); auto indicesEndTrue = visitBranch(&_node.trueStatement(), expr(_node.condition()));
vector<VariableDeclaration const*> touchedVariables = m_variableUsage->touchedVariables(_node.trueStatement()); vector<VariableDeclaration const*> touchedVariables = m_variableUsage->touchedVariables(_node.trueStatement());
decltype(indicesEndTrue) indicesEndFalse; decltype(indicesEndTrue) indicesEndFalse;
if (_node.falseStatement()) if (_node.falseStatement())
{ {
indicesEndFalse = visitBranch(*_node.falseStatement(), !expr(_node.condition())); indicesEndFalse = visitBranch(_node.falseStatement(), !expr(_node.condition()));
touchedVariables += m_variableUsage->touchedVariables(*_node.falseStatement()); touchedVariables += m_variableUsage->touchedVariables(*_node.falseStatement());
} }
else else
@ -233,7 +233,7 @@ bool SMTChecker::visit(WhileStatement const& _node)
decltype(indicesBeforeLoop) indicesAfterLoop; decltype(indicesBeforeLoop) indicesAfterLoop;
if (_node.isDoWhile()) if (_node.isDoWhile())
{ {
indicesAfterLoop = visitBranch(_node.body()); indicesAfterLoop = visitBranch(&_node.body());
// TODO the assertions generated in the body should still be active in the condition // TODO the assertions generated in the body should still be active in the condition
_node.condition().accept(*this); _node.condition().accept(*this);
if (isRootFunction()) if (isRootFunction())
@ -245,7 +245,7 @@ bool SMTChecker::visit(WhileStatement const& _node)
if (isRootFunction()) if (isRootFunction())
checkBooleanNotConstant(_node.condition(), "While loop condition is always $VALUE."); checkBooleanNotConstant(_node.condition(), "While loop condition is always $VALUE.");
indicesAfterLoop = visitBranch(_node.body(), expr(_node.condition())); indicesAfterLoop = visitBranch(&_node.body(), expr(_node.condition()));
} }
// We reset the execution to before the loop // We reset the execution to before the loop
@ -502,20 +502,27 @@ bool SMTChecker::visit(UnaryOperation const& _op)
bool SMTChecker::visit(BinaryOperation const& _op) bool SMTChecker::visit(BinaryOperation const& _op)
{ {
return !shortcutRationalNumber(_op); if (shortcutRationalNumber(_op))
return false;
if (TokenTraits::isBooleanOp(_op.getOperator()))
{
booleanOperation(_op);
return false;
}
return true;
} }
void SMTChecker::endVisit(BinaryOperation const& _op) void SMTChecker::endVisit(BinaryOperation const& _op)
{ {
if (_op.annotation().type->category() == Type::Category::RationalNumber) if (_op.annotation().type->category() == Type::Category::RationalNumber)
return; return;
if (TokenTraits::isBooleanOp(_op.getOperator()))
return;
if (TokenTraits::isArithmeticOp(_op.getOperator())) if (TokenTraits::isArithmeticOp(_op.getOperator()))
arithmeticOperation(_op); arithmeticOperation(_op);
else if (TokenTraits::isCompareOp(_op.getOperator())) else if (TokenTraits::isCompareOp(_op.getOperator()))
compareOperation(_op); compareOperation(_op);
else if (TokenTraits::isBooleanOp(_op.getOperator()))
booleanOperation(_op);
else else
m_errorReporter.warning( m_errorReporter.warning(
_op.location(), _op.location(),
@ -1095,11 +1102,21 @@ void SMTChecker::booleanOperation(BinaryOperation const& _op)
if (_op.annotation().commonType->category() == Type::Category::Bool) if (_op.annotation().commonType->category() == Type::Category::Bool)
{ {
// @TODO check that both of them are not constant // @TODO check that both of them are not constant
_op.leftExpression().accept(*this);
auto touchedVariables = m_variableUsage->touchedVariables(_op.leftExpression());
if (_op.getOperator() == Token::And) if (_op.getOperator() == Token::And)
{
auto indicesAfterSecond = visitBranch(&_op.rightExpression(), expr(_op.leftExpression()));
mergeVariables(touchedVariables, !expr(_op.leftExpression()), copyVariableIndices(), indicesAfterSecond);
defineExpr(_op, expr(_op.leftExpression()) && expr(_op.rightExpression())); defineExpr(_op, expr(_op.leftExpression()) && expr(_op.rightExpression()));
}
else else
{
auto indicesAfterSecond = visitBranch(&_op.rightExpression(), !expr(_op.leftExpression()));
mergeVariables(touchedVariables, expr(_op.leftExpression()), copyVariableIndices(), indicesAfterSecond);
defineExpr(_op, expr(_op.leftExpression()) || expr(_op.rightExpression())); defineExpr(_op, expr(_op.leftExpression()) || expr(_op.rightExpression()));
} }
}
else else
m_errorReporter.warning( m_errorReporter.warning(
_op.location(), _op.location(),
@ -1137,17 +1154,17 @@ void SMTChecker::assignment(VariableDeclaration const& _variable, smt::Expressio
m_interface->addAssertion(newValue(_variable) == _value); m_interface->addAssertion(newValue(_variable) == _value);
} }
SMTChecker::VariableIndices SMTChecker::visitBranch(Statement const& _statement, smt::Expression _condition) SMTChecker::VariableIndices SMTChecker::visitBranch(ASTNode const* _statement, smt::Expression _condition)
{ {
return visitBranch(_statement, &_condition); return visitBranch(_statement, &_condition);
} }
SMTChecker::VariableIndices SMTChecker::visitBranch(Statement const& _statement, smt::Expression const* _condition) SMTChecker::VariableIndices SMTChecker::visitBranch(ASTNode const* _statement, smt::Expression const* _condition)
{ {
auto indicesBeforeBranch = copyVariableIndices(); auto indicesBeforeBranch = copyVariableIndices();
if (_condition) if (_condition)
pushPathCondition(*_condition); pushPathCondition(*_condition);
_statement.accept(*this); _statement->accept(*this);
if (_condition) if (_condition)
popPathCondition(); popPathCondition();
auto indicesAfterBranch = copyVariableIndices(); auto indicesAfterBranch = copyVariableIndices();
@ -1456,7 +1473,12 @@ TypePointer SMTChecker::typeWithoutPointer(TypePointer const& _type)
void SMTChecker::mergeVariables(vector<VariableDeclaration const*> const& _variables, smt::Expression const& _condition, VariableIndices const& _indicesEndTrue, VariableIndices const& _indicesEndFalse) void SMTChecker::mergeVariables(vector<VariableDeclaration const*> const& _variables, smt::Expression const& _condition, VariableIndices const& _indicesEndTrue, VariableIndices const& _indicesEndFalse)
{ {
set<VariableDeclaration const*> uniqueVars(_variables.begin(), _variables.end()); set<VariableDeclaration const*> uniqueVars(_variables.begin(), _variables.end());
for (auto const* decl: uniqueVars) mergeVariables(uniqueVars, _condition, _indicesEndTrue, _indicesEndFalse);
}
void SMTChecker::mergeVariables(set<VariableDeclaration const*> const& _variables, smt::Expression const& _condition, VariableIndices const& _indicesEndTrue, VariableIndices const& _indicesEndFalse)
{
for (auto const* decl: _variables)
{ {
solAssert(_indicesEndTrue.count(decl) && _indicesEndFalse.count(decl), ""); solAssert(_indicesEndTrue.count(decl) && _indicesEndFalse.count(decl), "");
int trueIndex = _indicesEndTrue.at(decl); int trueIndex = _indicesEndTrue.at(decl);

View File

@ -132,8 +132,8 @@ private:
/// Visits the branch given by the statement, pushes and pops the current path conditions. /// Visits the branch given by the statement, pushes and pops the current path conditions.
/// @param _condition if present, asserts that this condition is true within the branch. /// @param _condition if present, asserts that this condition is true within the branch.
/// @returns the variable indices after visiting the branch. /// @returns the variable indices after visiting the branch.
VariableIndices visitBranch(Statement const& _statement, smt::Expression const* _condition = nullptr); VariableIndices visitBranch(ASTNode const* _statement, smt::Expression const* _condition = nullptr);
VariableIndices visitBranch(Statement const& _statement, smt::Expression _condition); VariableIndices visitBranch(ASTNode const* _statement, smt::Expression _condition);
/// Check that a condition can be satisfied. /// Check that a condition can be satisfied.
void checkCondition( void checkCondition(
@ -199,6 +199,7 @@ private:
/// merge the touched variables into after-branch ite variables /// merge the touched variables into after-branch ite variables
/// using the branch condition as guard. /// using the branch condition as guard.
void mergeVariables(std::vector<VariableDeclaration const*> const& _variables, smt::Expression const& _condition, VariableIndices const& _indicesEndTrue, VariableIndices const& _indicesEndFalse); void mergeVariables(std::vector<VariableDeclaration const*> const& _variables, smt::Expression const& _condition, VariableIndices const& _indicesEndTrue, VariableIndices const& _indicesEndFalse);
void mergeVariables(std::set<VariableDeclaration const*> const& _variables, smt::Expression const& _condition, VariableIndices const& _indicesEndTrue, VariableIndices const& _indicesEndFalse);
/// Tries to create an uninitialized variable and returns true on success. /// Tries to create an uninitialized variable and returns true on success.
/// This fails if the type is not supported. /// This fails if the type is not supported.
bool createVariable(VariableDeclaration const& _varDecl); bool createVariable(VariableDeclaration const& _varDecl);

View File

@ -8,7 +8,7 @@ contract c {
} }
function g() public { function g() public {
x = 0; x = 0;
assert((f() > 0) || (f() > 0)); bool b = (f() > 0) || (f() > 0);
// This assertion should NOT fail. // This assertion should NOT fail.
// It currently does because the SMTChecker does not // It currently does because the SMTChecker does not
// handle short-circuiting properly and inlines f() twice. // handle short-circuiting properly and inlines f() twice.