diff --git a/libsolidity/formal/CVC4Interface.cpp b/libsolidity/formal/CVC4Interface.cpp index 118229279..6e17eef17 100644 --- a/libsolidity/formal/CVC4Interface.cpp +++ b/libsolidity/formal/CVC4Interface.cpp @@ -50,7 +50,7 @@ void CVC4Interface::pop() m_solver.pop(); } -void CVC4Interface::declareFunction(string _name, vector const& _domain, Sort _codomain) +void CVC4Interface::declareFunction(string _name, vector const& _domain, Sort const& _codomain) { if (!m_functions.count(_name)) { @@ -186,13 +186,13 @@ CVC4::Expr CVC4Interface::toCVC4Expr(Expression const& _expr) return arguments[0]; } -CVC4::Type CVC4Interface::cvc4Sort(Sort _sort) +CVC4::Type CVC4Interface::cvc4Sort(Sort const& _sort) { - switch (_sort) + switch (_sort.kind) { - case Sort::Bool: + case Kind::Bool: return m_context.booleanType(); - case Sort::Int: + case Kind::Int: return m_context.integerType(); default: break; @@ -202,10 +202,10 @@ CVC4::Type CVC4Interface::cvc4Sort(Sort _sort) return m_context.integerType(); } -vector CVC4Interface::cvc4Sort(vector const& _sorts) +vector CVC4Interface::cvc4Sort(vector const& _sorts) { vector cvc4Sorts; for (auto const& _sort: _sorts) - cvc4Sorts.push_back(cvc4Sort(_sort)); + cvc4Sorts.push_back(cvc4Sort(*_sort)); return cvc4Sorts; } diff --git a/libsolidity/formal/CVC4Interface.h b/libsolidity/formal/CVC4Interface.h index 273dce235..f354c790b 100644 --- a/libsolidity/formal/CVC4Interface.h +++ b/libsolidity/formal/CVC4Interface.h @@ -51,7 +51,7 @@ public: void push() override; void pop() override; - void declareFunction(std::string _name, std::vector const& _domain, Sort _codomain) override; + void declareFunction(std::string _name, std::vector const& _domain, Sort const& _codomain) override; void declareInteger(std::string _name) override; void declareBool(std::string _name) override; @@ -60,8 +60,8 @@ public: private: CVC4::Expr toCVC4Expr(Expression const& _expr); - CVC4::Type cvc4Sort(smt::Sort _sort); - std::vector cvc4Sort(std::vector const& _sort); + CVC4::Type cvc4Sort(smt::Sort const& _sort); + std::vector cvc4Sort(std::vector const& _sorts); CVC4::ExprManager m_context; CVC4::SmtEngine m_solver; diff --git a/libsolidity/formal/SMTChecker.cpp b/libsolidity/formal/SMTChecker.cpp index bbc78c0c7..0a581fc00 100644 --- a/libsolidity/formal/SMTChecker.cpp +++ b/libsolidity/formal/SMTChecker.cpp @@ -416,7 +416,7 @@ void SMTChecker::visitGasLeft(FunctionCall const& _funCall) void SMTChecker::visitBlockHash(FunctionCall const& _funCall) { string blockHash = "blockhash"; - defineUninterpretedFunction(blockHash, {smt::Sort::Int}, smt::Sort::Int); + defineUninterpretedFunction(blockHash, {make_shared(smt::Kind::Int)}, smt::Kind::Int); auto const& arguments = _funCall.arguments(); solAssert(arguments.size() == 1, ""); defineExpr(_funCall, m_uninterpretedFunctions.at(blockHash)({expr(*arguments[0])})); @@ -605,7 +605,7 @@ void SMTChecker::defineSpecialVariable(string const& _name, Expression const& _e defineExpr(_expr, m_specialVariables.at(_name)->currentValue()); } -void SMTChecker::defineUninterpretedFunction(string const& _name, vector const& _domain, smt::Sort _codomain) +void SMTChecker::defineUninterpretedFunction(string const& _name, vector const& _domain, smt::Sort const& _codomain) { if (!m_uninterpretedFunctions.count(_name)) m_uninterpretedFunctions.emplace(_name, m_interface->newFunction(_name, _domain, _codomain)); diff --git a/libsolidity/formal/SMTChecker.h b/libsolidity/formal/SMTChecker.h index 4a5974aac..3bf84ac9a 100644 --- a/libsolidity/formal/SMTChecker.h +++ b/libsolidity/formal/SMTChecker.h @@ -83,7 +83,7 @@ private: void inlineFunctionCall(FunctionCall const&); void defineSpecialVariable(std::string const& _name, Expression const& _expr, bool _increaseIndex = false); - void defineUninterpretedFunction(std::string const& _name, std::vector const& _domain, smt::Sort _codomain); + void defineUninterpretedFunction(std::string const& _name, std::vector const& _domain, smt::Sort const& _codomain); /// Division expression in the given type. Requires special treatment because /// of rounding for signed division. diff --git a/libsolidity/formal/SMTLib2Interface.cpp b/libsolidity/formal/SMTLib2Interface.cpp index 545422338..01386dda9 100644 --- a/libsolidity/formal/SMTLib2Interface.cpp +++ b/libsolidity/formal/SMTLib2Interface.cpp @@ -64,12 +64,12 @@ void SMTLib2Interface::pop() m_accumulatedOutput.pop_back(); } -void SMTLib2Interface::declareFunction(string _name, vector const& _domain, Sort _codomain) +void SMTLib2Interface::declareFunction(string _name, vector const& _domain, Sort const& _codomain) { // TODO Use domain and codomain as key as well string domain(""); for (auto const& sort: _domain) - domain += toSmtLibSort(sort) + ' '; + domain += toSmtLibSort(*sort) + ' '; if (!m_functions.count(_name)) { m_functions.insert(_name); @@ -79,7 +79,7 @@ void SMTLib2Interface::declareFunction(string _name, vector const& _domain "| (" + domain + ") " + - (_codomain == Sort::Int ? "Int" : "Bool") + + (_codomain.kind == Kind::Int ? "Int" : "Bool") + ")" ); } @@ -143,13 +143,13 @@ string SMTLib2Interface::toSExpr(Expression const& _expr) return sexpr; } -string SMTLib2Interface::toSmtLibSort(Sort _sort) +string SMTLib2Interface::toSmtLibSort(Sort const& _sort) { - switch (_sort) + switch (_sort.kind) { - case Sort::Int: + case Kind::Int: return "Int"; - case Sort::Bool: + case Kind::Bool: return "Bool"; default: solAssert(false, "Invalid SMT sort"); @@ -173,8 +173,8 @@ string SMTLib2Interface::checkSatAndGetValuesCommand(vector const& _ for (size_t i = 0; i < _expressionsToEvaluate.size(); i++) { auto const& e = _expressionsToEvaluate.at(i); - solAssert(e.sort == Sort::Int || e.sort == Sort::Bool, "Invalid sort for expression to evaluate."); - command += "(declare-const |EVALEXPR_" + to_string(i) + "| " + (e.sort == Sort::Int ? "Int" : "Bool") + ")\n"; + solAssert(e.sort->kind == Kind::Int || e.sort->kind == Kind::Bool, "Invalid sort for expression to evaluate."); + command += "(declare-const |EVALEXPR_" + to_string(i) + "| " + (e.sort->kind == Kind::Int ? "Int" : "Bool") + ")\n"; command += "(assert (= |EVALEXPR_" + to_string(i) + "| " + toSExpr(e) + "))\n"; } command += "(check-sat)\n"; diff --git a/libsolidity/formal/SMTLib2Interface.h b/libsolidity/formal/SMTLib2Interface.h index 08ad74daf..b140f5557 100644 --- a/libsolidity/formal/SMTLib2Interface.h +++ b/libsolidity/formal/SMTLib2Interface.h @@ -49,7 +49,7 @@ public: void push() override; void pop() override; - void declareFunction(std::string _name, std::vector const& _domain, Sort _codomain) override; + void declareFunction(std::string _name, std::vector const& _domain, Sort const& _codomain) override; void declareInteger(std::string _name) override; void declareBool(std::string _name) override; @@ -58,7 +58,7 @@ public: private: std::string toSExpr(Expression const& _expr); - std::string toSmtLibSort(Sort _sort); + std::string toSmtLibSort(Sort const& _sort); void write(std::string _data); diff --git a/libsolidity/formal/SMTPortfolio.cpp b/libsolidity/formal/SMTPortfolio.cpp index e1cde04c6..e01a5accf 100644 --- a/libsolidity/formal/SMTPortfolio.cpp +++ b/libsolidity/formal/SMTPortfolio.cpp @@ -64,7 +64,7 @@ void SMTPortfolio::pop() s->pop(); } -void SMTPortfolio::declareFunction(string _name, vector const& _domain, Sort _codomain) +void SMTPortfolio::declareFunction(string _name, vector const& _domain, Sort const& _codomain) { for (auto s : m_solvers) s->declareFunction(_name, _domain, _codomain); diff --git a/libsolidity/formal/SMTPortfolio.h b/libsolidity/formal/SMTPortfolio.h index 50bd87d46..712fb545c 100644 --- a/libsolidity/formal/SMTPortfolio.h +++ b/libsolidity/formal/SMTPortfolio.h @@ -49,7 +49,7 @@ public: void push() override; void pop() override; - void declareFunction(std::string _name, std::vector const& _domain, Sort _codomain) override; + void declareFunction(std::string _name, std::vector const& _domain, Sort const& _codomain) override; void declareInteger(std::string _name) override; void declareBool(std::string _name) override; diff --git a/libsolidity/formal/SolverInterface.h b/libsolidity/formal/SolverInterface.h index a6618fb59..55c0a5634 100644 --- a/libsolidity/formal/SolverInterface.h +++ b/libsolidity/formal/SolverInterface.h @@ -42,21 +42,32 @@ enum class CheckResult SATISFIABLE, UNSATISFIABLE, UNKNOWN, CONFLICTING, ERROR }; -enum class Sort +enum class Kind { Int, Bool }; +struct Sort +{ + Sort(Kind _kind): + kind(_kind) {} + virtual ~Sort() = default; + Kind const kind; + bool operator==(Sort const& _other) const { return kind == _other.kind; } +}; +using SortPointer = std::shared_ptr; + + /// C++ representation of an SMTLIB2 expression. class Expression { friend class SolverInterface; public: - explicit Expression(bool _v): name(_v ? "true" : "false"), sort(Sort::Bool) {} - Expression(size_t _number): name(std::to_string(_number)), sort(Sort::Int) {} - Expression(u256 const& _number): name(_number.str()), sort(Sort::Int) {} - Expression(bigint const& _number): name(_number.str()), sort(Sort::Int) {} + explicit Expression(bool _v): Expression(_v ? "true" : "false", Kind::Bool) {} + Expression(size_t _number): Expression(std::to_string(_number), Kind::Int) {} + Expression(u256 const& _number): Expression(_number.str(), Kind::Int) {} + Expression(bigint const& _number): Expression(_number.str(), Kind::Int) {} Expression(Expression const&) = default; Expression(Expression&&) = default; @@ -85,7 +96,7 @@ public: static Expression ite(Expression _condition, Expression _trueValue, Expression _falseValue) { - solAssert(_trueValue.sort == _falseValue.sort, ""); + solAssert(*_trueValue.sort == *_falseValue.sort, ""); return Expression("ite", std::vector{ std::move(_condition), std::move(_trueValue), std::move(_falseValue) }, _trueValue.sort); @@ -98,19 +109,19 @@ public: friend Expression operator!(Expression _a) { - return Expression("not", std::move(_a), Sort::Bool); + return Expression("not", std::move(_a), Kind::Bool); } friend Expression operator&&(Expression _a, Expression _b) { - return Expression("and", std::move(_a), std::move(_b), Sort::Bool); + return Expression("and", std::move(_a), std::move(_b), Kind::Bool); } friend Expression operator||(Expression _a, Expression _b) { - return Expression("or", std::move(_a), std::move(_b), Sort::Bool); + return Expression("or", std::move(_a), std::move(_b), Kind::Bool); } friend Expression operator==(Expression _a, Expression _b) { - return Expression("=", std::move(_a), std::move(_b), Sort::Bool); + return Expression("=", std::move(_a), std::move(_b), Kind::Bool); } friend Expression operator!=(Expression _a, Expression _b) { @@ -118,35 +129,35 @@ public: } friend Expression operator<(Expression _a, Expression _b) { - return Expression("<", std::move(_a), std::move(_b), Sort::Bool); + return Expression("<", std::move(_a), std::move(_b), Kind::Bool); } friend Expression operator<=(Expression _a, Expression _b) { - return Expression("<=", std::move(_a), std::move(_b), Sort::Bool); + return Expression("<=", std::move(_a), std::move(_b), Kind::Bool); } friend Expression operator>(Expression _a, Expression _b) { - return Expression(">", std::move(_a), std::move(_b), Sort::Bool); + return Expression(">", std::move(_a), std::move(_b), Kind::Bool); } friend Expression operator>=(Expression _a, Expression _b) { - return Expression(">=", std::move(_a), std::move(_b), Sort::Bool); + return Expression(">=", std::move(_a), std::move(_b), Kind::Bool); } friend Expression operator+(Expression _a, Expression _b) { - return Expression("+", std::move(_a), std::move(_b), Sort::Int); + return Expression("+", std::move(_a), std::move(_b), Kind::Int); } friend Expression operator-(Expression _a, Expression _b) { - return Expression("-", std::move(_a), std::move(_b), Sort::Int); + return Expression("-", std::move(_a), std::move(_b), Kind::Int); } friend Expression operator*(Expression _a, Expression _b) { - return Expression("*", std::move(_a), std::move(_b), Sort::Int); + return Expression("*", std::move(_a), std::move(_b), Kind::Int); } friend Expression operator/(Expression _a, Expression _b) { - return Expression("/", std::move(_a), std::move(_b), Sort::Int); + return Expression("/", std::move(_a), std::move(_b), Kind::Int); } Expression operator()(std::vector _arguments) const { @@ -154,36 +165,26 @@ public: arguments.empty(), "Attempted function application to non-function." ); - switch (sort) - { - case Sort::Int: - return Expression(name, std::move(_arguments), Sort::Int); - case Sort::Bool: - return Expression(name, std::move(_arguments), Sort::Bool); - default: - solAssert( - false, - "Attempted function application to invalid type." - ); - break; - } + return Expression(name, std::move(_arguments), sort); } std::string name; std::vector arguments; - Sort sort; + SortPointer sort; private: - /// Manual constructor, should only be used by SolverInterface and this class itself. - Expression(std::string _name, std::vector _arguments, Sort _sort): - name(std::move(_name)), arguments(std::move(_arguments)), sort(_sort) {} + /// Manual constructors, should only be used by SolverInterface and this class itself. + Expression(std::string _name, std::vector _arguments, SortPointer _sort): + name(std::move(_name)), arguments(std::move(_arguments)), sort(std::move(_sort)) {} + Expression(std::string _name, std::vector _arguments, Kind _kind): + Expression(std::move(_name), std::move(_arguments), std::make_shared(_kind)) {} - explicit Expression(std::string _name, Sort _sort): - Expression(std::move(_name), std::vector{}, _sort) {} - Expression(std::string _name, Expression _arg, Sort _sort): - Expression(std::move(_name), std::vector{std::move(_arg)}, _sort) {} - Expression(std::string _name, Expression _arg1, Expression _arg2, Sort _sort): - Expression(std::move(_name), std::vector{std::move(_arg1), std::move(_arg2)}, _sort) {} + explicit Expression(std::string _name, Kind _kind): + Expression(std::move(_name), std::vector{}, _kind) {} + Expression(std::string _name, Expression _arg, Kind _kind): + Expression(std::move(_name), std::vector{std::move(_arg)}, _kind) {} + Expression(std::string _name, Expression _arg1, Expression _arg2, Kind _kind): + Expression(std::move(_name), std::vector{std::move(_arg1), std::move(_arg2)}, _kind) {} }; DEV_SIMPLE_EXCEPTION(SolverError); @@ -197,39 +198,26 @@ public: virtual void push() = 0; virtual void pop() = 0; - virtual void declareFunction(std::string _name, std::vector const& _domain, Sort _codomain) = 0; - void declareFunction(std::string _name, Sort _domain, Sort _codomain) - { - declareFunction(std::move(_name), std::vector{std::move(_domain)}, std::move(_codomain)); - } - Expression newFunction(std::string _name, std::vector const& _domain, Sort _codomain) + virtual void declareFunction(std::string _name, std::vector const& _domain, Sort const& _codomain) = 0; + Expression newFunction(std::string _name, std::vector const& _domain, Sort const& _codomain) { declareFunction(_name, _domain, _codomain); // Subclasses should do something here - switch (_codomain) - { - case Sort::Int: - return Expression(std::move(_name), {}, Sort::Int); - case Sort::Bool: - return Expression(std::move(_name), {}, Sort::Bool); - default: - solAssert(false, "Function sort not supported."); - break; - } + return Expression(std::move(_name), {}, _codomain.kind); } virtual void declareInteger(std::string _name) = 0; Expression newInteger(std::string _name) { // Subclasses should do something here declareInteger(_name); - return Expression(std::move(_name), {}, Sort::Int); + return Expression(std::move(_name), {}, Kind::Int); } virtual void declareBool(std::string _name) = 0; Expression newBool(std::string _name) { // Subclasses should do something here declareBool(_name); - return Expression(std::move(_name), {}, Sort::Bool); + return Expression(std::move(_name), {}, Kind::Bool); } virtual void addAssertion(Expression const& _expr) = 0; diff --git a/libsolidity/formal/SymbolicTypes.cpp b/libsolidity/formal/SymbolicTypes.cpp index 78bf861bf..a3b6e3a8b 100644 --- a/libsolidity/formal/SymbolicTypes.cpp +++ b/libsolidity/formal/SymbolicTypes.cpp @@ -24,12 +24,24 @@ using namespace std; using namespace dev::solidity; -smt::Sort dev::solidity::smtSort(Type::Category _category) +smt::SortPointer dev::solidity::smtSort(Type const& _type) +{ + switch (smtKind(_type.category())) + { + case smt::Kind::Int: + return make_shared(smt::Kind::Int); + case smt::Kind::Bool: + return make_shared(smt::Kind::Bool); + } + solAssert(false, "Invalid type"); +} + +smt::Kind dev::solidity::smtKind(Type::Category _category) { if (isNumber(_category)) - return smt::Sort::Int; + return smt::Kind::Int; else if (isBool(_category)) - return smt::Sort::Bool; + return smt::Kind::Bool; solAssert(false, "Invalid type"); } diff --git a/libsolidity/formal/SymbolicTypes.h b/libsolidity/formal/SymbolicTypes.h index 2639fcb94..c802c5b46 100644 --- a/libsolidity/formal/SymbolicTypes.h +++ b/libsolidity/formal/SymbolicTypes.h @@ -29,7 +29,9 @@ namespace solidity { /// Returns the SMT sort that models the Solidity type _type. -smt::Sort smtSort(Type::Category _type); +smt::SortPointer smtSort(Type const& _type); +/// Returns the SMT kind that models the Solidity type type category _category. +smt::Kind smtKind(Type::Category _category); /// So far int, bool and address are supported. /// Returns true if type is supported. diff --git a/libsolidity/formal/Z3Interface.cpp b/libsolidity/formal/Z3Interface.cpp index 2519e41b7..09696aeba 100644 --- a/libsolidity/formal/Z3Interface.cpp +++ b/libsolidity/formal/Z3Interface.cpp @@ -51,7 +51,7 @@ void Z3Interface::pop() m_solver.pop(); } -void Z3Interface::declareFunction(string _name, vector const& _domain, Sort _codomain) +void Z3Interface::declareFunction(string _name, vector const& _domain, Sort const& _codomain) { if (!m_functions.count(_name)) m_functions.insert({_name, m_context.function(_name.c_str(), z3Sort(_domain), z3Sort(_codomain))}); @@ -168,13 +168,13 @@ z3::expr Z3Interface::toZ3Expr(Expression const& _expr) return arguments[0]; } -z3::sort Z3Interface::z3Sort(Sort _sort) +z3::sort Z3Interface::z3Sort(Sort const& _sort) { - switch (_sort) + switch (_sort.kind) { - case Sort::Bool: + case Kind::Bool: return m_context.bool_sort(); - case Sort::Int: + case Kind::Int: return m_context.int_sort(); default: break; @@ -184,10 +184,10 @@ z3::sort Z3Interface::z3Sort(Sort _sort) return m_context.int_sort(); } -z3::sort_vector Z3Interface::z3Sort(vector const& _sorts) +z3::sort_vector Z3Interface::z3Sort(vector const& _sorts) { z3::sort_vector z3Sorts(m_context); for (auto const& _sort: _sorts) - z3Sorts.push_back(z3Sort(_sort)); + z3Sorts.push_back(z3Sort(*_sort)); return z3Sorts; } diff --git a/libsolidity/formal/Z3Interface.h b/libsolidity/formal/Z3Interface.h index 5eae618e8..8c1fcf614 100644 --- a/libsolidity/formal/Z3Interface.h +++ b/libsolidity/formal/Z3Interface.h @@ -40,7 +40,7 @@ public: void push() override; void pop() override; - void declareFunction(std::string _name, std::vector const& _domain, Sort _codomain) override; + void declareFunction(std::string _name, std::vector const& _domain, Sort const& _codomain) override; void declareInteger(std::string _name) override; void declareBool(std::string _name) override; @@ -49,8 +49,8 @@ public: private: z3::expr toZ3Expr(Expression const& _expr); - z3::sort z3Sort(smt::Sort _sort); - z3::sort_vector z3Sort(std::vector const& _sort); + z3::sort z3Sort(smt::Sort const& _sort); + z3::sort_vector z3Sort(std::vector const& _sorts); z3::context m_context; z3::solver m_solver;