diff --git a/libsolidity/formal/SMTCheckerImpl.cpp b/libsolidity/formal/SMTCheckerImpl.cpp index 0a9ad66e7..028e4400a 100644 --- a/libsolidity/formal/SMTCheckerImpl.cpp +++ b/libsolidity/formal/SMTCheckerImpl.cpp @@ -25,7 +25,6 @@ using namespace dev; using namespace dev::solidity; SMTCheckerImpl::SMTCheckerImpl(ErrorReporter& _errorReporter): - m_solver(m_context), m_errorReporter(_errorReporter) { } @@ -39,7 +38,7 @@ void SMTCheckerImpl::analyze(SourceUnit const& _source) pragmaFound = true; if (pragmaFound) { - m_solver.reset(); + m_interface.reset(); m_currentSequenceCounter.clear(); _source.accept(*this); } @@ -69,14 +68,14 @@ bool SMTCheckerImpl::visit(FunctionDefinition const& _function) ); // TODO actually we probably also have to reset all local variables and similar things. m_currentFunction = &_function; - m_solver.push(); + m_interface.push(); return true; } void SMTCheckerImpl::endVisit(FunctionDefinition const&) { // TOOD we could check for "reachability", i.e. satisfiability here. - m_solver.pop(); + m_interface.pop(); m_currentFunction = nullptr; } @@ -90,7 +89,7 @@ void SMTCheckerImpl::endVisit(VariableDeclarationStatement const& _varDecl) else if (knownVariable(*_varDecl.declarations()[0]) && _varDecl.initialValue()) // TODO more checks? // TODO add restrictions about type (might be assignment from smaller type) - m_solver.add(newValue(*_varDecl.declarations()[0]) == expr(*_varDecl.initialValue())); + m_interface.addAssertion(newValue(*_varDecl.declarations()[0]) == expr(*_varDecl.initialValue())); else m_errorReporter.warning( _varDecl.location(), @@ -120,7 +119,7 @@ void SMTCheckerImpl::endVisit(Assignment const& _assignment) if (knownVariable(*decl)) // TODO more checks? // TODO add restrictions about type (might be assignment from smaller type) - m_solver.add(newValue(*decl) == expr(_assignment.rightHandSide())); + m_interface.addAssertion(newValue(*decl) == expr(_assignment.rightHandSide())); else m_errorReporter.warning( _assignment.location(), @@ -142,7 +141,7 @@ void SMTCheckerImpl::endVisit(TupleExpression const& _tuple) "Assertion checker does not yet implement tules and inline arrays." ); else - m_solver.add(expr(_tuple) == expr(*_tuple.components()[0])); + m_interface.addAssertion(expr(_tuple) == expr(*_tuple.components()[0])); } void SMTCheckerImpl::endVisit(BinaryOperation const& _op) @@ -170,13 +169,13 @@ void SMTCheckerImpl::endVisit(FunctionCall const& _funCall) solAssert(args.size() == 1, ""); solAssert(args[0]->annotation().type->category() == Type::Category::Bool, ""); checkCondition(!(expr(*args[0])), _funCall.location(), "Assertion violation"); - m_solver.add(expr(*args[0])); + m_interface.addAssertion(expr(*args[0])); } else if (funType.kind() == FunctionType::Kind::Require) { solAssert(args.size() == 1, ""); solAssert(args[0]->annotation().type->category() == Type::Category::Bool, ""); - m_solver.add(expr(*args[0])); + m_interface.addAssertion(expr(*args[0])); checkCondition(!(expr(*args[0])), _funCall.location(), "Unreachable code"); // TODO is there something meaningful we can check here? // We can check whether the condition is always fulfilled or never fulfilled. @@ -189,7 +188,7 @@ void SMTCheckerImpl::endVisit(Identifier const& _identifier) solAssert(decl, ""); if (dynamic_cast(_identifier.annotation().type.get())) { - m_solver.add(expr(_identifier) == currentValue(*decl)); + m_interface.addAssertion(expr(_identifier) == currentValue(*decl)); return; } else if (FunctionType const* fun = dynamic_cast(_identifier.annotation().type.get())) @@ -214,7 +213,7 @@ void SMTCheckerImpl::endVisit(Literal const& _literal) if (RationalNumberType const* rational = dynamic_cast(&type)) solAssert(!rational->isFractional(), ""); - m_solver.add(expr(_literal) == m_context.int_val(type.literalValue(&_literal).str().c_str())); + m_interface.addAssertion(expr(_literal) == smt::Expression(type.literalValue(&_literal))); } else m_errorReporter.warning( @@ -235,10 +234,10 @@ void SMTCheckerImpl::arithmeticOperation(BinaryOperation const& _op) { solAssert(_op.annotation().commonType, ""); solAssert(_op.annotation().commonType->category() == Type::Category::Integer, ""); - z3::expr left(expr(_op.leftExpression())); - z3::expr right(expr(_op.rightExpression())); + smt::Expression left(expr(_op.leftExpression())); + smt::Expression right(expr(_op.rightExpression())); Token::Value op = _op.getOperator(); - z3::expr value( + smt::Expression value( op == Token::Add ? left + right : op == Token::Sub ? left - right : /*op == Token::Mul*/ left * right @@ -261,7 +260,7 @@ void SMTCheckerImpl::arithmeticOperation(BinaryOperation const& _op) &value ); - m_solver.add(expr(_op) == value); + m_interface.addAssertion(expr(_op) == value); break; } default: @@ -277,10 +276,10 @@ void SMTCheckerImpl::compareOperation(BinaryOperation const& _op) solAssert(_op.annotation().commonType, ""); if (_op.annotation().commonType->category() == Type::Category::Integer) { - z3::expr left(expr(_op.leftExpression())); - z3::expr right(expr(_op.rightExpression())); + smt::Expression left(expr(_op.leftExpression())); + smt::Expression right(expr(_op.rightExpression())); Token::Value op = _op.getOperator(); - z3::expr value = ( + smt::Expression value = ( op == Token::Equal ? (left == right) : op == Token::NotEqual ? (left != right) : op == Token::LessThan ? (left < right) : @@ -289,7 +288,7 @@ void SMTCheckerImpl::compareOperation(BinaryOperation const& _op) /*op == Token::GreaterThanOrEqual*/ (left >= right) ); // TODO: check that other values for op are not possible. - m_solver.add(expr(_op) == value); + m_interface.addAssertion(expr(_op) == value); } else m_errorReporter.warning( @@ -305,9 +304,9 @@ void SMTCheckerImpl::booleanOperation(BinaryOperation const& _op) if (_op.annotation().commonType->category() == Type::Category::Bool) { if (_op.getOperator() == Token::And) - m_solver.add(expr(_op) == expr(_op.leftExpression()) && expr(_op.rightExpression())); + m_interface.addAssertion(expr(_op) == expr(_op.leftExpression()) && expr(_op.rightExpression())); else - m_solver.add(expr(_op) == expr(_op.leftExpression()) || expr(_op.rightExpression())); + m_interface.addAssertion(expr(_op) == expr(_op.leftExpression()) || expr(_op.rightExpression())); } else m_errorReporter.warning( @@ -317,33 +316,32 @@ void SMTCheckerImpl::booleanOperation(BinaryOperation const& _op) } void SMTCheckerImpl::checkCondition( - z3::expr _condition, + smt::Expression _condition, SourceLocation const& _location, string const& _description, string const& _additionalValueName, - z3::expr* _additionalValue + smt::Expression* _additionalValue ) { - m_solver.push(); - m_solver.add(_condition); - switch (m_solver.check()) + m_interface.push(); + m_interface.addAssertion(_condition); + switch (m_interface.check()) { - case z3::check_result::sat: + case smt::CheckResult::SAT: { std::ostringstream message; message << _description << " happens here"; if (m_currentFunction) { message << " for:\n"; - z3::model m = m_solver.get_model(); if (_additionalValue) - message << " " << _additionalValueName << " = " << m.eval(*_additionalValue) << "\n"; + message << " " << _additionalValueName << " = " << m_interface.eval(*_additionalValue) << "\n"; for (auto const& param: m_currentFunction->parameters()) if (knownVariable(*param)) - message << " " << param->name() << " = " << m.eval(currentValue(*param)) << "\n"; + message << " " << param->name() << " = " << m_interface.eval(currentValue(*param)) << "\n"; for (auto const& var: m_currentFunction->localVariables()) if (knownVariable(*var)) - message << " " << var->name() << " = " << m.eval(currentValue(*var)) << "\n"; + message << " " << var->name() << " = " << m_interface.eval(currentValue(*var)) << "\n"; // message << m << endl; // message << m_solver << endl; } @@ -352,13 +350,13 @@ void SMTCheckerImpl::checkCondition( m_errorReporter.warning(_location, message.str()); break; } - case z3::check_result::unsat: + case smt::CheckResult::UNSAT: break; - case z3::check_result::unknown: + case smt::CheckResult::UNKNOWN: m_errorReporter.warning(_location, _description + " might happen here."); break; } - m_solver.pop(); + m_interface.pop(); } void SMTCheckerImpl::createVariable(VariableDeclaration const& _varDecl, bool _setToZero) @@ -368,13 +366,13 @@ void SMTCheckerImpl::createVariable(VariableDeclaration const& _varDecl, bool _s solAssert(m_currentSequenceCounter.count(&_varDecl) == 0, ""); solAssert(m_z3Variables.count(&_varDecl) == 0, ""); m_currentSequenceCounter[&_varDecl] = 0; - m_z3Variables.emplace(&_varDecl, m_context.function(uniqueSymbol(_varDecl).c_str(), m_context.int_sort(), m_context.int_sort())); + m_z3Variables.emplace(&_varDecl, m_interface.newFunction(uniqueSymbol(_varDecl), smt::Sort::Int, smt::Sort::Int)); if (_setToZero) - m_solver.add(currentValue(_varDecl) == 0); + m_interface.addAssertion(currentValue(_varDecl) == 0); else { - m_solver.add(currentValue(_varDecl) >= minValue(*intType)); - m_solver.add(currentValue(_varDecl) <= maxValue(*intType)); + m_interface.addAssertion(currentValue(_varDecl) >= minValue(*intType)); + m_interface.addAssertion(currentValue(_varDecl) <= maxValue(*intType)); } } else @@ -399,30 +397,30 @@ bool SMTCheckerImpl::knownVariable(Declaration const& _decl) return m_currentSequenceCounter.count(&_decl); } -z3::expr SMTCheckerImpl::currentValue(Declaration const& _decl) +smt::Expression SMTCheckerImpl::currentValue(Declaration const& _decl) { solAssert(m_currentSequenceCounter.count(&_decl), ""); return var(_decl)(m_currentSequenceCounter.at(&_decl)); } -z3::expr SMTCheckerImpl::newValue(const Declaration& _decl) +smt::Expression SMTCheckerImpl::newValue(const Declaration& _decl) { solAssert(m_currentSequenceCounter.count(&_decl), ""); m_currentSequenceCounter[&_decl]++; return currentValue(_decl); } -z3::expr SMTCheckerImpl::minValue(IntegerType const& _t) +smt::Expression SMTCheckerImpl::minValue(IntegerType const& _t) { - return m_context.int_val(_t.minValue().str().c_str()); + return m_interface.newInteger(_t.minValue()); } -z3::expr SMTCheckerImpl::maxValue(IntegerType const& _t) +smt::Expression SMTCheckerImpl::maxValue(IntegerType const& _t) { - return m_context.int_val(_t.maxValue().str().c_str()); + return m_interface.newInteger(_t.maxValue()); } -z3::expr SMTCheckerImpl::expr(Expression const& _e) +smt::Expression SMTCheckerImpl::expr(Expression const& _e) { if (!m_z3Expressions.count(&_e)) { @@ -433,14 +431,14 @@ z3::expr SMTCheckerImpl::expr(Expression const& _e) { if (RationalNumberType const* rational = dynamic_cast(_e.annotation().type.get())) solAssert(!rational->isFractional(), ""); - m_z3Expressions.emplace(&_e, m_context.int_const(uniqueSymbol(_e).c_str())); + m_z3Expressions.emplace(&_e, m_interface.newInteger(uniqueSymbol(_e))); break; } case Type::Category::Integer: - m_z3Expressions.emplace(&_e, m_context.int_const(uniqueSymbol(_e).c_str())); + m_z3Expressions.emplace(&_e, m_interface.newInteger(uniqueSymbol(_e))); break; case Type::Category::Bool: - m_z3Expressions.emplace(&_e, m_context.bool_const(uniqueSymbol(_e).c_str())); + m_z3Expressions.emplace(&_e, m_interface.newBool(uniqueSymbol(_e))); break; default: solAssert(false, "Type not implemented."); @@ -449,7 +447,7 @@ z3::expr SMTCheckerImpl::expr(Expression const& _e) return m_z3Expressions.at(&_e); } -z3::func_decl SMTCheckerImpl::var(Declaration const& _decl) +smt::Expression SMTCheckerImpl::var(Declaration const& _decl) { solAssert(m_z3Variables.count(&_decl), ""); return m_z3Variables.at(&_decl); diff --git a/libsolidity/formal/SMTCheckerImpl.h b/libsolidity/formal/SMTCheckerImpl.h index a82ec92ab..e51794409 100644 --- a/libsolidity/formal/SMTCheckerImpl.h +++ b/libsolidity/formal/SMTCheckerImpl.h @@ -18,8 +18,7 @@ #pragma once #include - -#include +#include #include #include @@ -60,11 +59,11 @@ private: void booleanOperation(BinaryOperation const& _op); void checkCondition( - z3::expr _condition, + smt::Expression _condition, SourceLocation const& _location, std::string const& _description, std::string const& _additionalValueName = "", - z3::expr* _additionalValue = nullptr + smt::Expression* _additionalValue = nullptr ); void createVariable(VariableDeclaration const& _varDecl, bool _setToZero); @@ -72,24 +71,23 @@ private: std::string uniqueSymbol(Declaration const& _decl); std::string uniqueSymbol(Expression const& _expr); bool knownVariable(Declaration const& _decl); - z3::expr currentValue(Declaration const& _decl); - z3::expr newValue(Declaration const& _decl); + smt::Expression currentValue(Declaration const& _decl); + smt::Expression newValue(Declaration const& _decl); - z3::expr minValue(IntegerType const& _t); - z3::expr maxValue(IntegerType const& _t); + smt::Expression minValue(IntegerType const& _t); + smt::Expression maxValue(IntegerType const& _t); /// Returns the z3 expression corresponding to the AST node. Creates a new expression /// if it does not exist yet. - z3::expr expr(Expression const& _e); + smt::Expression expr(Expression const& _e); /// Returns the z3 function declaration corresponding to the given variable. /// The function takes one argument which is the "sequence number". - z3::func_decl var(Declaration const& _decl); + smt::Expression var(Declaration const& _decl); - z3::context m_context; - z3::solver m_solver; + smt::SMTLib2Interface m_interface; std::map m_currentSequenceCounter; - std::map m_z3Expressions; - std::map m_z3Variables; + std::map m_z3Expressions; + std::map m_z3Variables; ErrorReporter& m_errorReporter; FunctionDefinition const* m_currentFunction = nullptr; diff --git a/libsolidity/formal/SMTLib2Interface.cpp b/libsolidity/formal/SMTLib2Interface.cpp new file mode 100644 index 000000000..c736ed2a3 --- /dev/null +++ b/libsolidity/formal/SMTLib2Interface.cpp @@ -0,0 +1,24 @@ +/* + This file is part of solidity. + + solidity is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + solidity is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with solidity. If not, see . +*/ + +#include + + +using namespace std; +using namespace dev; +using namespace dev::solidity::smt; + diff --git a/libsolidity/formal/SMTLib2Interface.h b/libsolidity/formal/SMTLib2Interface.h new file mode 100644 index 000000000..f984cfb50 --- /dev/null +++ b/libsolidity/formal/SMTLib2Interface.h @@ -0,0 +1,150 @@ +/* + This file is part of solidity. + + solidity is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + solidity is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with solidity. If not, see . +*/ + +#pragma once + +#include +#include +#include + +#include + +namespace dev +{ +namespace solidity +{ +namespace smt +{ + +enum class CheckResult +{ + SAT, UNSAT, UNKNOWN +}; + +enum class Sort +{ + Int, Bool +}; + +class Expression +{ + friend class SMTLib2Interface; + /// Manual constructor, should only be used by SMTLib2Interface and the class itself. + Expression(std::string _name, std::vector _arguments): + m_name(std::move(_name)), m_arguments(std::move(_arguments)) {} + +public: + Expression(size_t _number): m_name(std::to_string(_number)) {} + Expression(u256 const& _number): m_name(std::to_string(_number)) {} + + Expression(Expression const& _other) = default; + Expression(Expression&& _other) = default; + Expression& operator=(Expression const& _other) = default; + Expression& operator=(Expression&& _other) = default; + + friend Expression operator!(Expression _a) + { + return Expression("not", _a); + } + friend Expression operator&&(Expression _a, Expression _b) + { + return Expression("and", _a, _b); + } + friend Expression operator||(Expression _a, Expression _b) + { + return Expression("or", _a, _b); + } + friend Expression operator==(Expression _a, Expression _b) + { + return Expression("=", _a, _b); + } + friend Expression operator!=(Expression _a, Expression _b) + { + return !(_a == _b); + } + friend Expression operator<(Expression _a, Expression _b) + { + return Expression("<", std::move(_a), std::move(_b)); + } + friend Expression operator<=(Expression _a, Expression _b) + { + return Expression("<=", std::move(_a), std::move(_b)); + } + friend Expression operator>(Expression _a, Expression _b) + { + return Expression(">", std::move(_a), std::move(_b)); + } + friend Expression operator>=(Expression _a, Expression _b) + { + return Expression(">=", std::move(_a), std::move(_b)); + } + friend Expression operator+(Expression _a, Expression _b) + { + return Expression("+", std::move(_a), std::move(_b)); + } + friend Expression operator-(Expression _a, Expression _b) + { + return Expression("-", std::move(_a), std::move(_b)); + } + friend Expression operator*(Expression _a, Expression _b) + { + return Expression("*", std::move(_a), std::move(_b)); + } + + std::string toSExpr() const + { + std::string sexpr = "(" + m_name; + for (auto const& arg: m_arguments) + sexpr += " " + arg.toSExpr(); + sexpr += ")"; + return sexpr; + } + +private: + explicit Expression(std::string _name): + Expression(std::move(_name), std::vector{}) {} + Expression(std::string _name, Expression _arg): + Expression(std::move(_name), std::vector{std::move(_arg)}) {} + Expression(std::string _name, Expression _arg1, Expression _arg2): + Expression(std::move(_name), std::vector{std::move(_arg1), std::move(_arg2)}) {} + + std::string const m_name; + std::vector const m_arguments; +}; + +class SMTLib2Interface +{ +public: + + void reset(); + + void push(); + void pop(); + + Expression newFunction(std::string _name, Sort _domain, Sort _codomain); + Expression newInteger(std::string _name); + Expression newBool(std::string _name); + + void addAssertion(Expression _expr); + CheckResult check(); + std::string eval(Expression _expr); +}; + + +} +} +}