Copy-on-write let bindings.

This commit is contained in:
chriseth 2022-05-30 19:37:20 +02:00
parent 0a4f4f6f55
commit 0f4cc05667
3 changed files with 31 additions and 32 deletions

View File

@ -232,7 +232,7 @@ string BooleanLPSolver::toString() const
return result;
}
void BooleanLPSolver::addAssertion(Expression const& _expr, map<string, LetBinding> _letBindings)
void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBindings)
{
#ifdef DEBUG
cerr << "adding assertion" << endl;
@ -242,9 +242,9 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, map<string, LetBindi
if (_expr.arguments.empty())
{
size_t varIndex = 0;
if (_letBindings.count(_expr.name))
if (_letBindings->count(_expr.name))
{
LetBinding binding = _letBindings.at(_expr.name);
LetBinding binding = _letBindings->at(_expr.name);
if (holds_alternative<smtutil::Expression>(binding))
{
addAssertion(std::get<smtutil::Expression>(binding), move(_letBindings));
@ -276,7 +276,7 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, map<string, LetBindi
addBooleanEquality(*parseLiteral(_expr.arguments.at(1), _letBindings), _expr.arguments.at(0), _letBindings);
else
{
Literal newBoolean = *parseLiteral(declareInternalVariable(true), {});
Literal newBoolean = *parseLiteral(declareInternalVariable(true), make_shared<map<string, LetBinding>>());
addBooleanEquality(newBoolean, _expr.arguments.at(0), _letBindings);
addBooleanEquality(newBoolean, _expr.arguments.at(1), _letBindings);
}
@ -382,7 +382,7 @@ void BooleanLPSolver::declareVariable(string const& _name, bool _boolean)
resizeAndSet(state().isBooleanVariable, index, _boolean);
}
void BooleanLPSolver::addLetBindings(Expression const& _let, map<string, LetBinding>& _letBindings)
void BooleanLPSolver::addLetBindings(Expression const& _let, LetBindings& _letBindings)
{
map<string, LetBinding> newBindings;
solAssert(_let.name == "let");
@ -399,11 +399,12 @@ void BooleanLPSolver::addLetBindings(Expression const& _let, map<string, LetBind
addAssertion(var == binding.arguments.at(0), _letBindings);
}
}
_letBindings = make_shared<std::map<std::string, LetBinding>>(*_letBindings);
for (auto& [name, value]: newBindings)
_letBindings.insert({name, move(value)});
_letBindings->insert({name, move(value)});
}
optional<Literal> BooleanLPSolver::parseLiteral(smtutil::Expression const& _expr, map<string, LetBinding> _letBindings)
optional<Literal> BooleanLPSolver::parseLiteral(smtutil::Expression const& _expr, LetBindings _letBindings)
{
// TODO constanst true/false?
@ -416,9 +417,9 @@ optional<Literal> BooleanLPSolver::parseLiteral(smtutil::Expression const& _expr
if (_expr.arguments.empty())
{
size_t varIndex = 0;
if (_letBindings.count(_expr.name))
if (_letBindings->count(_expr.name))
{
LetBinding binding = _letBindings.at(_expr.name);
LetBinding binding = _letBindings->at(_expr.name);
if (holds_alternative<smtutil::Expression>(binding))
return parseLiteral(std::get<smtutil::Expression>(binding), move(_letBindings));
else
@ -494,7 +495,7 @@ Literal BooleanLPSolver::negate(Literal const& _lit)
gt.data *= -1;
Literal gtL{true, addConditionalConstraint(gt)};
Literal equalBoolean = *parseLiteral(declareInternalVariable(true), {});
Literal equalBoolean = *parseLiteral(declareInternalVariable(true), make_shared<map<string, LetBinding>>());
// a = or(x, y) <=> (-a \/ x \/ y) /\ (a \/ -x) /\ (a \/ -y)
state().clauses.emplace_back(Clause{vector<Literal>{negate(equalBoolean), ltL, gtL}});
state().clauses.emplace_back(Clause{vector<Literal>{equalBoolean, negate(ltL)}});
@ -524,7 +525,7 @@ Literal BooleanLPSolver::negate(Literal const& _lit)
return ~_lit;
}
Literal BooleanLPSolver::parseLiteralOrReturnEqualBoolean(Expression const& _expr, map<string, LetBinding> _letBindings)
Literal BooleanLPSolver::parseLiteralOrReturnEqualBoolean(Expression const& _expr, LetBindings _letBindings)
{
if (_expr.sort->kind != Kind::Bool)
cerr << "expected bool: " << _expr.toString() << endl;
@ -540,7 +541,7 @@ Literal BooleanLPSolver::parseLiteralOrReturnEqualBoolean(Expression const& _exp
}
}
optional<LinearExpression> BooleanLPSolver::parseLinearSum(smtutil::Expression const& _expr, map<string, LetBinding> _letBindings)
optional<LinearExpression> BooleanLPSolver::parseLinearSum(smtutil::Expression const& _expr, LetBindings _letBindings)
{
if (_expr.name == "let")
{
@ -604,7 +605,7 @@ optional<LinearExpression> BooleanLPSolver::parseLinearSum(smtutil::Expression c
Expression result = declareInternalVariable(false);
addAssertion(!_expr.arguments.at(0) || (result == _expr.arguments.at(1)), _letBindings);
addAssertion(_expr.arguments.at(0) || (result == _expr.arguments.at(2)), _letBindings);
return parseLinearSum(result, {});
return parseLinearSum(result, make_shared<map<string, LetBinding>>());
}
else
{
@ -625,7 +626,7 @@ bool BooleanLPSolver::isLiteral(smtutil::Expression const& _expr) const
_expr.name == "false";
}
optional<LinearExpression> BooleanLPSolver::parseFactor(smtutil::Expression const& _expr, map<string, LetBinding> _letBindings) const
optional<LinearExpression> BooleanLPSolver::parseFactor(smtutil::Expression const& _expr, LetBindings _letBindings) const
{
solAssert(_expr.arguments.empty(), "");
solAssert(!_expr.name.empty(), "");
@ -639,9 +640,9 @@ optional<LinearExpression> BooleanLPSolver::parseFactor(smtutil::Expression cons
return LinearExpression::constant(0);
size_t varIndex = 0;
if (_letBindings.count(_expr.name))
if (_letBindings->count(_expr.name))
{
LetBinding binding = _letBindings.at(_expr.name);
LetBinding binding = _letBindings->at(_expr.name);
if (holds_alternative<smtutil::Expression>(binding))
return parseFactor(std::get<smtutil::Expression>(binding), move(_letBindings));
else
@ -723,7 +724,7 @@ size_t BooleanLPSolver::addConditionalConstraint(Constraint _constraint)
return index;
}
void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expression const& _right, map<string, LetBinding> _letBindings)
void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expression const& _right, LetBindings _letBindings)
{
solAssert(_right.sort->kind == Kind::Bool);
if (optional<Literal> right = parseLiteral(_right, _letBindings))

View File

@ -74,7 +74,10 @@ public:
void declareVariable(std::string const& _name, smtutil::SortPointer const& _sort) override;
void addAssertion(smtutil::Expression const& _expr) override { addAssertion(_expr, {}); }
void addAssertion(smtutil::Expression const& _expr) override
{
addAssertion(_expr, std::make_shared<std::map<std::string, LetBinding>>());
}
std::pair<smtutil::CheckResult, std::vector<std::string>>
check(std::vector<smtutil::Expression> const& _expressionsToEvaluate) override;
@ -84,38 +87,33 @@ public:
private:
using rational = boost::rational<bigint>;
using LetBinding = std::variant<size_t, smtutil::Expression>;
using LetBindings = std::shared_ptr<std::map<std::string, LetBinding>>;
void addAssertion(
smtutil::Expression const& _expr,
std::map<std::string, LetBinding> _letBindings
LetBindings _letBindings
);
smtutil::Expression declareInternalVariable(bool _boolean);
void declareVariable(std::string const& _name, bool _boolean);
/// Handles a "let" expression and adds the bindings to @a _letBindings.
void addLetBindings(smtutil::Expression const& _let, std::map<std::string, LetBinding>& _letBindings);
void addLetBindings(smtutil::Expression const& _let, LetBindings& _letBindings);
/// Parses an expression of sort bool and returns a literal.
std::optional<Literal> parseLiteral(
smtutil::Expression const& _expr,
std::map<std::string, LetBinding> _letBindings
);
std::optional<Literal> parseLiteral(smtutil::Expression const& _expr, LetBindings _letBindings);
Literal negate(Literal const& _lit);
Literal parseLiteralOrReturnEqualBoolean(
smtutil::Expression const& _expr,
std::map<std::string, LetBinding> _letBindings
);
Literal parseLiteralOrReturnEqualBoolean(smtutil::Expression const& _expr, LetBindings _letBindings);
/// Parses the expression and expects a linear sum of variables.
/// Returns a vector with the first element being the constant and the
/// other elements the factors for the respective variables.
/// If the expression cannot be properly parsed or is not linear,
/// returns an empty vector.
std::optional<LinearExpression> parseLinearSum(smtutil::Expression const& _expression, std::map<std::string, LetBinding> _letBindings);
std::optional<LinearExpression> parseLinearSum(smtutil::Expression const& _expression, LetBindings _letBindings);
bool isLiteral(smtutil::Expression const& _expression) const;
std::optional<LinearExpression> parseFactor(smtutil::Expression const& _expression, std::map<std::string, LetBinding> _letBindings) const;
std::optional<LinearExpression> parseFactor(smtutil::Expression const& _expression, LetBindings _letBindings) const;
bool tryAddDirectBounds(Constraint const& _constraint);
void addUpperBound(size_t _index, RationalWithDelta _value);
@ -123,7 +121,7 @@ private:
size_t addConditionalConstraint(Constraint _constraint);
void addBooleanEquality(Literal const& _left, smtutil::Expression const& _right, std::map<std::string, LetBinding> _letBindings);
void addBooleanEquality(Literal const& _left, smtutil::Expression const& _right, LetBindings _letBindings);
//std::string toString(std::vector<SolvingState::Bounds> const& _bounds) const;
std::string toString(Clause const& _clause) const;

View File

@ -298,7 +298,7 @@ public:
void addLowerBound(size_t _variable, RationalWithDelta _bound, std::optional<size_t> _reason = std::nullopt);
void addUpperBound(size_t _variable, RationalWithDelta _bound, std::optional<size_t> _reason = std::nullopt);
std::pair<LPResult, ReasonSet>check();
std::pair<LPResult, ReasonSet> check();
std::string toString() const;
std::map<std::string, rational> model() const;