Avoid copying let bindings.

This commit is contained in:
chriseth 2022-06-27 15:04:19 +02:00
parent acfe18cd4e
commit 87bffbba98
2 changed files with 97 additions and 92 deletions

View File

@ -1,4 +1,4 @@
/* /*
This file is part of solidity. This file is part of solidity.
solidity is free software: you can redistribute it and/or modify solidity is free software: you can redistribute it and/or modify
@ -232,7 +232,7 @@ string BooleanLPSolver::toString() const
return result; return result;
} }
void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBindings) void BooleanLPSolver::addAssertion(Expression const& _expr)
{ {
#ifdef DEBUG #ifdef DEBUG
cerr << "adding assertion" << endl; cerr << "adding assertion" << endl;
@ -247,12 +247,12 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind
solAssert(false, "Adding false as top-level assertion."); solAssert(false, "Adding false as top-level assertion.");
size_t varIndex = 0; size_t varIndex = 0;
if (_letBindings->count(_expr.name)) if (m_letBindings.count(_expr.name))
{ {
LetBinding binding = _letBindings->at(_expr.name); LetBinding binding = m_letBindings.at(_expr.name);
if (holds_alternative<smtutil::Expression>(binding)) if (holds_alternative<smtutil::Expression>(binding))
{ {
addAssertion(std::get<smtutil::Expression>(binding), move(_letBindings)); addAssertion(std::get<smtutil::Expression>(binding));
return; return;
} }
else else
@ -266,8 +266,9 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind
} }
else if (_expr.name == "let") else if (_expr.name == "let")
{ {
addLetBindings(_expr, _letBindings); auto newBindings = addLetBindings(_expr);
addAssertion(_expr.arguments.back(), move(_letBindings)); addAssertion(_expr.arguments.back());
removeLetBindings(newBindings);
} }
else if (_expr.name == "=") else if (_expr.name == "=")
{ {
@ -276,21 +277,21 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind
if (_expr.arguments.at(0).sort->kind == Kind::Bool) if (_expr.arguments.at(0).sort->kind == Kind::Bool)
{ {
if (_expr.arguments.at(0).arguments.empty() && isBooleanVariable(_expr.arguments.at(0).name)) if (_expr.arguments.at(0).arguments.empty() && isBooleanVariable(_expr.arguments.at(0).name))
addBooleanEquality(*parseLiteral(_expr.arguments.at(0), _letBindings), _expr.arguments.at(1), _letBindings); addBooleanEquality(*parseLiteral(_expr.arguments.at(0)), _expr.arguments.at(1));
else if (_expr.arguments.at(1).arguments.empty() && isBooleanVariable(_expr.arguments.at(1).name)) else if (_expr.arguments.at(1).arguments.empty() && isBooleanVariable(_expr.arguments.at(1).name))
addBooleanEquality(*parseLiteral(_expr.arguments.at(1), _letBindings), _expr.arguments.at(0), _letBindings); addBooleanEquality(*parseLiteral(_expr.arguments.at(1)), _expr.arguments.at(0));
else else
{ {
Literal newBoolean = *parseLiteral(declareInternalVariable(true), make_shared<map<string, LetBinding>>()); Literal newBoolean = *parseLiteral(declareInternalVariable(true));
addBooleanEquality(newBoolean, _expr.arguments.at(0), _letBindings); addBooleanEquality(newBoolean, _expr.arguments.at(0));
addBooleanEquality(newBoolean, _expr.arguments.at(1), _letBindings); addBooleanEquality(newBoolean, _expr.arguments.at(1));
} }
} }
else if (_expr.arguments.at(0).sort->kind == Kind::Int || _expr.arguments.at(0).sort->kind == Kind::Real) else if (_expr.arguments.at(0).sort->kind == Kind::Int || _expr.arguments.at(0).sort->kind == Kind::Real)
{ {
// Try to see if both sides are linear. // Try to see if both sides are linear.
optional<LinearExpression> left = parseLinearSum(_expr.arguments.at(0), _letBindings); optional<LinearExpression> left = parseLinearSum(_expr.arguments.at(0));
optional<LinearExpression> right = parseLinearSum(_expr.arguments.at(1), _letBindings); optional<LinearExpression> right = parseLinearSum(_expr.arguments.at(1));
if (left && right) if (left && right)
{ {
LinearExpression data = *left - *right; LinearExpression data = *left - *right;
@ -314,7 +315,7 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind
} }
else if (_expr.name == "and") else if (_expr.name == "and")
for (auto const& arg: _expr.arguments) for (auto const& arg: _expr.arguments)
addAssertion(arg, _letBindings); addAssertion(arg);
else if (_expr.name == "or") else if (_expr.name == "or")
{ {
if (_expr.arguments.size() == 1) if (_expr.arguments.size() == 1)
@ -324,7 +325,7 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind
vector<Literal> literals; vector<Literal> literals;
// We could try to parse a full clause here instead. // We could try to parse a full clause here instead.
for (auto const& arg: _expr.arguments) for (auto const& arg: _expr.arguments)
literals.emplace_back(parseLiteralOrReturnEqualBoolean(arg, _letBindings)); literals.emplace_back(parseLiteralOrReturnEqualBoolean(arg));
state().clauses.emplace_back(Clause{move(literals)}); state().clauses.emplace_back(Clause{move(literals)});
} }
} }
@ -338,19 +339,19 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind
{ {
solAssert(_expr.arguments.size() == 1); solAssert(_expr.arguments.size() == 1);
// TODO can we still try to add a fixed constraint? // TODO can we still try to add a fixed constraint?
Literal l = negate(parseLiteralOrReturnEqualBoolean(_expr.arguments.at(0), move(_letBindings))); Literal l = negate(parseLiteralOrReturnEqualBoolean(_expr.arguments.at(0)));
state().clauses.emplace_back(Clause{vector<Literal>{l}}); state().clauses.emplace_back(Clause{vector<Literal>{l}});
} }
else if (_expr.name == "=>") else if (_expr.name == "=>")
{ {
solAssert(_expr.arguments.size() == 2); solAssert(_expr.arguments.size() == 2);
addAssertion(!_expr.arguments.at(0) || _expr.arguments.at(1), move(_letBindings)); addAssertion(!_expr.arguments.at(0) || _expr.arguments.at(1));
} }
else if (_expr.name == "<=" || _expr.name == "<") else if (_expr.name == "<=" || _expr.name == "<")
{ {
solAssert(_expr.arguments.size() == 2); solAssert(_expr.arguments.size() == 2);
optional<LinearExpression> left = parseLinearSum(_expr.arguments.at(0), _letBindings); optional<LinearExpression> left = parseLinearSum(_expr.arguments.at(0));
optional<LinearExpression> right = parseLinearSum(_expr.arguments.at(1), _letBindings); optional<LinearExpression> right = parseLinearSum(_expr.arguments.at(1));
solAssert(left && right); solAssert(left && right);
LinearExpression data = *left - *right; LinearExpression data = *left - *right;
data[0] *= -1; data[0] *= -1;
@ -362,12 +363,12 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind
else if (_expr.name == ">=") else if (_expr.name == ">=")
{ {
solAssert(_expr.arguments.size() == 2); solAssert(_expr.arguments.size() == 2);
addAssertion(_expr.arguments.at(1) <= _expr.arguments.at(0), move(_letBindings)); addAssertion(_expr.arguments.at(1) <= _expr.arguments.at(0));
} }
else if (_expr.name == ">") else if (_expr.name == ">")
{ {
solAssert(_expr.arguments.size() == 2); solAssert(_expr.arguments.size() == 2);
addAssertion(_expr.arguments.at(1) < _expr.arguments.at(0), move(_letBindings)); addAssertion(_expr.arguments.at(1) < _expr.arguments.at(0));
} }
else else
{ {
@ -392,7 +393,7 @@ void BooleanLPSolver::declareVariable(string const& _name, bool _boolean)
resizeAndSet(state().isBooleanVariable, index, _boolean); resizeAndSet(state().isBooleanVariable, index, _boolean);
} }
void BooleanLPSolver::addLetBindings(Expression const& _let, LetBindings& _letBindings) map<string, BooleanLPSolver::LetBinding> BooleanLPSolver::addLetBindings(Expression const& _let)
{ {
map<string, LetBinding> newBindings; map<string, LetBinding> newBindings;
solAssert(_let.name == "let"); solAssert(_let.name == "let");
@ -406,30 +407,38 @@ void BooleanLPSolver::addLetBindings(Expression const& _let, LetBindings& _letBi
{ {
Expression var = declareInternalVariable(isBool); Expression var = declareInternalVariable(isBool);
newBindings.insert({binding.name, state().variables.at(var.name)}); newBindings.insert({binding.name, state().variables.at(var.name)});
addAssertion(var == binding.arguments.at(0), _letBindings); addAssertion(var == binding.arguments.at(0));
} }
} }
_letBindings = make_shared<std::map<std::string, LetBinding>>(*_letBindings);
for (auto& [name, value]: newBindings) for (auto& [name, value]: newBindings)
_letBindings->insert({name, move(value)}); m_letBindings.insert({name, move(value)});
return newBindings;
} }
optional<Literal> BooleanLPSolver::parseLiteral(smtutil::Expression const& _expr, LetBindings _letBindings) void BooleanLPSolver::removeLetBindings(map<string, BooleanLPSolver::LetBinding> const& _bindings)
{
for (auto& [name, value]: _bindings)
m_letBindings.erase(name);
}
optional<Literal> BooleanLPSolver::parseLiteral(smtutil::Expression const& _expr)
{ {
if (_expr.name == "let") if (_expr.name == "let")
{ {
addLetBindings(_expr, _letBindings); map<string, LetBinding> newBindings = addLetBindings(_expr);
return parseLiteral(_expr.arguments.back(), move(_letBindings)); optional<Literal> literal = parseLiteral(_expr.arguments.back());
removeLetBindings(newBindings);
return literal;
} }
if (_expr.arguments.empty()) if (_expr.arguments.empty())
{ {
size_t varIndex = 0; size_t varIndex = 0;
if (_letBindings->count(_expr.name)) if (m_letBindings.count(_expr.name))
{ {
LetBinding binding = _letBindings->at(_expr.name); LetBinding binding = m_letBindings.at(_expr.name);
if (holds_alternative<smtutil::Expression>(binding)) if (holds_alternative<smtutil::Expression>(binding))
return parseLiteral(std::get<smtutil::Expression>(binding), move(_letBindings)); return parseLiteral(std::get<smtutil::Expression>(binding));
else else
varIndex = std::get<size_t>(binding); varIndex = std::get<size_t>(binding);
} }
@ -441,8 +450,8 @@ optional<Literal> BooleanLPSolver::parseLiteral(smtutil::Expression const& _expr
if (!state().trueConstant) if (!state().trueConstant)
{ {
Expression var = declareInternalVariable(true); Expression var = declareInternalVariable(true);
addAssertion(var, make_shared<map<string, LetBinding>>()); addAssertion(var);
state().trueConstant = parseLiteral(var, make_shared<map<string, LetBinding>>())->variable; state().trueConstant = parseLiteral(var)->variable;
} }
return Literal{_expr.name == "true", *state().trueConstant}; return Literal{_expr.name == "true", *state().trueConstant};
} }
@ -452,11 +461,11 @@ optional<Literal> BooleanLPSolver::parseLiteral(smtutil::Expression const& _expr
return Literal{true, varIndex}; return Literal{true, varIndex};
} }
else if (_expr.name == "not") else if (_expr.name == "not")
return negate(parseLiteralOrReturnEqualBoolean(_expr.arguments.at(0), move(_letBindings))); return negate(parseLiteralOrReturnEqualBoolean(_expr.arguments.at(0)));
else if (_expr.name == "<=" || _expr.name == "<" || _expr.name == "=") else if (_expr.name == "<=" || _expr.name == "<" || _expr.name == "=")
{ {
optional<LinearExpression> left = parseLinearSum(_expr.arguments.at(0), _letBindings); optional<LinearExpression> left = parseLinearSum(_expr.arguments.at(0));
optional<LinearExpression> right = parseLinearSum(_expr.arguments.at(1), _letBindings); optional<LinearExpression> right = parseLinearSum(_expr.arguments.at(1));
if (!left || !right) if (!left || !right)
return {}; return {};
@ -471,9 +480,9 @@ optional<Literal> BooleanLPSolver::parseLiteral(smtutil::Expression const& _expr
return Literal{true, addConditionalConstraint(Constraint{move(data), kind})}; return Literal{true, addConditionalConstraint(Constraint{move(data), kind})};
} }
else if (_expr.name == ">=") else if (_expr.name == ">=")
return parseLiteral(_expr.arguments.at(1) <= _expr.arguments.at(0), move(_letBindings)); return parseLiteral(_expr.arguments.at(1) <= _expr.arguments.at(0));
else if (_expr.name == ">") else if (_expr.name == ">")
return parseLiteral(_expr.arguments.at(1) < _expr.arguments.at(0), move(_letBindings)); return parseLiteral(_expr.arguments.at(1) < _expr.arguments.at(0));
return {}; return {};
} }
@ -516,7 +525,7 @@ Literal BooleanLPSolver::negate(Literal const& _lit)
gt.data *= -1; gt.data *= -1;
Literal gtL{true, addConditionalConstraint(gt)}; Literal gtL{true, addConditionalConstraint(gt)};
Literal equalBoolean = *parseLiteral(declareInternalVariable(true), make_shared<map<string, LetBinding>>()); Literal equalBoolean = *parseLiteral(declareInternalVariable(true));
// a = or(x, y) <=> (-a \/ x \/ y) /\ (a \/ -x) /\ (a \/ -y) // a = or(x, y) <=> (-a \/ x \/ y) /\ (a \/ -x) /\ (a \/ -y)
state().clauses.emplace_back(Clause{vector<Literal>{negate(equalBoolean), ltL, gtL}}); state().clauses.emplace_back(Clause{vector<Literal>{negate(equalBoolean), ltL, gtL}});
state().clauses.emplace_back(Clause{vector<Literal>{equalBoolean, negate(ltL)}}); state().clauses.emplace_back(Clause{vector<Literal>{equalBoolean, negate(ltL)}});
@ -546,37 +555,39 @@ Literal BooleanLPSolver::negate(Literal const& _lit)
return ~_lit; return ~_lit;
} }
Literal BooleanLPSolver::parseLiteralOrReturnEqualBoolean(Expression const& _expr, LetBindings _letBindings) Literal BooleanLPSolver::parseLiteralOrReturnEqualBoolean(Expression const& _expr)
{ {
if (_expr.sort->kind != Kind::Bool) if (_expr.sort->kind != Kind::Bool)
cerr << "expected bool: " << _expr.toString() << endl; cerr << "expected bool: " << _expr.toString() << endl;
solAssert(_expr.sort->kind == Kind::Bool); solAssert(_expr.sort->kind == Kind::Bool);
// TODO when can this fail? // TODO when can this fail?
if (optional<Literal> literal = parseLiteral(_expr, _letBindings)) if (optional<Literal> literal = parseLiteral(_expr))
return *literal; return *literal;
else else
{ {
Literal newBoolean = *parseLiteral(declareInternalVariable(true), _letBindings); Literal newBoolean = *parseLiteral(declareInternalVariable(true));
addBooleanEquality(newBoolean, _expr, _letBindings); addBooleanEquality(newBoolean, _expr);
return newBoolean; return newBoolean;
} }
} }
optional<LinearExpression> BooleanLPSolver::parseLinearSum(smtutil::Expression const& _expr, LetBindings _letBindings) optional<LinearExpression> BooleanLPSolver::parseLinearSum(smtutil::Expression const& _expr)
{ {
if (_expr.name == "let") if (_expr.name == "let")
{ {
addLetBindings(_expr, _letBindings); auto newBindings = addLetBindings(_expr);
return parseLinearSum(_expr.arguments.back(), move(_letBindings)); auto result = parseLinearSum(_expr.arguments.back());
removeLetBindings(newBindings);
return result;
} }
if (_expr.arguments.empty()) if (_expr.arguments.empty())
return parseFactor(_expr, move(_letBindings)); return parseFactor(_expr);
else if (_expr.name == "+") else if (_expr.name == "+")
{ {
optional<LinearExpression> expr = LinearExpression::constant(0); optional<LinearExpression> expr = LinearExpression::constant(0);
for (auto const& arg: _expr.arguments) for (auto const& arg: _expr.arguments)
if (optional<LinearExpression> summand = parseLinearSum(arg, _letBindings)) if (optional<LinearExpression> summand = parseLinearSum(arg))
*expr += move(*summand); *expr += move(*summand);
else else
return std::nullopt; return std::nullopt;
@ -588,13 +599,13 @@ optional<LinearExpression> BooleanLPSolver::parseLinearSum(smtutil::Expression c
optional<LinearExpression> right; optional<LinearExpression> right;
if (_expr.arguments.size() == 2) if (_expr.arguments.size() == 2)
{ {
left = parseLinearSum(_expr.arguments.at(0), _letBindings); left = parseLinearSum(_expr.arguments.at(0));
right = parseLinearSum(_expr.arguments.at(1), _letBindings); right = parseLinearSum(_expr.arguments.at(1));
} }
else if (_expr.arguments.size() == 1) else if (_expr.arguments.size() == 1)
{ {
left = LinearExpression::constant(0); left = LinearExpression::constant(0);
right = parseLinearSum(_expr.arguments.at(0), _letBindings); right = parseLinearSum(_expr.arguments.at(0));
} }
else else
solAssert(false); solAssert(false);
@ -608,13 +619,13 @@ optional<LinearExpression> BooleanLPSolver::parseLinearSum(smtutil::Expression c
// TODO this can also have more than to args // TODO this can also have more than to args
solAssert(_expr.arguments.size() == 2); solAssert(_expr.arguments.size() == 2);
// This will result in nullopt unless one of them is a constant. // This will result in nullopt unless one of them is a constant.
return parseLinearSum(_expr.arguments.at(0), _letBindings) * parseLinearSum(_expr.arguments.at(1), _letBindings); return parseLinearSum(_expr.arguments.at(0)) * parseLinearSum(_expr.arguments.at(1));
} }
else if (_expr.name == "/" || _expr.name == "div") else if (_expr.name == "/" || _expr.name == "div")
{ {
solAssert(_expr.arguments.size() == 2); solAssert(_expr.arguments.size() == 2);
optional<LinearExpression> left = parseLinearSum(_expr.arguments.at(0), _letBindings); optional<LinearExpression> left = parseLinearSum(_expr.arguments.at(0));
optional<LinearExpression> right = parseLinearSum(_expr.arguments.at(1), move(_letBindings)); optional<LinearExpression> right = parseLinearSum(_expr.arguments.at(1));
if (!left || !right || !right->isConstant()) if (!left || !right || !right->isConstant())
return std::nullopt; return std::nullopt;
*left /= right->get(0); *left /= right->get(0);
@ -624,9 +635,9 @@ optional<LinearExpression> BooleanLPSolver::parseLinearSum(smtutil::Expression c
{ {
solAssert(_expr.arguments.size() == 3); solAssert(_expr.arguments.size() == 3);
Expression result = declareInternalVariable(false); Expression result = declareInternalVariable(false);
addAssertion(!_expr.arguments.at(0) || (result == _expr.arguments.at(1)), _letBindings); addAssertion(!_expr.arguments.at(0) || (result == _expr.arguments.at(1)));
addAssertion(_expr.arguments.at(0) || (result == _expr.arguments.at(2)), _letBindings); addAssertion(_expr.arguments.at(0) || (result == _expr.arguments.at(2)));
return parseLinearSum(result, make_shared<map<string, LetBinding>>()); return parseLinearSum(result);
} }
else else
{ {
@ -673,7 +684,7 @@ bool BooleanLPSolver::isLiteral(smtutil::Expression const& _expr) const
_expr.name == "false"; _expr.name == "false";
} }
optional<LinearExpression> BooleanLPSolver::parseFactor(smtutil::Expression const& _expr, LetBindings _letBindings) const optional<LinearExpression> BooleanLPSolver::parseFactor(smtutil::Expression const& _expr) const
{ {
solAssert(_expr.arguments.empty(), ""); solAssert(_expr.arguments.empty(), "");
solAssert(!_expr.name.empty(), ""); solAssert(!_expr.name.empty(), "");
@ -687,11 +698,11 @@ optional<LinearExpression> BooleanLPSolver::parseFactor(smtutil::Expression cons
return LinearExpression::constant(0); return LinearExpression::constant(0);
size_t varIndex = 0; size_t varIndex = 0;
if (_letBindings->count(_expr.name)) if (m_letBindings.count(_expr.name))
{ {
LetBinding binding = _letBindings->at(_expr.name); LetBinding binding = m_letBindings.at(_expr.name);
if (holds_alternative<smtutil::Expression>(binding)) if (holds_alternative<smtutil::Expression>(binding))
return parseFactor(std::get<smtutil::Expression>(binding), move(_letBindings)); return parseFactor(std::get<smtutil::Expression>(binding));
else else
varIndex = std::get<size_t>(binding); varIndex = std::get<size_t>(binding);
} }
@ -771,10 +782,10 @@ size_t BooleanLPSolver::addConditionalConstraint(Constraint _constraint)
return index; return index;
} }
void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expression const& _right, LetBindings _letBindings) void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expression const& _right)
{ {
solAssert(_right.sort->kind == Kind::Bool); solAssert(_right.sort->kind == Kind::Bool);
if (optional<Literal> right = parseLiteral(_right, _letBindings)) if (optional<Literal> right = parseLiteral(_right))
{ {
// includes: not, <=, <, >=, >, =, boolean variables. // includes: not, <=, <, >=, >, =, boolean variables.
// a = b <=> (-a \/ b) /\ (a \/ -b) // a = b <=> (-a \/ b) /\ (a \/ -b)
@ -784,15 +795,14 @@ void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expressi
state().clauses.emplace_back(Clause{vector<Literal>{_left, negRight}}); state().clauses.emplace_back(Clause{vector<Literal>{_left, negRight}});
} }
// TODO This parses twice // TODO This parses twice
else if (_right.name == "=" && parseLinearSum(_right.arguments.at(0), _letBindings) && parseLinearSum(_right.arguments.at(1), _letBindings)) else if (_right.name == "=" && parseLinearSum(_right.arguments.at(0)) && parseLinearSum(_right.arguments.at(1)))
{ {
solAssert(false, "This should be covered by the case above"); solAssert(false, "This should be covered by the case above");
// a = (x = y) <=> a = (x <= y && x >= y) // a = (x = y) <=> a = (x <= y && x >= y)
addBooleanEquality( addBooleanEquality(
_left, _left,
_right.arguments.at(0) <= _right.arguments.at(1) && _right.arguments.at(0) <= _right.arguments.at(1) &&
_right.arguments.at(1) <= _right.arguments.at(0), _right.arguments.at(1) <= _right.arguments.at(0)
move(_letBindings)
); );
} }
else if (_right.name == "ite") else if (_right.name == "ite")
@ -809,9 +819,9 @@ void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expressi
// (-c || _left = x) && (c || _left = y) // (-c || _left = x) && (c || _left = y)
// (-c || ((-_left || x) && (_left || -x))) && ... // (-c || ((-_left || x) && (_left || -x))) && ...
// (-c || -_left || x) && (-c || _left || -x) && ... // (-c || -_left || x) && (-c || _left || -x) && ...
Literal c = parseLiteralOrReturnEqualBoolean(_right.arguments.at(0), _letBindings); Literal c = parseLiteralOrReturnEqualBoolean(_right.arguments.at(0));
Literal x = parseLiteralOrReturnEqualBoolean(_right.arguments.at(1), _letBindings); Literal x = parseLiteralOrReturnEqualBoolean(_right.arguments.at(1));
Literal y = parseLiteralOrReturnEqualBoolean(_right.arguments.at(2), _letBindings); Literal y = parseLiteralOrReturnEqualBoolean(_right.arguments.at(2));
state().clauses.emplace_back(Clause{vector<Literal>{negate(c), negate(_left), x}}); state().clauses.emplace_back(Clause{vector<Literal>{negate(c), negate(_left), x}});
state().clauses.emplace_back(Clause{vector<Literal>{negate(c), _left, negate(x)}}); state().clauses.emplace_back(Clause{vector<Literal>{negate(c), _left, negate(x)}});
state().clauses.emplace_back(Clause{vector<Literal>{c, negate(_left), y}}); state().clauses.emplace_back(Clause{vector<Literal>{c, negate(_left), y}});
@ -819,7 +829,7 @@ void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expressi
} }
else else
{ {
Literal a = parseLiteralOrReturnEqualBoolean(_right.arguments.at(0), _letBindings); Literal a = parseLiteralOrReturnEqualBoolean(_right.arguments.at(0));
Literal b; Literal b;
if (_right.arguments.size() > 2) if (_right.arguments.size() > 2)
{ {
@ -827,10 +837,10 @@ void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expressi
// Reduce "a and b and c and ..." to "a and (b and c and ...)" // Reduce "a and b and c and ..." to "a and (b and c and ...)"
smtutil::Expression rightSuffix = _right; smtutil::Expression rightSuffix = _right;
rightSuffix.arguments.erase(rightSuffix.arguments.begin()); rightSuffix.arguments.erase(rightSuffix.arguments.begin());
b = parseLiteralOrReturnEqualBoolean(rightSuffix, _letBindings); b = parseLiteralOrReturnEqualBoolean(rightSuffix);
} }
else else
b = parseLiteralOrReturnEqualBoolean(_right.arguments.at(1), _letBindings); b = parseLiteralOrReturnEqualBoolean(_right.arguments.at(1));
if (_right.name == "and") if (_right.name == "and")
{ {
@ -850,7 +860,7 @@ void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expressi
{ {
solAssert(_right.arguments.size() == 2); solAssert(_right.arguments.size() == 2);
// a = (x => y) <=> a = or(-x, y) // a = (x => y) <=> a = or(-x, y)
addBooleanEquality(_left, !_right.arguments.at(0) || _right.arguments.at(1), move(_letBindings)); addBooleanEquality(_left, !_right.arguments.at(0) || _right.arguments.at(1));
} }
else if (_right.name == "=") else if (_right.name == "=")
{ {
@ -863,7 +873,7 @@ void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expressi
else if (_right.name == "xor") else if (_right.name == "xor")
{ {
solAssert(_right.arguments.size() == 2); solAssert(_right.arguments.size() == 2);
addBooleanEquality(negate(_left), _right.arguments.at(0) == _right.arguments.at(1), move(_letBindings)); addBooleanEquality(negate(_left), _right.arguments.at(0) == _right.arguments.at(1));
} }
else else
solAssert(false, "Unsupported operation: " + _right.name); solAssert(false, "Unsupported operation: " + _right.name);

View File

@ -76,10 +76,7 @@ public:
void declareVariable(std::string const& _name, smtutil::SortPointer const& _sort) override; void declareVariable(std::string const& _name, smtutil::SortPointer const& _sort) override;
void addAssertion(smtutil::Expression const& _expr) override void addAssertion(smtutil::Expression const& _expr);
{
addAssertion(_expr, std::make_shared<std::map<std::string, LetBinding>>());
}
std::pair<smtutil::CheckResult, std::vector<std::string>> std::pair<smtutil::CheckResult, std::vector<std::string>>
check(std::vector<smtutil::Expression> const& _expressionsToEvaluate) override; check(std::vector<smtutil::Expression> const& _expressionsToEvaluate) override;
@ -89,33 +86,28 @@ public:
private: private:
using rational = boost::rational<bigint>; using rational = boost::rational<bigint>;
using LetBinding = std::variant<size_t, smtutil::Expression>; using LetBinding = std::variant<size_t, smtutil::Expression>;
using LetBindings = std::shared_ptr<std::map<std::string, LetBinding>>;
void addAssertion(
smtutil::Expression const& _expr,
LetBindings _letBindings
);
smtutil::Expression declareInternalVariable(bool _boolean); smtutil::Expression declareInternalVariable(bool _boolean);
void declareVariable(std::string const& _name, bool _boolean); void declareVariable(std::string const& _name, bool _boolean);
/// Handles a "let" expression and adds the bindings to @a _letBindings. /// Handles a "let" expression and adds the bindings to @a m_letBindings.
void addLetBindings(smtutil::Expression const& _let, LetBindings& _letBindings); std::map<std::string, LetBinding> addLetBindings(smtutil::Expression const& _let);
void removeLetBindings(std::map<std::string, LetBinding> const& _toRemove);
/// Parses an expression of sort bool and returns a literal. /// Parses an expression of sort bool and returns a literal.
std::optional<Literal> parseLiteral(smtutil::Expression const& _expr, LetBindings _letBindings); std::optional<Literal> parseLiteral(smtutil::Expression const& _expr);
Literal negate(Literal const& _lit); Literal negate(Literal const& _lit);
Literal parseLiteralOrReturnEqualBoolean(smtutil::Expression const& _expr, LetBindings _letBindings); Literal parseLiteralOrReturnEqualBoolean(smtutil::Expression const& _expr);
/// Parses the expression and expects a linear sum of variables. /// Parses the expression and expects a linear sum of variables.
/// Returns a vector with the first element being the constant and the /// Returns a vector with the first element being the constant and the
/// other elements the factors for the respective variables. /// other elements the factors for the respective variables.
/// If the expression cannot be properly parsed or is not linear, /// If the expression cannot be properly parsed or is not linear,
/// returns an empty vector. /// returns an empty vector.
std::optional<LinearExpression> parseLinearSum(smtutil::Expression const& _expression, LetBindings _letBindings); std::optional<LinearExpression> parseLinearSum(smtutil::Expression const& _expression);
bool isLiteral(smtutil::Expression const& _expression) const; bool isLiteral(smtutil::Expression const& _expression) const;
std::optional<LinearExpression> parseFactor(smtutil::Expression const& _expression, LetBindings _letBindings) const; std::optional<LinearExpression> parseFactor(smtutil::Expression const& _expression) const;
bool tryAddDirectBounds(Constraint const& _constraint); bool tryAddDirectBounds(Constraint const& _constraint);
void addUpperBound(size_t _index, RationalWithDelta _value); void addUpperBound(size_t _index, RationalWithDelta _value);
@ -123,7 +115,7 @@ private:
size_t addConditionalConstraint(Constraint _constraint); size_t addConditionalConstraint(Constraint _constraint);
void addBooleanEquality(Literal const& _left, smtutil::Expression const& _right, LetBindings _letBindings); void addBooleanEquality(Literal const& _left, smtutil::Expression const& _right);
//std::string toString(std::vector<SolvingState::Bounds> const& _bounds) const; //std::string toString(std::vector<SolvingState::Bounds> const& _bounds) const;
std::string toString(Clause const& _clause) const; std::string toString(Clause const& _clause) const;
@ -142,6 +134,9 @@ private:
/// Stack of state, to allow for push()/pop(). /// Stack of state, to allow for push()/pop().
std::vector<State> m_state{{State{}}}; std::vector<State> m_state{{State{}}};
/// Current let bindings - only valid/used while parsing a statement.
std::map<std::string, LetBinding> m_letBindings;
}; };