From 8ea8a1eb998c168fdf13a133c5822ddf01552379 Mon Sep 17 00:00:00 2001 From: Martin Blicha Date: Fri, 21 Jul 2023 15:40:00 +0200 Subject: [PATCH] Cache sorts already in SMTLib2Interface This allows us to ask for a sort of a sort from its string representation parsed from an SMT-LIB solver response --- libsmtutil/CHCSmtLib2Interface.cpp | 159 +++++++++++++++-------------- libsmtutil/CHCSmtLib2Interface.h | 3 - libsmtutil/SMTLib2Interface.cpp | 11 ++ libsmtutil/SMTLib2Interface.h | 8 ++ 4 files changed, 100 insertions(+), 81 deletions(-) diff --git a/libsmtutil/CHCSmtLib2Interface.cpp b/libsmtutil/CHCSmtLib2Interface.cpp index 63bc03165..f7c3ed286 100644 --- a/libsmtutil/CHCSmtLib2Interface.cpp +++ b/libsmtutil/CHCSmtLib2Interface.cpp @@ -29,6 +29,8 @@ #include #include +#include +#include #include #include @@ -38,7 +40,6 @@ #include #include #include -#include using namespace solidity; using namespace solidity::util; @@ -66,8 +67,6 @@ void CHCSmtLib2Interface::reset() m_accumulatedOutput.clear(); m_variables.clear(); m_unhandledQueries.clear(); - m_sortNames.clear(); - m_knownSorts.clear(); } void CHCSmtLib2Interface::registerRelation(Expression const& _expr) @@ -134,22 +133,12 @@ void CHCSmtLib2Interface::declareVariable(std::string const& _name, SortPointer std::string CHCSmtLib2Interface::toSmtLibSort(Sort const& _sort) { - if (!m_sortNames.count(&_sort)) - { - auto smtLibName = m_smtlib2->toSmtLibSort(_sort); - m_sortNames[&_sort] = smtLibName; - m_knownSorts[smtLibName] = &_sort; - } - return m_sortNames.at(&_sort); + return m_smtlib2->toSmtLibSort(_sort); } std::string CHCSmtLib2Interface::toSmtLibSort(std::vector const& _sorts) { - std::string ssort("("); - for (auto const& sort: _sorts) - ssort += toSmtLibSort(*sort) + " "; - ssort += ")"; - return ssort; + return m_smtlib2->toSmtLibSort(_sorts); } std::string CHCSmtLib2Interface::forall() @@ -295,66 +284,81 @@ namespace { return std::get(expr.data); } - SortPointer toSort(SMTLib2Expression const& expr) + class SMTLibTranslationContext { - if (isAtom(expr)) { - auto const& name = asAtom(expr); - if (name == "Int") - return SortProvider::sintSort; - } else { - auto const& args = asSubExpressions(expr); - if (asAtom(args[0]) == "Array") { - assert(args.size() == 3); - auto domainSort = toSort(args[1]); - auto codomainSort = toSort(args[2]); - return std::make_shared(std::move(domainSort), std::move(codomainSort)); - } - } - // FIXME: This is not correct, we need to track sorts properly! - return SortProvider::boolSort; -// smtAssert(false, "Unknown sort encountered"); + SMTLib2Interface const& m_smtlib2Interface; - } + public: + SMTLibTranslationContext(SMTLib2Interface const& _smtlib2Interface) : m_smtlib2Interface(_smtlib2Interface) {} - smtutil::Expression toSMTUtilExpression(SMTLib2Expression const& _expr) - { - return std::visit(GenericVisitor{ - [&](std::string const& _atom) { - if (_atom == "true" || _atom == "false") - return smtutil::Expression(_atom == "true"); - else if (isNumber(_atom)) - return smtutil::Expression(_atom, {}, SortProvider::sintSort); - else - return smtutil::Expression(_atom, {}, SortProvider::boolSort); - }, - [&](std::vector const& _subExpr) { - SortPointer sort; - std::vector arguments; - if (isAtom(_subExpr.front())) { - std::string const &op = std::get(_subExpr.front().data); - std::set boolOperators{"and", "or", "not", "=", "<", ">", "<=", ">=", "=>"}; - for (size_t i = 1; i < _subExpr.size(); i++) - arguments.emplace_back(toSMTUtilExpression(_subExpr[i])); - sort = contains(boolOperators, op) ? SortProvider::boolSort : arguments.back().sort; - return smtutil::Expression(op, move(arguments), move(sort)); - } else { - // check for const array - if (_subExpr.size() == 2 and !isAtom(_subExpr[0])) - { - auto const& typeArgs = asSubExpressions(_subExpr.front()); - if (typeArgs.size() == 3 && typeArgs[0].toString() == "as" && typeArgs[1].toString() == "const") - { - auto arraySort = toSort(typeArgs[2]); - auto sortSort = std::make_shared(arraySort); - return smtutil::Expression::const_array(Expression(sortSort), toSMTUtilExpression(_subExpr[1])); - } - } - - smtAssert(false, "Unhandled case in expression conversion"); - } + SortPointer toSort(SMTLib2Expression const& expr) + { + if (isAtom(expr)) { + auto const& name = asAtom(expr); + if (name == "Int") + return SortProvider::sintSort; + if (name == "Bool") + return SortProvider::boolSort; + std::string quotedName = "|" + name + "|"; + auto it = ranges::find_if(m_smtlib2Interface.sortNames(), [&](auto const& entry) { + return entry.second == name || entry.second == quotedName; + }); + if (it != m_smtlib2Interface.sortNames().end()) + return std::make_shared(*it->first); + } else { + auto const& args = asSubExpressions(expr); + if (asAtom(args[0]) == "Array") { + assert(args.size() == 3); + auto domainSort = toSort(args[1]); + auto codomainSort = toSort(args[2]); + return std::make_shared(std::move(domainSort), std::move(codomainSort)); } - }, _expr.data); - } + } + // FIXME: This is not correct, we need to track sorts properly! +// return SortProvider::boolSort; + smtAssert(false, "Unknown sort encountered"); + } + + smtutil::Expression toSMTUtilExpression(SMTLib2Expression const& _expr) + { + return std::visit(GenericVisitor{ + [&](std::string const& _atom) { + if (_atom == "true" || _atom == "false") + return smtutil::Expression(_atom == "true"); + else if (isNumber(_atom)) + return smtutil::Expression(_atom, {}, SortProvider::sintSort); + else + return smtutil::Expression(_atom, {}, SortProvider::boolSort); + }, + [&](std::vector const& _subExpr) { + SortPointer sort; + std::vector arguments; + if (isAtom(_subExpr.front())) { + std::string const &op = std::get(_subExpr.front().data); + std::set boolOperators{"and", "or", "not", "=", "<", ">", "<=", ">=", "=>"}; + for (size_t i = 1; i < _subExpr.size(); i++) + arguments.emplace_back(toSMTUtilExpression(_subExpr[i])); + sort = contains(boolOperators, op) ? SortProvider::boolSort : arguments.back().sort; + return smtutil::Expression(op, move(arguments), move(sort)); + } else { + // check for const array + if (_subExpr.size() == 2 and !isAtom(_subExpr[0])) + { + auto const& typeArgs = asSubExpressions(_subExpr.front()); + if (typeArgs.size() == 3 && typeArgs[0].toString() == "as" && typeArgs[1].toString() == "const") + { + auto arraySort = toSort(typeArgs[2]); + auto sortSort = std::make_shared(arraySort); + return smtutil::Expression::const_array(Expression(sortSort), toSMTUtilExpression(_subExpr[1])); + } + } + + smtAssert(false, "Unhandled case in expression conversion"); + } + } + }, _expr.data); + } + }; @@ -563,6 +567,7 @@ CHCSolverInterface::CexGraph CHCSmtLib2Interface::graphFromSMTLib2Expression(SMT return {}; CHCSolverInterface::CexGraph graph; + SMTLibTranslationContext context(*m_smtlib2); std::stack proofStack; proofStack.push(&asSubExpressions(proofNode).at(1)); @@ -572,11 +577,9 @@ CHCSolverInterface::CexGraph CHCSmtLib2Interface::graphFromSMTLib2Expression(SMT auto const* root = proofStack.top(); - std::cout << root->toString() << std::endl; auto const& derivedRootFact = fact(*root); - std::cout << derivedRootFact.toString() << std::endl; visitedIds.insert({root, nextId++}); - graph.nodes.emplace(visitedIds.at(root), toSMTUtilExpression(derivedRootFact)); + graph.nodes.emplace(visitedIds.at(root), context.toSMTUtilExpression(derivedRootFact)); auto isHyperRes = [](SMTLib2Expression const& expr) { if (isAtom(expr)) return false; @@ -618,7 +621,7 @@ CHCSolverInterface::CexGraph CHCSmtLib2Interface::graphFromSMTLib2Expression(SMT auto childId = visitedIds.at(child); if (!graph.nodes.count(childId)) { - graph.nodes.emplace(childId, toSMTUtilExpression(fact(*child))); + graph.nodes.emplace(childId, context.toSMTUtilExpression(fact(*child))); graph.edges[childId] = {}; } @@ -645,10 +648,10 @@ CHCSolverInterface::CexGraph CHCSmtLib2Interface::graphFromZ3Proof(const std::st // std::cout << command.toString() << '\n' << std::endl; if (!isAtom(command)) { auto const& head = std::get(command.data)[0]; - if(isAtom(head) && std::get(head.data) == "proof") { - std::cout << "Proof expression!\n" << command.toString() << std::endl; + if (isAtom(head) && std::get(head.data) == "proof") { +// std::cout << "Proof expression!\n" << command.toString() << std::endl; inlineLetExpressions(command); - std::cout << "Cleaned Proof expression!\n" << command.toString() << std::endl; +// std::cout << "Cleaned Proof expression!\n" << command.toString() << std::endl; return graphFromSMTLib2Expression(command); } } diff --git a/libsmtutil/CHCSmtLib2Interface.h b/libsmtutil/CHCSmtLib2Interface.h index 75225cc3d..2f71890f0 100644 --- a/libsmtutil/CHCSmtLib2Interface.h +++ b/libsmtutil/CHCSmtLib2Interface.h @@ -97,9 +97,6 @@ private: frontend::ReadCallback::Callback m_smtCallback; SMTSolverChoice m_enabledSolvers; - - std::map m_sortNames; - std::map m_knownSorts; }; } diff --git a/libsmtutil/SMTLib2Interface.cpp b/libsmtutil/SMTLib2Interface.cpp index 5746dbe4e..d125890a5 100644 --- a/libsmtutil/SMTLib2Interface.cpp +++ b/libsmtutil/SMTLib2Interface.cpp @@ -57,6 +57,7 @@ void SMTLib2Interface::reset() m_accumulatedOutput.emplace_back(); m_variables.clear(); m_userSorts.clear(); + m_sortNames.clear(); write("(set-option :produce-models true)"); if (m_queryTimeout) write("(set-option :timeout " + std::to_string(*m_queryTimeout) + ")"); @@ -276,6 +277,16 @@ std::string SMTLib2Interface::toSExpr(Expression const& _expr) } std::string SMTLib2Interface::toSmtLibSort(Sort const& _sort) +{ + if (!m_sortNames.count(&_sort)) + { + auto smtLibName = sortToString(_sort); + m_sortNames[&_sort] = smtLibName; + } + return m_sortNames.at(&_sort); +} + +std::string SMTLib2Interface::sortToString(Sort const& _sort) { switch (_sort.kind) { diff --git a/libsmtutil/SMTLib2Interface.h b/libsmtutil/SMTLib2Interface.h index 8bff1300a..12c5b7e1c 100644 --- a/libsmtutil/SMTLib2Interface.h +++ b/libsmtutil/SMTLib2Interface.h @@ -69,9 +69,13 @@ public: std::vector> const& userSorts() const { return m_userSorts; } + auto const& sortNames() const { return m_sortNames; } + std::string dumpQuery(std::vector const& _expressionsToEvaluate); private: + std::string sortToString(Sort const& _sort); + void declareFunction(std::string const& _name, SortPointer const& _sort); void write(std::string _data); @@ -86,6 +90,10 @@ private: /// It needs to be a vector so that the declaration order is kept, /// otherwise solvers cannot parse the queries. std::vector> m_userSorts; + // TODO: Should this remember shared_pointer? + // TODO: Shouldn't sorts be unique objects? + std::map m_sortNames; + std::vector m_unhandledQueries;