Cache sorts already in SMTLib2Interface

This allows us to ask for a sort of a sort from its string
representation parsed from an SMT-LIB solver response
This commit is contained in:
Martin Blicha 2023-07-21 15:40:00 +02:00
parent 6acbe2ec35
commit 8ea8a1eb99
4 changed files with 100 additions and 81 deletions

View File

@ -29,6 +29,8 @@
#include <boost/algorithm/string/predicate.hpp> #include <boost/algorithm/string/predicate.hpp>
#include <range/v3/view.hpp> #include <range/v3/view.hpp>
#include <range/v3/algorithm/find.hpp>
#include <range/v3/algorithm/find_if.hpp>
#include <array> #include <array>
#include <fstream> #include <fstream>
@ -38,7 +40,6 @@
#include <stdexcept> #include <stdexcept>
#include <unordered_map> #include <unordered_map>
#include <variant> #include <variant>
#include <range/v3/algorithm/find.hpp>
using namespace solidity; using namespace solidity;
using namespace solidity::util; using namespace solidity::util;
@ -66,8 +67,6 @@ void CHCSmtLib2Interface::reset()
m_accumulatedOutput.clear(); m_accumulatedOutput.clear();
m_variables.clear(); m_variables.clear();
m_unhandledQueries.clear(); m_unhandledQueries.clear();
m_sortNames.clear();
m_knownSorts.clear();
} }
void CHCSmtLib2Interface::registerRelation(Expression const& _expr) void CHCSmtLib2Interface::registerRelation(Expression const& _expr)
@ -134,22 +133,12 @@ void CHCSmtLib2Interface::declareVariable(std::string const& _name, SortPointer
std::string CHCSmtLib2Interface::toSmtLibSort(Sort const& _sort) std::string CHCSmtLib2Interface::toSmtLibSort(Sort const& _sort)
{ {
if (!m_sortNames.count(&_sort)) return m_smtlib2->toSmtLibSort(_sort);
{
auto smtLibName = m_smtlib2->toSmtLibSort(_sort);
m_sortNames[&_sort] = smtLibName;
m_knownSorts[smtLibName] = &_sort;
}
return m_sortNames.at(&_sort);
} }
std::string CHCSmtLib2Interface::toSmtLibSort(std::vector<SortPointer> const& _sorts) std::string CHCSmtLib2Interface::toSmtLibSort(std::vector<SortPointer> const& _sorts)
{ {
std::string ssort("("); return m_smtlib2->toSmtLibSort(_sorts);
for (auto const& sort: _sorts)
ssort += toSmtLibSort(*sort) + " ";
ssort += ")";
return ssort;
} }
std::string CHCSmtLib2Interface::forall() std::string CHCSmtLib2Interface::forall()
@ -295,12 +284,27 @@ namespace {
return std::get<SMTLib2Expression::args_t>(expr.data); return std::get<SMTLib2Expression::args_t>(expr.data);
} }
class SMTLibTranslationContext
{
SMTLib2Interface const& m_smtlib2Interface;
public:
SMTLibTranslationContext(SMTLib2Interface const& _smtlib2Interface) : m_smtlib2Interface(_smtlib2Interface) {}
SortPointer toSort(SMTLib2Expression const& expr) SortPointer toSort(SMTLib2Expression const& expr)
{ {
if (isAtom(expr)) { if (isAtom(expr)) {
auto const& name = asAtom(expr); auto const& name = asAtom(expr);
if (name == "Int") if (name == "Int")
return SortProvider::sintSort; return SortProvider::sintSort;
if (name == "Bool")
return SortProvider::boolSort;
std::string quotedName = "|" + name + "|";
auto it = ranges::find_if(m_smtlib2Interface.sortNames(), [&](auto const& entry) {
return entry.second == name || entry.second == quotedName;
});
if (it != m_smtlib2Interface.sortNames().end())
return std::make_shared<Sort>(*it->first);
} else { } else {
auto const& args = asSubExpressions(expr); auto const& args = asSubExpressions(expr);
if (asAtom(args[0]) == "Array") { if (asAtom(args[0]) == "Array") {
@ -311,9 +315,8 @@ namespace {
} }
} }
// FIXME: This is not correct, we need to track sorts properly! // FIXME: This is not correct, we need to track sorts properly!
return SortProvider::boolSort; // return SortProvider::boolSort;
// smtAssert(false, "Unknown sort encountered"); smtAssert(false, "Unknown sort encountered");
} }
smtutil::Expression toSMTUtilExpression(SMTLib2Expression const& _expr) smtutil::Expression toSMTUtilExpression(SMTLib2Expression const& _expr)
@ -355,6 +358,7 @@ namespace {
} }
}, _expr.data); }, _expr.data);
} }
};
@ -563,6 +567,7 @@ CHCSolverInterface::CexGraph CHCSmtLib2Interface::graphFromSMTLib2Expression(SMT
return {}; return {};
CHCSolverInterface::CexGraph graph; CHCSolverInterface::CexGraph graph;
SMTLibTranslationContext context(*m_smtlib2);
std::stack<SMTLib2Expression const*> proofStack; std::stack<SMTLib2Expression const*> proofStack;
proofStack.push(&asSubExpressions(proofNode).at(1)); proofStack.push(&asSubExpressions(proofNode).at(1));
@ -572,11 +577,9 @@ CHCSolverInterface::CexGraph CHCSmtLib2Interface::graphFromSMTLib2Expression(SMT
auto const* root = proofStack.top(); auto const* root = proofStack.top();
std::cout << root->toString() << std::endl;
auto const& derivedRootFact = fact(*root); auto const& derivedRootFact = fact(*root);
std::cout << derivedRootFact.toString() << std::endl;
visitedIds.insert({root, nextId++}); visitedIds.insert({root, nextId++});
graph.nodes.emplace(visitedIds.at(root), toSMTUtilExpression(derivedRootFact)); graph.nodes.emplace(visitedIds.at(root), context.toSMTUtilExpression(derivedRootFact));
auto isHyperRes = [](SMTLib2Expression const& expr) { auto isHyperRes = [](SMTLib2Expression const& expr) {
if (isAtom(expr)) return false; if (isAtom(expr)) return false;
@ -618,7 +621,7 @@ CHCSolverInterface::CexGraph CHCSmtLib2Interface::graphFromSMTLib2Expression(SMT
auto childId = visitedIds.at(child); auto childId = visitedIds.at(child);
if (!graph.nodes.count(childId)) if (!graph.nodes.count(childId))
{ {
graph.nodes.emplace(childId, toSMTUtilExpression(fact(*child))); graph.nodes.emplace(childId, context.toSMTUtilExpression(fact(*child)));
graph.edges[childId] = {}; graph.edges[childId] = {};
} }
@ -645,10 +648,10 @@ CHCSolverInterface::CexGraph CHCSmtLib2Interface::graphFromZ3Proof(const std::st
// std::cout << command.toString() << '\n' << std::endl; // std::cout << command.toString() << '\n' << std::endl;
if (!isAtom(command)) { if (!isAtom(command)) {
auto const& head = std::get<SMTLib2Expression::args_t>(command.data)[0]; auto const& head = std::get<SMTLib2Expression::args_t>(command.data)[0];
if(isAtom(head) && std::get<std::string>(head.data) == "proof") { if (isAtom(head) && std::get<std::string>(head.data) == "proof") {
std::cout << "Proof expression!\n" << command.toString() << std::endl; // std::cout << "Proof expression!\n" << command.toString() << std::endl;
inlineLetExpressions(command); inlineLetExpressions(command);
std::cout << "Cleaned Proof expression!\n" << command.toString() << std::endl; // std::cout << "Cleaned Proof expression!\n" << command.toString() << std::endl;
return graphFromSMTLib2Expression(command); return graphFromSMTLib2Expression(command);
} }
} }

View File

@ -97,9 +97,6 @@ private:
frontend::ReadCallback::Callback m_smtCallback; frontend::ReadCallback::Callback m_smtCallback;
SMTSolverChoice m_enabledSolvers; SMTSolverChoice m_enabledSolvers;
std::map<Sort const*, std::string> m_sortNames;
std::map<std::string, Sort const*> m_knownSorts;
}; };
} }

View File

@ -57,6 +57,7 @@ void SMTLib2Interface::reset()
m_accumulatedOutput.emplace_back(); m_accumulatedOutput.emplace_back();
m_variables.clear(); m_variables.clear();
m_userSorts.clear(); m_userSorts.clear();
m_sortNames.clear();
write("(set-option :produce-models true)"); write("(set-option :produce-models true)");
if (m_queryTimeout) if (m_queryTimeout)
write("(set-option :timeout " + std::to_string(*m_queryTimeout) + ")"); write("(set-option :timeout " + std::to_string(*m_queryTimeout) + ")");
@ -276,6 +277,16 @@ std::string SMTLib2Interface::toSExpr(Expression const& _expr)
} }
std::string SMTLib2Interface::toSmtLibSort(Sort const& _sort) std::string SMTLib2Interface::toSmtLibSort(Sort const& _sort)
{
if (!m_sortNames.count(&_sort))
{
auto smtLibName = sortToString(_sort);
m_sortNames[&_sort] = smtLibName;
}
return m_sortNames.at(&_sort);
}
std::string SMTLib2Interface::sortToString(Sort const& _sort)
{ {
switch (_sort.kind) switch (_sort.kind)
{ {

View File

@ -69,9 +69,13 @@ public:
std::vector<std::pair<std::string, std::string>> const& userSorts() const { return m_userSorts; } std::vector<std::pair<std::string, std::string>> const& userSorts() const { return m_userSorts; }
auto const& sortNames() const { return m_sortNames; }
std::string dumpQuery(std::vector<Expression> const& _expressionsToEvaluate); std::string dumpQuery(std::vector<Expression> const& _expressionsToEvaluate);
private: private:
std::string sortToString(Sort const& _sort);
void declareFunction(std::string const& _name, SortPointer const& _sort); void declareFunction(std::string const& _name, SortPointer const& _sort);
void write(std::string _data); void write(std::string _data);
@ -86,6 +90,10 @@ private:
/// It needs to be a vector so that the declaration order is kept, /// It needs to be a vector so that the declaration order is kept,
/// otherwise solvers cannot parse the queries. /// otherwise solvers cannot parse the queries.
std::vector<std::pair<std::string, std::string>> m_userSorts; std::vector<std::pair<std::string, std::string>> m_userSorts;
// TODO: Should this remember shared_pointer?
// TODO: Shouldn't sorts be unique objects?
std::map<Sort const*, std::string> m_sortNames;
std::vector<std::string> m_unhandledQueries; std::vector<std::string> m_unhandledQueries;