diff --git a/libsolidity/formal/SMTChecker.cpp b/libsolidity/formal/SMTChecker.cpp index 129822049..e1fd2bfdf 100644 --- a/libsolidity/formal/SMTChecker.cpp +++ b/libsolidity/formal/SMTChecker.cpp @@ -25,6 +25,8 @@ #include +#include + using namespace std; using namespace dev; using namespace dev::solidity; @@ -51,6 +53,7 @@ void SMTChecker::analyze(SourceUnit const& _source) { m_interface->reset(); m_currentSequenceCounter.clear(); + m_nextFreeSequenceCounter.clear(); _source.accept(*this); } } @@ -89,10 +92,68 @@ void SMTChecker::endVisit(FunctionDefinition const&) // We only handle local variables, so we clear everything. // If we add storage variables, those should be cleared differently. m_currentSequenceCounter.clear(); + m_nextFreeSequenceCounter.clear(); m_interface->pop(); m_currentFunction = nullptr; } +bool SMTChecker::visit(IfStatement const& _node) +{ + _node.condition().accept(*this); + + // TODO Check if condition is always true + + auto countersAtStart = m_currentSequenceCounter; + m_interface->push(); + m_interface->addAssertion(expr(_node.condition())); + _node.trueStatement().accept(*this); + auto countersAtEndOfTrue = m_currentSequenceCounter; + m_interface->pop(); + + decltype(m_currentSequenceCounter) countersAtEndOfFalse; + if (_node.falseStatement()) + { + m_currentSequenceCounter = countersAtStart; + m_interface->push(); + m_interface->addAssertion(!expr(_node.condition())); + _node.falseStatement()->accept(*this); + countersAtEndOfFalse = m_currentSequenceCounter; + m_interface->pop(); + } + else + countersAtEndOfFalse = countersAtStart; + + // Reset all values that have been touched. + + // TODO this should use a previously generated side-effect structure + + solAssert(countersAtEndOfFalse.size() == countersAtEndOfTrue.size(), ""); + for (auto const& declCounter: countersAtEndOfTrue) + { + solAssert(countersAtEndOfFalse.count(declCounter.first), ""); + auto decl = declCounter.first; + int trueCounter = countersAtEndOfTrue.at(decl); + int falseCounter = countersAtEndOfFalse.at(decl); + if (trueCounter == falseCounter) + continue; // Was not modified + newValue(*decl); + setValue(*decl, 0); + } + return false; +} + +bool SMTChecker::visit(WhileStatement const& _node) +{ + _node.condition().accept(*this); + + //m_interface->push(); + //m_interface->addAssertion(expr(_node.condition())); + // TDOO clear knowledge (increment sequence numbers and add bounds assertions ) apart from assertions + + // TODO combine similar to if + return true; +} + void SMTChecker::endVisit(VariableDeclarationStatement const& _varDecl) { if (_varDecl.declarations().size() != 1) @@ -100,10 +161,13 @@ void SMTChecker::endVisit(VariableDeclarationStatement const& _varDecl) _varDecl.location(), "Assertion checker does not yet support such variable declarations." ); - else if (knownVariable(*_varDecl.declarations()[0]) && _varDecl.initialValue()) - // TODO more checks? - // TODO add restrictions about type (might be assignment from smaller type) - m_interface->addAssertion(newValue(*_varDecl.declarations()[0]) == expr(*_varDecl.initialValue())); + else if (knownVariable(*_varDecl.declarations()[0])) + { + if (_varDecl.initialValue()) + // TODO more checks? + // TODO add restrictions about type (might be assignment from smaller type) + m_interface->addAssertion(newValue(*_varDecl.declarations()[0]) == expr(*_varDecl.initialValue())); + } else m_errorReporter.warning( _varDecl.location(), @@ -421,19 +485,15 @@ void SMTChecker::checkCondition( void SMTChecker::createVariable(VariableDeclaration const& _varDecl, bool _setToZero) { - if (auto intType = dynamic_cast(_varDecl.type().get())) + if (dynamic_cast(_varDecl.type().get())) { solAssert(m_currentSequenceCounter.count(&_varDecl) == 0, ""); + solAssert(m_nextFreeSequenceCounter.count(&_varDecl) == 0, ""); solAssert(m_z3Variables.count(&_varDecl) == 0, ""); m_currentSequenceCounter[&_varDecl] = 0; + m_nextFreeSequenceCounter[&_varDecl] = 1; m_z3Variables.emplace(&_varDecl, m_interface->newFunction(uniqueSymbol(_varDecl), smt::Sort::Int, smt::Sort::Int)); - if (_setToZero) - m_interface->addAssertion(currentValue(_varDecl) == 0); - else - { - m_interface->addAssertion(currentValue(_varDecl) >= minValue(*intType)); - m_interface->addAssertion(currentValue(_varDecl) <= maxValue(*intType)); - } + setValue(_varDecl, _setToZero); } else m_errorReporter.warning( @@ -460,16 +520,35 @@ bool SMTChecker::knownVariable(Declaration const& _decl) smt::Expression SMTChecker::currentValue(Declaration const& _decl) { solAssert(m_currentSequenceCounter.count(&_decl), ""); - return var(_decl)(m_currentSequenceCounter.at(&_decl)); + return valueAtSequence(_decl, m_currentSequenceCounter.at(&_decl)); } -smt::Expression SMTChecker::newValue(const Declaration& _decl) +smt::Expression SMTChecker::valueAtSequence(const Declaration& _decl, int _sequence) +{ + return var(_decl)(_sequence); +} + +smt::Expression SMTChecker::newValue(Declaration const& _decl) { solAssert(m_currentSequenceCounter.count(&_decl), ""); - m_currentSequenceCounter[&_decl]++; + solAssert(m_nextFreeSequenceCounter.count(&_decl), ""); + m_currentSequenceCounter[&_decl] = m_nextFreeSequenceCounter[&_decl]++; return currentValue(_decl); } +void SMTChecker::setValue(Declaration const& _decl, bool _setToZero) +{ + auto const& intType = dynamic_cast(*_decl.type()); + + if (_setToZero) + m_interface->addAssertion(currentValue(_decl) == 0); + else + { + m_interface->addAssertion(currentValue(_decl) >= minValue(intType)); + m_interface->addAssertion(currentValue(_decl) <= maxValue(intType)); + } +} + smt::Expression SMTChecker::minValue(IntegerType const& _t) { return smt::Expression(_t.minValue()); diff --git a/libsolidity/formal/SMTChecker.h b/libsolidity/formal/SMTChecker.h index f0968cc78..d49351164 100644 --- a/libsolidity/formal/SMTChecker.h +++ b/libsolidity/formal/SMTChecker.h @@ -46,6 +46,8 @@ private: virtual void endVisit(VariableDeclaration const& _node) override; virtual bool visit(FunctionDefinition const& _node) override; virtual void endVisit(FunctionDefinition const& _node) override; + virtual bool visit(IfStatement const& _node) override; + virtual bool visit(WhileStatement const& _node) override; virtual void endVisit(VariableDeclarationStatement const& _node) override; virtual void endVisit(ExpressionStatement const& _node) override; virtual void endVisit(Assignment const& _node) override; @@ -71,10 +73,23 @@ private: std::string uniqueSymbol(Declaration const& _decl); std::string uniqueSymbol(Expression const& _expr); + + /// @returns true if _delc is a variable that is known at the current point, i.e. + /// has a valid sequence number bool knownVariable(Declaration const& _decl); + /// @returns an expression denoting the value of the variable declared in @a _decl + /// at the current point. smt::Expression currentValue(Declaration const& _decl); + /// @returns an expression denoting the value of the variable declared in @a _decl + /// at the given sequence point. Does not ensure that this sequence point exists. + smt::Expression valueAtSequence(Declaration const& _decl, int _sequence); + /// Allocates a new sequence number for the declaration, updates the current + /// sequence number to this value and returns the expression. smt::Expression newValue(Declaration const& _decl); + /// Sets the value of the declaration either to zero or to its intrinsic range. + void setValue(Declaration const& _decl, bool _setToZero); + smt::Expression minValue(IntegerType const& _t); smt::Expression maxValue(IntegerType const& _t); @@ -87,6 +102,7 @@ private: std::shared_ptr m_interface; std::map m_currentSequenceCounter; + std::map m_nextFreeSequenceCounter; std::map m_z3Expressions; std::map m_z3Variables; ErrorReporter& m_errorReporter; diff --git a/libsolidity/formal/SolverInterface.h b/libsolidity/formal/SolverInterface.h index 2c00d030b..8423c4a70 100644 --- a/libsolidity/formal/SolverInterface.h +++ b/libsolidity/formal/SolverInterface.h @@ -61,6 +61,13 @@ public: Expression& operator=(Expression const& _other) = default; Expression& operator=(Expression&& _other) = default; + static Expression ite(Expression _condition, Expression _trueValue, Expression _falseValue) + { + return Expression("ite", std::vector{ + std::move(_condition), std::move(_trueValue), std::move(_falseValue) + }); + } + friend Expression operator!(Expression _a) { return Expression("not", std::move(_a)); diff --git a/libsolidity/formal/Z3Interface.cpp b/libsolidity/formal/Z3Interface.cpp index 0d59a3c77..bb0d6f6ff 100644 --- a/libsolidity/formal/Z3Interface.cpp +++ b/libsolidity/formal/Z3Interface.cpp @@ -72,21 +72,28 @@ void Z3Interface::addAssertion(Expression const& _expr) pair> Z3Interface::check(vector const& _expressionsToEvaluate) { +// cout << "---------------------------------" << endl; +// cout << m_solver << endl; CheckResult result; switch (m_solver.check()) { case z3::check_result::sat: result = CheckResult::SAT; + cout << "sat" << endl; break; case z3::check_result::unsat: result = CheckResult::UNSAT; + cout << "unsat" << endl; break; case z3::check_result::unknown: result = CheckResult::UNKNOWN; + cout << "unknown" << endl; break; default: solAssert(false, ""); } +// cout << "---------------------------------" << endl; + vector values; if (result != CheckResult::UNSAT) @@ -107,6 +114,7 @@ z3::expr Z3Interface::toZ3Expr(Expression const& _expr) arguments.push_back(toZ3Expr(arg)); static map arity{ + {"ite", 3}, {"not", 1}, {"and", 2}, {"or", 2}, @@ -135,7 +143,9 @@ z3::expr Z3Interface::toZ3Expr(Expression const& _expr) } assert(arity.count(n) && arity.at(n) == arguments.size()); - if (n == "not") + if (n == "ite") + return z3::ite(arguments[0], arguments[1], arguments[2]); + else if (n == "not") return !arguments[0]; else if (n == "and") return arguments[0] && arguments[1];