[SMTChecker] Add FunctionSort and refactors the solver interface to create variables

This commit is contained in:
Leonardo Alt 2018-11-21 16:57:02 +01:00
parent dc748bc771
commit 13a142b039
14 changed files with 137 additions and 131 deletions

View File

@ -33,8 +33,7 @@ CVC4Interface::CVC4Interface():
void CVC4Interface::reset()
{
m_constants.clear();
m_functions.clear();
m_variables.clear();
m_solver.reset();
m_solver.setOption("produce-models", true);
m_solver.setTimeLimit(queryTimeout);
@ -50,25 +49,10 @@ void CVC4Interface::pop()
m_solver.pop();
}
void CVC4Interface::declareFunction(string _name, vector<SortPointer> const& _domain, Sort const& _codomain)
void CVC4Interface::declareVariable(string const& _name, Sort const& _sort)
{
if (!m_functions.count(_name))
{
CVC4::Type fType = m_context.mkFunctionType(cvc4Sort(_domain), cvc4Sort(_codomain));
m_functions.insert({_name, m_context.mkVar(_name.c_str(), fType)});
}
}
void CVC4Interface::declareInteger(string _name)
{
if (!m_constants.count(_name))
m_constants.insert({_name, m_context.mkVar(_name.c_str(), m_context.integerType())});
}
void CVC4Interface::declareBool(string _name)
{
if (!m_constants.count(_name))
m_constants.insert({_name, m_context.mkVar(_name.c_str(), m_context.booleanType())});
if (!m_variables.count(_name))
m_variables.insert({_name, m_context.mkVar(_name.c_str(), cvc4Sort(_sort))});
}
void CVC4Interface::addAssertion(Expression const& _expr)
@ -129,20 +113,19 @@ pair<CheckResult, vector<string>> CVC4Interface::check(vector<Expression> const&
CVC4::Expr CVC4Interface::toCVC4Expr(Expression const& _expr)
{
if (_expr.arguments.empty() && m_constants.count(_expr.name))
return m_constants.at(_expr.name);
// Variable
if (_expr.arguments.empty() && m_variables.count(_expr.name))
return m_variables.at(_expr.name);
vector<CVC4::Expr> arguments;
for (auto const& arg: _expr.arguments)
arguments.push_back(toCVC4Expr(arg));
string const& n = _expr.name;
if (m_functions.count(n))
return m_context.mkExpr(CVC4::kind::APPLY_UF, m_functions[n], arguments);
else if (m_constants.count(n))
{
solAssert(arguments.empty(), "");
return m_constants.at(n);
}
// Function application
if (!arguments.empty() && m_variables.count(_expr.name))
return m_context.mkExpr(CVC4::kind::APPLY_UF, m_variables.at(n), arguments);
// Literal
else if (arguments.empty())
{
if (n == "true")
@ -194,6 +177,11 @@ CVC4::Type CVC4Interface::cvc4Sort(Sort const& _sort)
return m_context.booleanType();
case Kind::Int:
return m_context.integerType();
case Kind::Function:
{
FunctionSort const& fSort = dynamic_cast<FunctionSort const&>(_sort);
return m_context.mkFunctionType(cvc4Sort(fSort.domain), cvc4Sort(*fSort.codomain));
}
default:
break;
}

View File

@ -51,9 +51,7 @@ public:
void push() override;
void pop() override;
void declareFunction(std::string _name, std::vector<SortPointer> const& _domain, Sort const& _codomain) override;
void declareInteger(std::string _name) override;
void declareBool(std::string _name) override;
void declareVariable(std::string const&, Sort const&) override;
void addAssertion(Expression const& _expr) override;
std::pair<CheckResult, std::vector<std::string>> check(std::vector<Expression> const& _expressionsToEvaluate) override;
@ -65,8 +63,7 @@ private:
CVC4::ExprManager m_context;
CVC4::SmtEngine m_solver;
std::map<std::string, CVC4::Expr> m_constants;
std::map<std::string, CVC4::Expr> m_functions;
std::map<std::string, CVC4::Expr> m_variables;
};
}

View File

