From 6acbe2ec3536d90d4cb41a380b24fc17b006fb75 Mon Sep 17 00:00:00 2001 From: Martin Blicha Date: Sun, 16 Jul 2023 12:56:40 +0200 Subject: [PATCH] Towards translating proof from SMT-LIB response --- libsmtutil/CHCSmtLib2Interface.cpp | 422 ++++++++++++++++++++++++++++- libsmtutil/CHCSmtLib2Interface.h | 13 + 2 files changed, 433 insertions(+), 2 deletions(-) diff --git a/libsmtutil/CHCSmtLib2Interface.cpp b/libsmtutil/CHCSmtLib2Interface.cpp index 3f0d3f881..63bc03165 100644 --- a/libsmtutil/CHCSmtLib2Interface.cpp +++ b/libsmtutil/CHCSmtLib2Interface.cpp @@ -19,6 +19,11 @@ #include #include +#include +#include + +#include "liblangutil/Common.h" + #include #include @@ -29,11 +34,16 @@ #include #include #include +#include #include +#include +#include +#include using namespace solidity; using namespace solidity::util; using namespace solidity::frontend; +using namespace solidity::langutil; using namespace solidity::smtutil; CHCSmtLib2Interface::CHCSmtLib2Interface( @@ -57,6 +67,7 @@ void CHCSmtLib2Interface::reset() m_variables.clear(); m_unhandledQueries.clear(); m_sortNames.clear(); + m_knownSorts.clear(); } void CHCSmtLib2Interface::registerRelation(Expression const& _expr) @@ -97,8 +108,10 @@ std::tuple CHCSmtLib2Inte // TODO proper parsing if (boost::starts_with(response, "sat")) result = CheckResult::UNSATISFIABLE; - else if (boost::starts_with(response, "unsat")) + else if (boost::starts_with(response, "unsat")) { result = CheckResult::SATISFIABLE; + return {result, Expression(true), graphFromZ3Proof(response)}; + } else if (boost::starts_with(response, "unknown")) result = CheckResult::UNKNOWN; else @@ -122,7 +135,11 @@ void CHCSmtLib2Interface::declareVariable(std::string const& _name, SortPointer std::string CHCSmtLib2Interface::toSmtLibSort(Sort const& _sort) { if (!m_sortNames.count(&_sort)) - m_sortNames[&_sort] = m_smtlib2->toSmtLibSort(_sort); + { + auto smtLibName = m_smtlib2->toSmtLibSort(_sort); + m_sortNames[&_sort] = smtLibName; + m_knownSorts[smtLibName] = &_sort; + } return m_sortNames.at(&_sort); } @@ -237,3 +254,404 @@ std::string CHCSmtLib2Interface::createHeaderAndDeclarations() { std::string CHCSmtLib2Interface::createQueryAssertion(std::string name) { return "(assert\n(forall " + forall() + "\n" + "(=> " + name + " false)))"; } + +std::string CHCSmtLib2Interface::SMTLib2Expression::toString() const +{ + return std::visit(GenericVisitor{ + [](std::string const& _sv) { return _sv; }, + [](std::vector const& _subExpr) { + std::vector formatted; + for (auto const& item: _subExpr) + formatted.emplace_back(item.toString()); + return "(" + joinHumanReadable(formatted, " ") + ")"; + } + }, data); +} + +namespace { + using SMTLib2Expression = CHCSmtLib2Interface::SMTLib2Expression; + bool isNumber(std::string const& _expr) + { + for (char c: _expr) + if (!isDigit(c) && c != '.') + return false; + return true; + } + + bool isAtom(SMTLib2Expression const & expr) + { + return std::holds_alternative(expr.data); + }; + + std::string const& asAtom(SMTLib2Expression const& expr) + { + assert(isAtom(expr)); + return std::get(expr.data); + } + + auto const& asSubExpressions(SMTLib2Expression const& expr) + { + assert(!isAtom(expr)); + return std::get(expr.data); + } + + SortPointer toSort(SMTLib2Expression const& expr) + { + if (isAtom(expr)) { + auto const& name = asAtom(expr); + if (name == "Int") + return SortProvider::sintSort; + } else { + auto const& args = asSubExpressions(expr); + if (asAtom(args[0]) == "Array") { + assert(args.size() == 3); + auto domainSort = toSort(args[1]); + auto codomainSort = toSort(args[2]); + return std::make_shared(std::move(domainSort), std::move(codomainSort)); + } + } + // FIXME: This is not correct, we need to track sorts properly! + return SortProvider::boolSort; +// smtAssert(false, "Unknown sort encountered"); + + } + + smtutil::Expression toSMTUtilExpression(SMTLib2Expression const& _expr) + { + return std::visit(GenericVisitor{ + [&](std::string const& _atom) { + if (_atom == "true" || _atom == "false") + return smtutil::Expression(_atom == "true"); + else if (isNumber(_atom)) + return smtutil::Expression(_atom, {}, SortProvider::sintSort); + else + return smtutil::Expression(_atom, {}, SortProvider::boolSort); + }, + [&](std::vector const& _subExpr) { + SortPointer sort; + std::vector arguments; + if (isAtom(_subExpr.front())) { + std::string const &op = std::get(_subExpr.front().data); + std::set boolOperators{"and", "or", "not", "=", "<", ">", "<=", ">=", "=>"}; + for (size_t i = 1; i < _subExpr.size(); i++) + arguments.emplace_back(toSMTUtilExpression(_subExpr[i])); + sort = contains(boolOperators, op) ? SortProvider::boolSort : arguments.back().sort; + return smtutil::Expression(op, move(arguments), move(sort)); + } else { + // check for const array + if (_subExpr.size() == 2 and !isAtom(_subExpr[0])) + { + auto const& typeArgs = asSubExpressions(_subExpr.front()); + if (typeArgs.size() == 3 && typeArgs[0].toString() == "as" && typeArgs[1].toString() == "const") + { + auto arraySort = toSort(typeArgs[2]); + auto sortSort = std::make_shared(arraySort); + return smtutil::Expression::const_array(Expression(sortSort), toSMTUtilExpression(_subExpr[1])); + } + } + + smtAssert(false, "Unhandled case in expression conversion"); + } + } + }, _expr.data); + } + + + + class SMTLib2Parser + { + public: + SMTLib2Parser(std::istream& _input): + m_input(_input), + m_token(static_cast(m_input.get())) + {} + + SMTLib2Expression parseExpression() + { + skipWhitespace(); + if (token() == '(') + { + advance(); + skipWhitespace(); + std::vector subExpressions; + while (token() != 0 && token() != ')') + { + subExpressions.emplace_back(parseExpression()); + skipWhitespace(); + } + solAssert(token() == ')'); + // simulate whitespace because we do not want to read the next token + // since it might block. + m_token = ' '; + return {move(subExpressions)}; + } + else + return {parseToken()}; + } + + bool isEOF() + { + skipWhitespace(); + return m_input.eof(); + } + + private: + std::string parseToken() + { + std::string result; + + skipWhitespace(); + bool isPipe = token() == '|'; + if (isPipe) + advance(); + while (token() != 0) + { + char c = token(); + if (isPipe && c == '|') + { + advance(); + break; + } + else if (!isPipe && (isWhiteSpace(c) || c == '(' || c == ')')) + break; + result.push_back(c); + advance(); + } + return result; + } + + void skipWhitespace() + { + while (isWhiteSpace(token())) + advance(); + } + + char token() const + { + return m_token; + } + + void advance() + { + m_token = static_cast(m_input.get()); + if (token() == ';') + while (token() != '\n' && token() != 0) + m_token = static_cast(m_input.get()); + } + + std::istream& m_input; + char m_token = 0; + }; + + struct LetBindings { + std::unordered_map bindings; + std::vector varNames; + std::vector scopeBounds; + + bool has(std::string const& varName) + { + return bindings.find(varName) != bindings.end(); + } + + SMTLib2Expression & operator[](std::string const& varName) + { + auto it = bindings.find(varName); + assert(it != bindings.end()); + return it->second; + } + + void pushScope() + { + scopeBounds.push_back(varNames.size()); + } + + void popScope() + { + assert(scopeBounds.size() > 0); + auto bound = scopeBounds.back(); + while (varNames.size() > bound) { + auto const& varName = varNames.back(); + auto it = bindings.find(varName); + assert(it != bindings.end()); + bindings.erase(it); + varNames.pop_back(); + } + scopeBounds.pop_back(); + } + + void addBinding(std::string name, SMTLib2Expression expression) + { + assert(!has(name)); + bindings.insert({name, std::move(expression)}); + varNames.push_back(std::move(name)); + } + }; + + void inlineLetExpressions(SMTLib2Expression& expr, LetBindings & bindings) + { + if (isAtom(expr)) + { + auto const& atom = std::get(expr.data); + if (bindings.has(atom)) + expr = bindings[atom]; + } + else + { + auto& subexprs = std::get(expr.data); + auto const& first = subexprs[0]; + if (isAtom(first) && std::get(first.data) == "let") + { + assert(!isAtom(subexprs[1])); + auto & bindingExpressions = std::get(subexprs[1].data); + // process new bindings + std::vector> newBindings; + for (auto & binding : bindingExpressions) + { + assert(!isAtom(binding)); + auto & bindingPair = std::get(binding.data); + assert(bindingPair.size() == 2); + assert(isAtom(bindingPair.at(0))); + inlineLetExpressions(bindingPair.at(1), bindings); + newBindings.emplace_back(std::get(bindingPair.at(0).data), bindingPair.at(1)); + } + bindings.pushScope(); + for (auto && [name, expr] : newBindings) + bindings.addBinding(std::move(name), std::move(expr)); + newBindings.clear(); + + // get new subexpression + inlineLetExpressions(subexprs.at(2), bindings); + // remove the new bindings + bindings.popScope(); + + // update the expression + auto tmp = std::move(subexprs.at(2)); + expr = std::move(tmp); + return; + } + // not a let expression, just process all arguments + for (auto& subexpr : subexprs) + { + inlineLetExpressions(subexpr, bindings); + } + } + } + + void inlineLetExpressions(SMTLib2Expression& expr) + { + LetBindings bindings; + inlineLetExpressions(expr, bindings); + } + + SMTLib2Expression const& fact(SMTLib2Expression const& _node) + { + if (isAtom(_node)) + return _node; + return asSubExpressions(_node).back(); + } +} + +CHCSolverInterface::CexGraph CHCSmtLib2Interface::graphFromSMTLib2Expression(SMTLib2Expression const& _proof) +{ + assert(!isAtom(_proof)); + auto const& args = asSubExpressions(_proof); + smtAssert(args.size() == 2); + smtAssert(isAtom(args.at(0)) && asAtom(args.at(0)) == "proof"); + auto const& proofNode = args.at(1); + auto derivedFact = fact(proofNode); + if (isAtom(proofNode) || !isAtom(derivedFact) || asAtom(derivedFact) != "false") + return {}; + + CHCSolverInterface::CexGraph graph; + + std::stack proofStack; + proofStack.push(&asSubExpressions(proofNode).at(1)); + + std::map visitedIds; + unsigned nextId = 0; + + + auto const* root = proofStack.top(); + std::cout << root->toString() << std::endl; + auto const& derivedRootFact = fact(*root); + std::cout << derivedRootFact.toString() << std::endl; + visitedIds.insert({root, nextId++}); + graph.nodes.emplace(visitedIds.at(root), toSMTUtilExpression(derivedRootFact)); + + auto isHyperRes = [](SMTLib2Expression const& expr) { + if (isAtom(expr)) return false; + auto const& subExprs = asSubExpressions(expr); + assert(!subExprs.empty()); + auto const& op = subExprs.at(0); + if (isAtom(op)) return false; + auto const& opExprs = asSubExpressions(op); + if (opExprs.size() < 2) return false; + auto const& ruleName = opExprs.at(1); + return isAtom(ruleName) && asAtom(ruleName) == "hyper-res"; + }; + + while (!proofStack.empty()) + { + auto const* proofNode = proofStack.top(); + smtAssert(visitedIds.find(proofNode) != visitedIds.end(), ""); + auto id = visitedIds.at(proofNode); + smtAssert(graph.nodes.count(id), ""); + proofStack.pop(); + + if (isHyperRes(*proofNode)) + { + auto const& args = asSubExpressions(*proofNode); + smtAssert(args.size() > 1, ""); + // args[0] is the name of the rule + // args[1] is the clause used + // last argument is the derived fact + // the arguments in the middle are the facts where we need to recurse + for (unsigned i = 2; i < args.size() - 1; ++i) + { + auto const* child = &args[i]; + if (!visitedIds.count(child)) + { + visitedIds.insert({child, nextId++}); + proofStack.push(child); + } + + auto childId = visitedIds.at(child); + if (!graph.nodes.count(childId)) + { + graph.nodes.emplace(childId, toSMTUtilExpression(fact(*child))); + graph.edges[childId] = {}; + } + + graph.edges[id].push_back(childId); + } + } + } + return graph; +} + +CHCSolverInterface::CexGraph CHCSmtLib2Interface::graphFromZ3Proof(const std::string & _proof) { + std::stringstream ss(_proof); + std::string answer; + ss >> answer; + solAssert(answer == "unsat"); + SMTLib2Parser parser(ss); + solAssert(!parser.isEOF()); + // For some reason Z3 outputs everything as a single s-expression + auto all = parser.parseExpression(); + solAssert(parser.isEOF()); + solAssert(!isAtom(all)); + auto& commands = std::get(all.data); + for (auto& command : commands) { +// std::cout << command.toString() << '\n' << std::endl; + if (!isAtom(command)) { + auto const& head = std::get(command.data)[0]; + if(isAtom(head) && std::get(head.data) == "proof") { + std::cout << "Proof expression!\n" << command.toString() << std::endl; + inlineLetExpressions(command); + std::cout << "Cleaned Proof expression!\n" << command.toString() << std::endl; + return graphFromSMTLib2Expression(command); + } + } + } + return {}; +} \ No newline at end of file diff --git a/libsmtutil/CHCSmtLib2Interface.h b/libsmtutil/CHCSmtLib2Interface.h index 0f3a68ee4..75225cc3d 100644 --- a/libsmtutil/CHCSmtLib2Interface.h +++ b/libsmtutil/CHCSmtLib2Interface.h @@ -32,6 +32,14 @@ namespace solidity::smtutil class CHCSmtLib2Interface: public CHCSolverInterface { public: + struct SMTLib2Expression + { + using args_t = std::vector; + std::variant data; + + std::string toString() const; + }; + explicit CHCSmtLib2Interface( std::map const& _queryResponses = {}, frontend::ReadCallback::Callback _smtCallback = {}, @@ -74,6 +82,10 @@ private: /// Communicates with the solver via the callback. Throws SMTSolverError on error. std::string querySolver(std::string const& _input); + CexGraph graphFromZ3Proof(std::string const& _proof); + + CexGraph graphFromSMTLib2Expression(SMTLib2Expression const& _proof); + /// Used to access toSmtLibSort, SExpr, and handle variables. std::unique_ptr m_smtlib2; @@ -87,6 +99,7 @@ private: SMTSolverChoice m_enabledSolvers; std::map m_sortNames; + std::map m_knownSorts; }; }