From 8c99c125c429fbe17c74b7c3fe3c9e21480f2441 Mon Sep 17 00:00:00 2001 From: Martin Blicha Date: Thu, 10 Aug 2023 20:14:42 +0200 Subject: [PATCH] Add support for parsing invariants --- libsmtutil/CHCSmtLib2Interface.cpp | 531 +++++++++++++++++++---------- libsmtutil/CHCSmtLib2Interface.h | 2 +- 2 files changed, 358 insertions(+), 175 deletions(-) diff --git a/libsmtutil/CHCSmtLib2Interface.cpp b/libsmtutil/CHCSmtLib2Interface.cpp index 245f36190..154158a1e 100644 --- a/libsmtutil/CHCSmtLib2Interface.cpp +++ b/libsmtutil/CHCSmtLib2Interface.cpp @@ -106,7 +106,10 @@ std::tuple CHCSmtLib2Inte CheckResult result; // TODO proper parsing if (boost::starts_with(response, "sat")) + { result = CheckResult::UNSATISFIABLE; + return {result, invariantsFromSMTLib(response), {}}; + } else if (boost::starts_with(response, "unsat")) { result = CheckResult::SATISFIABLE; @@ -199,7 +202,9 @@ std::string CHCSmtLib2Interface::querySolver(std::string const& _input) return "z3 rlimit=1000000 fp.spacer.q3.use_qgen=true fp.spacer.mbqi=false fp.spacer.ground_pobs=false"; return ""; }(); - auto result = m_smtCallback(ReadCallback::kindString(ReadCallback::Kind::SMTQuery) + " " + solverBinary, _input); + std::string z3Input = _input + "(get-model)\n"; + auto const& query = boost::starts_with(solverBinary, "z3") ? z3Input : _input; + auto result = m_smtCallback(ReadCallback::kindString(ReadCallback::Kind::SMTQuery) + " " + solverBinary, query); if (result.success) { if (m_enabledSolvers.z3 and boost::starts_with(result.responseOrErrorMessage, "unsat")) @@ -293,104 +298,6 @@ namespace return std::get(expr.data); } - class SMTLibTranslationContext - { - SMTLib2Interface const& m_smtlib2Interface; - - public: - SMTLibTranslationContext(SMTLib2Interface const& _smtlib2Interface) : m_smtlib2Interface(_smtlib2Interface) {} - - std::optional lookupKnownTupleSort(std::string const& name) { - std::string quotedName = "|" + name + "|"; - auto it = ranges::find_if(m_smtlib2Interface.sortNames(), [&](auto const& entry) { - return entry.second == name || entry.second == quotedName; - }); - if (it != m_smtlib2Interface.sortNames().end() && it->first->kind == Kind::Tuple) - { - auto tupleSort = std::dynamic_pointer_cast(it->first); - smtAssert(tupleSort); - return tupleSort; - } - return {}; - } - - SortPointer toSort(SMTLib2Expression const& expr) - { - if (isAtom(expr)) - { - auto const& name = asAtom(expr); - if (name == "Int") - return SortProvider::sintSort; - if (name == "Bool") - return SortProvider::boolSort; - auto tupleSort = lookupKnownTupleSort(name); - if (tupleSort) - return tupleSort.value(); - } else { - auto const& args = asSubExpressions(expr); - if (asAtom(args[0]) == "Array") - { - assert(args.size() == 3); - auto domainSort = toSort(args[1]); - auto codomainSort = toSort(args[2]); - return std::make_shared(std::move(domainSort), std::move(codomainSort)); - } - } - smtAssert(false, "Unknown sort encountered"); - } - - smtutil::Expression toSMTUtilExpression(SMTLib2Expression const& _expr) - { - return std::visit(GenericVisitor{ - [&](std::string const& _atom) { - if (_atom == "true" || _atom == "false") - return smtutil::Expression(_atom == "true"); - else if (isNumber(_atom)) - return smtutil::Expression(_atom, {}, SortProvider::sintSort); - else - return smtutil::Expression(_atom, {}, SortProvider::boolSort); - }, - [&](std::vector const& _subExpr) { - SortPointer sort; - std::vector arguments; - if (isAtom(_subExpr.front())) - { - for (size_t i = 1; i < _subExpr.size(); i++) - arguments.emplace_back(toSMTUtilExpression(_subExpr[i])); - std::string const& op = asAtom(_subExpr.front()); - if (auto tupleSort = lookupKnownTupleSort(op); tupleSort) - { - auto sortSort = std::make_shared(tupleSort.value()); - return Expression::tuple_constructor(Expression(sortSort), arguments); - } else { - std::set boolOperators{"and", "or", "not", "=", "<", ">", "<=", ">=", - "=>"}; - sort = contains(boolOperators, op) ? SortProvider::boolSort : arguments.back().sort; - return smtutil::Expression(op, std::move(arguments), std::move(sort)); - } - smtAssert(false, "Unhandled case in expression conversion"); - } else { - // check for const array - if (_subExpr.size() == 2 and !isAtom(_subExpr[0])) - { - auto const& typeArgs = asSubExpressions(_subExpr.front()); - if (typeArgs.size() == 3 && typeArgs[0].toString() == "as" && typeArgs[1].toString() == "const") - { - auto arraySort = toSort(typeArgs[2]); - auto sortSort = std::make_shared(arraySort); - return smtutil::Expression::const_array(Expression(sortSort), toSMTUtilExpression(_subExpr[1])); - } - } - - smtAssert(false, "Unhandled case in expression conversion"); - } - } - }, _expr.data); - } - }; - - - class SMTLib2Parser { public: @@ -476,6 +383,213 @@ namespace char m_token = 0; }; + class SMTLibTranslationContext + { + SMTLib2Interface const& m_smtlib2Interface; + std::map knownVariables; + + public: + SMTLibTranslationContext(SMTLib2Interface const& _smtlib2Interface) : m_smtlib2Interface(_smtlib2Interface) + { + // fill user defined sorts and constructors + auto const& userSorts = _smtlib2Interface.userSorts(); // TODO: It would be better to remember userSorts as SortPointers + for (auto const& [_, definition] : userSorts) + { + std::stringstream ss(definition); + SMTLib2Parser parser(ss); + auto expr = parser.parseExpression(); + solAssert(parser.isEOF()); + solAssert(!isAtom(expr)); + auto const& args = asSubExpressions(expr); + solAssert(args.size() == 3); + solAssert(isAtom(args[0]) && asAtom(args[0]) == "declare-datatypes"); + // args[1] is the name of the type + // args[2] is the constructor with the members + solAssert(!isAtom(args[2]) && asSubExpressions(args[2]).size() == 1 && !isAtom(asSubExpressions(args[2])[0])); + auto const& constructors = asSubExpressions(asSubExpressions(args[2])[0]); + solAssert(constructors.size() == 1); + auto const& constructor = constructors[0]; + // constructor is a list: name + members + solAssert(!isAtom(constructor)); + auto const& constructorArgs = asSubExpressions(constructor); + for (unsigned i = 1u; i < constructorArgs.size(); ++i) + { + auto const& carg = constructorArgs[i]; + solAssert(!isAtom(carg) && asSubExpressions(carg).size() == 2); + auto const& nameSortPair = asSubExpressions(carg); + solAssert(isAtom(nameSortPair[0])); + knownVariables.insert({asAtom(nameSortPair[0]), toSort(nameSortPair[1])}); + } + } + } + + void addVariableDeclaration(std::string name, SortPointer sort) + { + solAssert(knownVariables.find(name) == knownVariables.end()); + knownVariables.insert({std::move(name), std::move(sort)}); + } + + std::optional lookupKnownTupleSort(std::string const& name) { + std::string quotedName = "|" + name + "|"; + auto it = ranges::find_if(m_smtlib2Interface.sortNames(), [&](auto const& entry) { + return entry.second == name || entry.second == quotedName; + }); + if (it != m_smtlib2Interface.sortNames().end() && it->first->kind == Kind::Tuple) + { + auto tupleSort = std::dynamic_pointer_cast(it->first); + smtAssert(tupleSort); + return tupleSort; + } + return {}; + } + + SortPointer toSort(SMTLib2Expression const& expr) + { + if (isAtom(expr)) + { + auto const& name = asAtom(expr); + if (name == "Int") + return SortProvider::sintSort; + if (name == "Bool") + return SortProvider::boolSort; + auto tupleSort = lookupKnownTupleSort(name); + if (tupleSort) + return tupleSort.value(); + } else { + auto const& args = asSubExpressions(expr); + if (asAtom(args[0]) == "Array") + { + assert(args.size() == 3); + auto domainSort = toSort(args[1]); + auto codomainSort = toSort(args[2]); + return std::make_shared(std::move(domainSort), std::move(codomainSort)); + } + if (args.size() == 3 && isAtom(args[0]) && asAtom(args[0]) == "_" && isAtom(args[1]) && asAtom(args[1]) == "int2bv") + { + solAssert(isAtom(args[2])); + return std::make_shared(std::stoul(asAtom(args[2]))); + } + } + smtAssert(false, "Unknown sort encountered"); + } + + smtutil::Expression parseQuantifier( + std::string const& quantifierName, + std::vector const& varList, + SMTLib2Expression const& coreExpression + ) + { + std::vector> boundVariables; + for (auto const& sortedVar: varList) + { + solAssert(!isAtom(sortedVar)); + auto varSortPair = asSubExpressions(sortedVar); + solAssert(varSortPair.size() == 2); + solAssert(isAtom(varSortPair[0])); + boundVariables.emplace_back(asAtom(varSortPair[0]), toSort(varSortPair[1])); + } + for (auto const& [var, sort] : boundVariables) + { + solAssert(knownVariables.find(var) == knownVariables.end()); // TODO: deal with shadowing? + knownVariables.insert({var, sort}); + } + auto core = toSMTUtilExpression(coreExpression); + for (auto const& [var, sort] : boundVariables) + { + solAssert(knownVariables.find(var) != knownVariables.end()); + knownVariables.erase(var); + } + return Expression(quantifierName, {core}, SortProvider::boolSort); // TODO: what about the bound variables? + + } + + smtutil::Expression toSMTUtilExpression(SMTLib2Expression const& _expr) + { + return std::visit(GenericVisitor{ + [&](std::string const& _atom) { + if (_atom == "true" || _atom == "false") + return smtutil::Expression(_atom == "true"); + else if (isNumber(_atom)) + return smtutil::Expression(_atom, {}, SortProvider::sintSort); + else if (knownVariables.find(_atom) != knownVariables.end()) + return smtutil::Expression(_atom, {}, knownVariables.at(_atom)); + else // assume this is a predicate, so has sort bool; TODO: Context should be aware of the predicates! + return smtutil::Expression(_atom, {}, SortProvider::boolSort); + }, + [&](std::vector const& _subExpr) { + SortPointer sort; + std::vector arguments; + if (isAtom(_subExpr.front())) + { + std::string const& op = asAtom(_subExpr.front()); + if (op == "!") + { + // named term, we ignore the name + solAssert(_subExpr.size() > 2); + return toSMTUtilExpression(_subExpr[1]); + } + if (op == "exists" || op == "forall") + { + solAssert(_subExpr.size() == 3); + solAssert(!isAtom(_subExpr[1])); + return parseQuantifier(op, asSubExpressions(_subExpr[1]), _subExpr[2]); + } + for (size_t i = 1; i < _subExpr.size(); i++) + arguments.emplace_back(toSMTUtilExpression(_subExpr[i])); + if (auto tupleSort = lookupKnownTupleSort(op); tupleSort) + { + auto sortSort = std::make_shared(tupleSort.value()); + return Expression::tuple_constructor(Expression(sortSort), arguments); + } + if (knownVariables.find(op) != knownVariables.end()) + { + return smtutil::Expression(op, std::move(arguments), knownVariables.at(op)); + } + else { + std::set boolOperators{"and", "or", "not", "=", "<", ">", "<=", ">=", + "=>"}; + sort = contains(boolOperators, op) ? SortProvider::boolSort : arguments.back().sort; + return smtutil::Expression(op, std::move(arguments), std::move(sort)); + } + smtAssert(false, "Unhandled case in expression conversion"); + } else { + // check for const array + if (_subExpr.size() == 2 and !isAtom(_subExpr[0])) + { + auto const& typeArgs = asSubExpressions(_subExpr.front()); + if (typeArgs.size() == 3 && typeArgs[0].toString() == "as" && typeArgs[1].toString() == "const") + { + auto arraySort = toSort(typeArgs[2]); + auto sortSort = std::make_shared(arraySort); + return smtutil::Expression::const_array(Expression(sortSort), toSMTUtilExpression(_subExpr[1])); + } + if (typeArgs.size() == 3 && typeArgs[0].toString() == "_" && typeArgs[1].toString() == "int2bv") + { + auto bvSort = std::dynamic_pointer_cast(toSort(_subExpr[0])); + solAssert(bvSort); + return smtutil::Expression::int2bv(toSMTUtilExpression(_subExpr[1]), bvSort->size); + } + if (typeArgs.size() == 4 && typeArgs[0].toString() == "_") + { + if (typeArgs[1].toString() == "extract") + { + return smtutil::Expression( + "extract", + {toSMTUtilExpression(typeArgs[2]), toSMTUtilExpression(typeArgs[3])}, + SortProvider::bitVectorSort // TODO: Compute bit size properly? + ); + } + } + } + + smtAssert(false, "Unhandled case in expression conversion"); + } + } + }, _expr.data); + } + }; + + struct LetBindings { using BindingRecord = std::vector; std::unordered_map bindings; @@ -592,91 +706,92 @@ namespace solAssert(!asSubExpressions(_node).empty()); return asSubExpressions(_node).back(); } -} -CHCSolverInterface::CexGraph CHCSmtLib2Interface::graphFromSMTLib2Expression(SMTLib2Expression const& _proof) -{ - assert(!isAtom(_proof)); - auto const& args = asSubExpressions(_proof); - smtAssert(args.size() == 2); - smtAssert(isAtom(args.at(0)) && asAtom(args.at(0)) == "proof"); - auto const& proofNode = args.at(1); - auto derivedFact = fact(proofNode); - if (isAtom(proofNode) || !isAtom(derivedFact) || asAtom(derivedFact) != "false") - return {}; - - CHCSolverInterface::CexGraph graph; - SMTLibTranslationContext context(*m_smtlib2); - - std::stack proofStack; - proofStack.push(&asSubExpressions(proofNode).at(1)); - - std::map visitedIds; - unsigned nextId = 0; - - - auto const* root = proofStack.top(); - auto const& derivedRootFact = fact(*root); - visitedIds.insert({root, nextId++}); - graph.nodes.emplace(visitedIds.at(root), context.toSMTUtilExpression(derivedRootFact)); - - auto isHyperRes = [](SMTLib2Expression const& expr) { - if (isAtom(expr)) return false; - auto const& subExprs = asSubExpressions(expr); - solAssert(!subExprs.empty()); - auto const& op = subExprs.at(0); - if (isAtom(op)) return false; - auto const& opExprs = asSubExpressions(op); - if (opExprs.size() < 2) return false; - auto const& ruleName = opExprs.at(1); - return isAtom(ruleName) && asAtom(ruleName) == "hyper-res"; - }; - - while (!proofStack.empty()) + CHCSolverInterface::CexGraph graphFromSMTLib2Expression(SMTLib2Expression const& _proof, SMTLibTranslationContext & context) { - auto const* proofNode = proofStack.top(); - smtAssert(visitedIds.find(proofNode) != visitedIds.end(), ""); - auto id = visitedIds.at(proofNode); - smtAssert(graph.nodes.count(id), ""); - proofStack.pop(); + assert(!isAtom(_proof)); + auto const& args = asSubExpressions(_proof); + smtAssert(args.size() == 2); + smtAssert(isAtom(args.at(0)) && asAtom(args.at(0)) == "proof"); + auto const& proofNode = args.at(1); + auto derivedFact = fact(proofNode); + if (isAtom(proofNode) || !isAtom(derivedFact) || asAtom(derivedFact) != "false") + return {}; - if (isHyperRes(*proofNode)) + CHCSolverInterface::CexGraph graph; + + std::stack proofStack; + proofStack.push(&asSubExpressions(proofNode).at(1)); + + std::map visitedIds; + unsigned nextId = 0; + + + auto const* root = proofStack.top(); + auto const& derivedRootFact = fact(*root); + visitedIds.insert({root, nextId++}); + graph.nodes.emplace(visitedIds.at(root), context.toSMTUtilExpression(derivedRootFact)); + + auto isHyperRes = [](SMTLib2Expression const& expr) { + if (isAtom(expr)) return false; + auto const& subExprs = asSubExpressions(expr); + solAssert(!subExprs.empty()); + auto const& op = subExprs.at(0); + if (isAtom(op)) return false; + auto const& opExprs = asSubExpressions(op); + if (opExprs.size() < 2) return false; + auto const& ruleName = opExprs.at(1); + return isAtom(ruleName) && asAtom(ruleName) == "hyper-res"; + }; + + while (!proofStack.empty()) { - auto const& args = asSubExpressions(*proofNode); - smtAssert(args.size() > 1, ""); - // args[0] is the name of the rule - // args[1] is the clause used - // last argument is the derived fact - // the arguments in the middle are the facts where we need to recurse - for (unsigned i = 2; i < args.size() - 1; ++i) + auto const* proofNode = proofStack.top(); + smtAssert(visitedIds.find(proofNode) != visitedIds.end(), ""); + auto id = visitedIds.at(proofNode); + smtAssert(graph.nodes.count(id), ""); + proofStack.pop(); + + if (isHyperRes(*proofNode)) { - auto const* child = &args[i]; - if (!visitedIds.count(child)) + auto const& args = asSubExpressions(*proofNode); + smtAssert(args.size() > 1, ""); + // args[0] is the name of the rule + // args[1] is the clause used + // last argument is the derived fact + // the arguments in the middle are the facts where we need to recurse + for (unsigned i = 2; i < args.size() - 1; ++i) { - visitedIds.insert({child, nextId++}); - proofStack.push(child); - } + auto const* child = &args[i]; + if (!visitedIds.count(child)) + { + visitedIds.insert({child, nextId++}); + proofStack.push(child); + } - auto childId = visitedIds.at(child); - if (!graph.nodes.count(childId)) - { - graph.nodes.emplace(childId, context.toSMTUtilExpression(fact(*child))); - graph.edges[childId] = {}; - } + auto childId = visitedIds.at(child); + if (!graph.nodes.count(childId)) + { + graph.nodes.emplace(childId, context.toSMTUtilExpression(fact(*child))); + graph.edges[childId] = {}; + } - graph.edges[id].push_back(childId); + graph.edges[id].push_back(childId); + } } } + return graph; } - return graph; } + CHCSolverInterface::CexGraph CHCSmtLib2Interface::graphFromZ3Proof(std::string const& _proof) { std::stringstream ss(_proof); std::string answer; ss >> answer; solAssert(answer == "unsat"); + SMTLib2Parser parser(ss); if (parser.isEOF()) // No proof from Z3 return {}; @@ -685,19 +800,87 @@ CHCSolverInterface::CexGraph CHCSmtLib2Interface::graphFromZ3Proof(std::string c solAssert(parser.isEOF()); solAssert(!isAtom(all)); auto& commands = std::get(all.data); - for (auto& command: commands) { -// std::cout << command.toString() << '\n' << std::endl; - if (!isAtom(command)) + SMTLibTranslationContext context(*m_smtlib2); + for (auto& command: commands) + { + if (isAtom(command)) + continue; + + auto const& args = asSubExpressions(command); + auto const& head = args[0]; + if (!isAtom(head)) + continue; + + if (asAtom(head) == "declare-fun") + { + solAssert(args.size() == 4); + auto const& name = args[1]; + auto const& domainSorts = args[2]; + auto const& codomainSort = args[3]; + solAssert(isAtom(name)); + solAssert(!isAtom(domainSorts)); + context.addVariableDeclaration(asAtom(name), context.toSort(codomainSort)); + } + else if (asAtom(head) == "proof") { - auto const& head = std::get(command.data)[0]; - if (isAtom(head) && std::get(head.data) == "proof") - { // std::cout << "Proof expression!\n" << command.toString() << std::endl; - inlineLetExpressions(command); + inlineLetExpressions(command); // std::cout << "Cleaned Proof expression!\n" << command.toString() << std::endl; - return graphFromSMTLib2Expression(command); - } + return graphFromSMTLib2Expression(command, context); } } return {}; } + +smtutil::Expression CHCSmtLib2Interface::invariantsFromSMTLib(std::string const& _invariants) { + std::stringstream ss(_invariants); + std::string answer; + ss >> answer; + solAssert(answer == "sat"); + SMTLib2Parser parser(ss); + if (parser.isEOF()) // No model + return Expression(true); + // For some reason Z3 outputs everything as a single s-expression + auto all = parser.parseExpression(); + solAssert(parser.isEOF()); + solAssert(!isAtom(all)); + auto& commands = std::get(all.data); + std::vector definitions; + for (auto& command: commands) + { + solAssert(!isAtom(command)); + auto& args = asSubExpressions(command); + solAssert(args.size() == 5); + // args[0] = "define-fun" + // args[1] = predicate name + // args[2] = formal arguments of the predicate + // args[3] = return sort + // args[4] = body of the predicate's interpretation + solAssert(isAtom(args[0]) && asAtom(args[0]) == "define-fun"); + solAssert(isAtom(args[1])); + solAssert(!isAtom(args[2])); + solAssert(isAtom(args[3]) && asAtom(args[3]) == "Bool"); + auto& interpretation = args[4]; + inlineLetExpressions(interpretation); + SMTLibTranslationContext context(*m_smtlib2); + auto const& formalArguments = asSubExpressions(args[2]); + std::vector predicateArgs; + for (auto const& formalArgument: formalArguments) + { + solAssert(!isAtom(formalArgument)); + auto const& nameSortPair = asSubExpressions(formalArgument); + solAssert(nameSortPair.size() == 2); + solAssert(isAtom(nameSortPair[0])); + SortPointer varSort = context.toSort(nameSortPair[1]); + context.addVariableDeclaration(asAtom(nameSortPair[0]), varSort); + Expression arg = context.toSMTUtilExpression(nameSortPair[0]); + predicateArgs.push_back(arg); + } + + auto parsedInterpretation = context.toSMTUtilExpression(interpretation); + + Expression predicate(asAtom(args[1]), predicateArgs, SortProvider::boolSort); + definitions.push_back(predicate == parsedInterpretation); + } + return Expression::mkAnd(std::move(definitions)); +} diff --git a/libsmtutil/CHCSmtLib2Interface.h b/libsmtutil/CHCSmtLib2Interface.h index dd8f4342e..99100d663 100644 --- a/libsmtutil/CHCSmtLib2Interface.h +++ b/libsmtutil/CHCSmtLib2Interface.h @@ -84,7 +84,7 @@ private: CexGraph graphFromZ3Proof(std::string const& _proof); - CexGraph graphFromSMTLib2Expression(SMTLib2Expression const& _proof); + smtutil::Expression invariantsFromSMTLib(std::string const& _invariants); /// Used to access toSmtLibSort, SExpr, and handle variables. std::unique_ptr m_smtlib2;