@ -417,10 +417,15 @@ void SMTChecker::visitGasLeft(FunctionCall const& _funCall)
void SMTChecker::visitBlockHash(FunctionCall const& _funCall)
{
string blockHash = "blockhash";
defineUninterpretedFunction(blockHash, {make_shared<smt::Sort>(smt::Kind::Int)}, smt::Kind::Int);
auto const& arguments = _funCall.arguments();
solAssert(arguments.size() == 1, "");
defineExpr(_funCall, m_uninterpretedFunctions.at(blockHash)({expr(*arguments[0])}));
smt::SortPointer paramSort = smtSort(*arguments.at(0)->annotation().type);
smt::SortPointer returnSort = smtSort(*_funCall.annotation().type);
defineUninterpretedFunction(
blockHash,
make_shared<smt::FunctionSort>(vector<smt::SortPointer>{paramSort}, returnSort)
);
defineExpr(_funCall, m_uninterpretedFunctions.at(blockHash)({expr(*arguments.at(0))}));
m_uninterpretedTerms.push_back(&_funCall);
}
@ -606,10 +611,10 @@ void SMTChecker::defineSpecialVariable(string const& _name, Expression const& _e
defineExpr(_expr, m_specialVariables.at(_name)->currentValue());
}
void SMTChecker::defineUninterpretedFunction(string const& _name, vector<smt::SortPointer> const& _domain, smt::Sort const& _codomain)
void SMTChecker::defineUninterpretedFunction(string const& _name, smt::SortPointer _sort)
{
if (!m_uninterpretedFunctions.count(_name))
m_uninterpretedFunctions.emplace(_name, m_interface->newFunction(_name, _domain, _codomain));
m_uninterpretedFunctions.emplace(_name, m_interface->newVariable(_name, _sort));
}
void SMTChecker::arithmeticOperation(BinaryOperation const& _op)

View File

@ -88,7 +88,7 @@ private:
void inlineFunctionCall(FunctionCall const&);
void defineSpecialVariable(std::string const& _name, Expression const& _expr, bool _increaseIndex = false);
void defineUninterpretedFunction(std::string const& _name, std::vector<smt::SortPointer> const& _domain, smt::Sort const& _codomain);
void defineUninterpretedFunction(std::string const& _name, smt::SortPointer _sort);
/// Division expression in the given type. Requires special treatment because
/// of rounding for signed division.

View File

@ -47,8 +47,7 @@ void SMTLib2Interface::reset()
{
m_accumulatedOutput.clear();
m_accumulatedOutput.emplace_back();
m_constants.clear();
m_functions.clear();
m_variables.clear();
write("(set-option :produce-models true)");
write("(set-logic QF_UFLIA)");
}
@ -64,45 +63,39 @@ void SMTLib2Interface::pop()
m_accumulatedOutput.pop_back();
}
void SMTLib2Interface::declareFunction(string _name, vector<SortPointer> const& _domain, Sort const& _codomain)
void SMTLib2Interface::declareVariable(string const& _name, Sort const& _sort)
{
// TODO Use domain and codomain as key as well
string domain("");
for (auto const& sort: _domain)
domain += toSmtLibSort(*sort) + ' ';
if (!m_functions.count(_name))
if (_sort.kind == Kind::Function)
declareFunction(_name, _sort);
else if (!m_variables.count(_name))
{
m_functions.insert(_name);
m_variables.insert(_name);
write("(declare-fun |" + _name + "| () " + toSmtLibSort(_sort) + ')');
}
}
void SMTLib2Interface::declareFunction(string const& _name, Sort const& _sort)
{
solAssert(_sort.kind == smt::Kind::Function, "");
// TODO Use domain and codomain as key as well
if (!m_variables.count(_name))
{
FunctionSort fSort = dynamic_cast<FunctionSort const&>(_sort);
string domain = toSmtLibSort(fSort.domain);
string codomain = toSmtLibSort(*fSort.codomain);
m_variables.insert(_name);
write(
"(declare-fun |" +
_name +
"| (" +
"| " +
domain +
") " +
(_codomain.kind == Kind::Int ? "Int" : "Bool") +
" " +
codomain +
")"
);
}
}
void SMTLib2Interface::declareInteger(string _name)
{
if (!m_constants.count(_name))
{
m_constants.insert(_name);
write("(declare-const |" + _name + "| Int)");
}
}
void SMTLib2Interface::declareBool(string _name)
{
if (!m_constants.count(_name))
{
m_constants.insert(_name);
write("(declare-const |" + _name + "| Bool)");
}
}
void SMTLib2Interface::addAssertion(Expression const& _expr)
{
write("(assert " + toSExpr(_expr) + ")");
@ -156,6 +149,15 @@ string SMTLib2Interface::toSmtLibSort(Sort const& _sort)
}
}
string SMTLib2Interface::toSmtLibSort(vector<SortPointer> const& _sorts)
{
string ssort("(");
for (auto const& sort: _sorts)
ssort += toSmtLibSort(*sort) + " ";
ssort += ")";
return ssort;
}
void SMTLib2Interface::write(string _data)
{
solAssert(!m_accumulatedOutput.empty(), "");

View File

@ -49,16 +49,17 @@ public:
void push() override;
void pop() override;
void declareFunction(std::string _name, std::vector<SortPointer> const& _domain, Sort const& _codomain) override;
void declareInteger(std::string _name) override;
void declareBool(std::string _name) override;
void declareVariable(std::string const&, Sort const&) override;
void addAssertion(Expression const& _expr) override;
std::pair<CheckResult, std::vector<std::string>> check(std::vector<Expression> const& _expressionsToEvaluate) override;
private:
void declareFunction(std::string const&, Sort const&);
std::string toSExpr(Expression const& _expr);
std::string toSmtLibSort(Sort const& _sort);
std::string toSmtLibSort(std::vector<SortPointer> const& _sort);
void write(std::string _data);
@ -70,8 +71,7 @@ private:
ReadCallback::Callback m_queryCallback;
std::vector<std::string> m_accumulatedOutput;
std::set<std::string> m_constants;
std::set<std::string> m_functions;
std::set<std::string> m_variables;
};
}

View File

@ -64,22 +64,10 @@ void SMTPortfolio::pop()
s->pop();
}
void SMTPortfolio::declareFunction(string _name, vector<SortPointer> const& _domain, Sort const& _codomain)
void SMTPortfolio::declareVariable(string const& _name, Sort const& _sort)
{
for (auto s : m_solvers)
s->declareFunction(_name, _domain, _codomain);
}
void SMTPortfolio::declareInteger(string _name)
{
for (auto s : m_solvers)
s->declareInteger(_name);
}
void SMTPortfolio::declareBool(string _name)
{
for (auto s : m_solvers)
s->declareBool(_name);
s->declareVariable(_name, _sort);
}
void SMTPortfolio::addAssertion(Expression const& _expr)

View File

@ -49,9 +49,7 @@ public:
void push() override;
void pop() override;
void declareFunction(std::string _name, std::vector<SortPointer> const& _domain, Sort const& _codomain) override;
void declareInteger(std::string _name) override;
void declareBool(std::string _name) override;
void declareVariable(std::string const&, Sort const&) override;
void addAssertion(Expression const& _expr) override;
std::pair<CheckResult, std::vector<std::string>> check(std::vector<Expression> const& _expressionsToEvaluate) override;

View File

@ -45,7 +45,8 @@ enum class CheckResult
enum class Kind
{
Int,
Bool
Bool,
Function
};
struct Sort
@ -58,6 +59,25 @@ struct Sort
};
using SortPointer = std::shared_ptr<Sort>;
struct FunctionSort: public Sort
{
FunctionSort(std::vector<SortPointer> _domain, SortPointer _codomain):
Sort(Kind::Function), domain(std::move(_domain)), codomain(std::move(_codomain)) {}
std::vector<SortPointer> domain;
SortPointer codomain;
bool operator==(FunctionSort const& _other) const
{
if (!std::equal(
domain.begin(),
domain.end(),
_other.domain.begin(),
[&](SortPointer _a, SortPointer _b) { return *_a == *_b; }
)
)
return false;
return Sort::operator==(_other) && *codomain == *_other.codomain;
}
};
/// C++ representation of an SMTLIB2 expression.
class Expression
@ -162,10 +182,12 @@ public:
Expression operator()(std::vector<Expression> _arguments) const
{
solAssert(
arguments.empty(),
sort->kind == Kind::Function,
"Attempted function application to non-function."
);
return Expression(name, std::move(_arguments), sort);
auto fSort = dynamic_cast<FunctionSort const*>(sort.get());
solAssert(fSort, "");
return Expression(name, std::move(_arguments), fSort->codomain);
}
std::string name;
@ -198,26 +220,12 @@ public:
virtual void push() = 0;
virtual void pop() = 0;
virtual void declareFunction(std::string _name, std::vector<SortPointer> const& _domain, Sort const& _codomain) = 0;
Expression newFunction(std::string _name, std::vector<SortPointer> const& _domain, Sort const& _codomain)
{
declareFunction(_name, _domain, _codomain);
// Subclasses should do something here
return Expression(std::move(_name), {}, _codomain.kind);
}
virtual void declareInteger(std::string _name) = 0;
Expression newInteger(std::string _name)
virtual void declareVariable(std::string const& _name, Sort const& _sort) = 0;
Expression newVariable(std::string _name, SortPointer _sort)
{
// Subclasses should do something here
declareInteger(_name);
return Expression(std::move(_name), {}, Kind::Int);
}
virtual void declareBool(std::string _name) = 0;
Expression newBool(std::string _name)
{
// Subclasses should do something here
declareBool(_name);
return Expression(std::move(_name), {}, Kind::Bool);
declareVariable(_name, *_sort);
return Expression(std::move(_name), {}, std::move(_sort));
}
virtual void addAssertion(Expression const& _expr) = 0;

View File

@ -32,10 +32,29 @@ smt::SortPointer dev::solidity::smtSort(Type const& _type)
return make_shared<smt::Sort>(smt::Kind::Int);
case smt::Kind::Bool:
return make_shared<smt::Sort>(smt::Kind::Bool);
case smt::Kind::Function:
{
auto fType = dynamic_cast<FunctionType const*>(&_type);
solAssert(fType, "");
vector<smt::SortPointer> parameterSorts = smtSort(fType->parameterTypes());
auto returnTypes = fType->returnParameterTypes();
// TODO remove this when we support tuples.
solAssert(returnTypes.size() == 1, "");
smt::SortPointer returnSort = smtSort(*returnTypes.at(0));
return make_shared<smt::FunctionSort>(parameterSorts, returnSort);
}
}
solAssert(false, "Invalid type");
}
vector<smt::SortPointer> dev::solidity::smtSort(vector<TypePointer> const& _types)
{
vector<smt::SortPointer> sorts;
for (auto const& type: _types)
sorts.push_back(smtSort(*type));
return sorts;
}
smt::Kind dev::solidity::smtKind(Type::Category _category)
{
if (isNumber(_category))

View File

@ -30,6 +30,7 @@ namespace solidity
/// Returns the SMT sort that models the Solidity type _type.
smt::SortPointer smtSort(Type const& _type);
std::vector<smt::SortPointer> smtSort(std::vector<TypePointer> const& _types);
/// Returns the SMT kind that models the Solidity type type category _category.
smt::Kind smtKind(Type::Category _category);

View File

@ -59,7 +59,7 @@ SymbolicBoolVariable::SymbolicBoolVariable(
smt::Expression SymbolicBoolVariable::valueAtIndex(int _index) const
{
return m_interface.newBool(uniqueSymbol(_index));
return m_interface.newVariable(uniqueSymbol(_index), make_shared<smt::Sort>(smt::Kind::Bool));
}
void SymbolicBoolVariable::setZeroValue()
@ -83,7 +83,7 @@ SymbolicIntVariable::SymbolicIntVariable(
smt::Expression SymbolicIntVariable::valueAtIndex(int _index) const
{
return m_interface.newInteger(uniqueSymbol(_index));
return m_interface.newVariable(uniqueSymbol(_index), make_shared<smt::Sort>(smt::Kind::Int));
}
void SymbolicIntVariable::setZeroValue()

View File

@ -51,22 +51,22 @@ void Z3Interface::pop()
m_solver.pop();
}
void Z3Interface::declareFunction(string _name, vector<SortPointer> const& _domain, Sort const& _codomain)
void Z3Interface::declareVariable(string const& _name, Sort const& _sort)
{
if (_sort.kind == Kind::Function)
declareFunction(_name, _sort);
else if (!m_constants.count(_name))
m_constants.insert({_name, m_context.constant(_name.c_str(), z3Sort(_sort))});
}
void Z3Interface::declareFunction(string const& _name, Sort const& _sort)
{
solAssert(_sort.kind == smt::Kind::Function, "");
if (!m_functions.count(_name))
m_functions.insert({_name, m_context.function(_name.c_str(), z3Sort(_domain), z3Sort(_codomain))});
}
void Z3Interface::declareInteger(string _name)
{
if (!m_constants.count(_name))
m_constants.insert({_name, m_context.int_const(_name.c_str())});
}
void Z3Interface::declareBool(string _name)
{
if (!m_constants.count(_name))
m_constants.insert({_name, m_context.bool_const(_name.c_str())});
{
FunctionSort fSort = dynamic_cast<FunctionSort const&>(_sort);
m_functions.insert({_name, m_context.function(_name.c_str(), z3Sort(fSort.domain), z3Sort(*fSort.codomain))});
}
}
void Z3Interface::addAssertion(Expression const& _expr)

View File

@ -40,14 +40,14 @@ public:
void push() override;
void pop() override;
void declareFunction(std::string _name, std::vector<SortPointer> const& _domain, Sort const& _codomain) override;
void declareInteger(std::string _name) override;
void declareBool(std::string _name) override;
void declareVariable(std::string const& _name, Sort const& _sort) override;
void addAssertion(Expression const& _expr) override;
std::pair<CheckResult, std::vector<std::string>> check(std::vector<Expression> const& _expressionsToEvaluate) override;
private:
void declareFunction(std::string const& _name, Sort const& _sort);
z3::expr toZ3Expr(Expression const& _expr);
z3::sort z3Sort(smt::Sort const& _sort);
z3::sort_vector z3Sort(std::vector<smt::SortPointer> const& _sorts);