From 87bffbba98ba92a48a1ef5233e2a8569280b68ee Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 27 Jun 2022 15:04:19 +0200 Subject: [PATCH] Avoid copying let bindings. --- libsolutil/BooleanLP.cpp | 160 +++++++++++++++++++++------------------ libsolutil/BooleanLP.h | 29 +++---- 2 files changed, 97 insertions(+), 92 deletions(-) diff --git a/libsolutil/BooleanLP.cpp b/libsolutil/BooleanLP.cpp index 595de5080..dea26412e 100644 --- a/libsolutil/BooleanLP.cpp +++ b/libsolutil/BooleanLP.cpp @@ -1,4 +1,4 @@ -/* +/* This file is part of solidity. solidity is free software: you can redistribute it and/or modify @@ -232,7 +232,7 @@ string BooleanLPSolver::toString() const return result; } -void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBindings) +void BooleanLPSolver::addAssertion(Expression const& _expr) { #ifdef DEBUG cerr << "adding assertion" << endl; @@ -247,12 +247,12 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind solAssert(false, "Adding false as top-level assertion."); size_t varIndex = 0; - if (_letBindings->count(_expr.name)) + if (m_letBindings.count(_expr.name)) { - LetBinding binding = _letBindings->at(_expr.name); + LetBinding binding = m_letBindings.at(_expr.name); if (holds_alternative(binding)) { - addAssertion(std::get(binding), move(_letBindings)); + addAssertion(std::get(binding)); return; } else @@ -266,8 +266,9 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind } else if (_expr.name == "let") { - addLetBindings(_expr, _letBindings); - addAssertion(_expr.arguments.back(), move(_letBindings)); + auto newBindings = addLetBindings(_expr); + addAssertion(_expr.arguments.back()); + removeLetBindings(newBindings); } else if (_expr.name == "=") { @@ -276,21 +277,21 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind if (_expr.arguments.at(0).sort->kind == Kind::Bool) { if (_expr.arguments.at(0).arguments.empty() && isBooleanVariable(_expr.arguments.at(0).name)) - addBooleanEquality(*parseLiteral(_expr.arguments.at(0), _letBindings), _expr.arguments.at(1), _letBindings); + addBooleanEquality(*parseLiteral(_expr.arguments.at(0)), _expr.arguments.at(1)); else if (_expr.arguments.at(1).arguments.empty() && isBooleanVariable(_expr.arguments.at(1).name)) - addBooleanEquality(*parseLiteral(_expr.arguments.at(1), _letBindings), _expr.arguments.at(0), _letBindings); + addBooleanEquality(*parseLiteral(_expr.arguments.at(1)), _expr.arguments.at(0)); else { - Literal newBoolean = *parseLiteral(declareInternalVariable(true), make_shared>()); - addBooleanEquality(newBoolean, _expr.arguments.at(0), _letBindings); - addBooleanEquality(newBoolean, _expr.arguments.at(1), _letBindings); + Literal newBoolean = *parseLiteral(declareInternalVariable(true)); + addBooleanEquality(newBoolean, _expr.arguments.at(0)); + addBooleanEquality(newBoolean, _expr.arguments.at(1)); } } else if (_expr.arguments.at(0).sort->kind == Kind::Int || _expr.arguments.at(0).sort->kind == Kind::Real) { // Try to see if both sides are linear. - optional left = parseLinearSum(_expr.arguments.at(0), _letBindings); - optional right = parseLinearSum(_expr.arguments.at(1), _letBindings); + optional left = parseLinearSum(_expr.arguments.at(0)); + optional right = parseLinearSum(_expr.arguments.at(1)); if (left && right) { LinearExpression data = *left - *right; @@ -314,7 +315,7 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind } else if (_expr.name == "and") for (auto const& arg: _expr.arguments) - addAssertion(arg, _letBindings); + addAssertion(arg); else if (_expr.name == "or") { if (_expr.arguments.size() == 1) @@ -324,7 +325,7 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind vector literals; // We could try to parse a full clause here instead. for (auto const& arg: _expr.arguments) - literals.emplace_back(parseLiteralOrReturnEqualBoolean(arg, _letBindings)); + literals.emplace_back(parseLiteralOrReturnEqualBoolean(arg)); state().clauses.emplace_back(Clause{move(literals)}); } } @@ -338,19 +339,19 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind { solAssert(_expr.arguments.size() == 1); // TODO can we still try to add a fixed constraint? - Literal l = negate(parseLiteralOrReturnEqualBoolean(_expr.arguments.at(0), move(_letBindings))); + Literal l = negate(parseLiteralOrReturnEqualBoolean(_expr.arguments.at(0))); state().clauses.emplace_back(Clause{vector{l}}); } else if (_expr.name == "=>") { solAssert(_expr.arguments.size() == 2); - addAssertion(!_expr.arguments.at(0) || _expr.arguments.at(1), move(_letBindings)); + addAssertion(!_expr.arguments.at(0) || _expr.arguments.at(1)); } else if (_expr.name == "<=" || _expr.name == "<") { solAssert(_expr.arguments.size() == 2); - optional left = parseLinearSum(_expr.arguments.at(0), _letBindings); - optional right = parseLinearSum(_expr.arguments.at(1), _letBindings); + optional left = parseLinearSum(_expr.arguments.at(0)); + optional right = parseLinearSum(_expr.arguments.at(1)); solAssert(left && right); LinearExpression data = *left - *right; data[0] *= -1; @@ -362,12 +363,12 @@ void BooleanLPSolver::addAssertion(Expression const& _expr, LetBindings _letBind else if (_expr.name == ">=") { solAssert(_expr.arguments.size() == 2); - addAssertion(_expr.arguments.at(1) <= _expr.arguments.at(0), move(_letBindings)); + addAssertion(_expr.arguments.at(1) <= _expr.arguments.at(0)); } else if (_expr.name == ">") { solAssert(_expr.arguments.size() == 2); - addAssertion(_expr.arguments.at(1) < _expr.arguments.at(0), move(_letBindings)); + addAssertion(_expr.arguments.at(1) < _expr.arguments.at(0)); } else { @@ -392,7 +393,7 @@ void BooleanLPSolver::declareVariable(string const& _name, bool _boolean) resizeAndSet(state().isBooleanVariable, index, _boolean); } -void BooleanLPSolver::addLetBindings(Expression const& _let, LetBindings& _letBindings) +map BooleanLPSolver::addLetBindings(Expression const& _let) { map newBindings; solAssert(_let.name == "let"); @@ -406,30 +407,38 @@ void BooleanLPSolver::addLetBindings(Expression const& _let, LetBindings& _letBi { Expression var = declareInternalVariable(isBool); newBindings.insert({binding.name, state().variables.at(var.name)}); - addAssertion(var == binding.arguments.at(0), _letBindings); + addAssertion(var == binding.arguments.at(0)); } } - _letBindings = make_shared>(*_letBindings); for (auto& [name, value]: newBindings) - _letBindings->insert({name, move(value)}); + m_letBindings.insert({name, move(value)}); + return newBindings; } -optional BooleanLPSolver::parseLiteral(smtutil::Expression const& _expr, LetBindings _letBindings) +void BooleanLPSolver::removeLetBindings(map const& _bindings) +{ + for (auto& [name, value]: _bindings) + m_letBindings.erase(name); +} + +optional BooleanLPSolver::parseLiteral(smtutil::Expression const& _expr) { if (_expr.name == "let") { - addLetBindings(_expr, _letBindings); - return parseLiteral(_expr.arguments.back(), move(_letBindings)); + map newBindings = addLetBindings(_expr); + optional literal = parseLiteral(_expr.arguments.back()); + removeLetBindings(newBindings); + return literal; } if (_expr.arguments.empty()) { size_t varIndex = 0; - if (_letBindings->count(_expr.name)) + if (m_letBindings.count(_expr.name)) { - LetBinding binding = _letBindings->at(_expr.name); + LetBinding binding = m_letBindings.at(_expr.name); if (holds_alternative(binding)) - return parseLiteral(std::get(binding), move(_letBindings)); + return parseLiteral(std::get(binding)); else varIndex = std::get(binding); } @@ -441,8 +450,8 @@ optional BooleanLPSolver::parseLiteral(smtutil::Expression const& _expr if (!state().trueConstant) { Expression var = declareInternalVariable(true); - addAssertion(var, make_shared>()); - state().trueConstant = parseLiteral(var, make_shared>())->variable; + addAssertion(var); + state().trueConstant = parseLiteral(var)->variable; } return Literal{_expr.name == "true", *state().trueConstant}; } @@ -452,11 +461,11 @@ optional BooleanLPSolver::parseLiteral(smtutil::Expression const& _expr return Literal{true, varIndex}; } else if (_expr.name == "not") - return negate(parseLiteralOrReturnEqualBoolean(_expr.arguments.at(0), move(_letBindings))); + return negate(parseLiteralOrReturnEqualBoolean(_expr.arguments.at(0))); else if (_expr.name == "<=" || _expr.name == "<" || _expr.name == "=") { - optional left = parseLinearSum(_expr.arguments.at(0), _letBindings); - optional right = parseLinearSum(_expr.arguments.at(1), _letBindings); + optional left = parseLinearSum(_expr.arguments.at(0)); + optional right = parseLinearSum(_expr.arguments.at(1)); if (!left || !right) return {}; @@ -471,9 +480,9 @@ optional BooleanLPSolver::parseLiteral(smtutil::Expression const& _expr return Literal{true, addConditionalConstraint(Constraint{move(data), kind})}; } else if (_expr.name == ">=") - return parseLiteral(_expr.arguments.at(1) <= _expr.arguments.at(0), move(_letBindings)); + return parseLiteral(_expr.arguments.at(1) <= _expr.arguments.at(0)); else if (_expr.name == ">") - return parseLiteral(_expr.arguments.at(1) < _expr.arguments.at(0), move(_letBindings)); + return parseLiteral(_expr.arguments.at(1) < _expr.arguments.at(0)); return {}; } @@ -516,7 +525,7 @@ Literal BooleanLPSolver::negate(Literal const& _lit) gt.data *= -1; Literal gtL{true, addConditionalConstraint(gt)}; - Literal equalBoolean = *parseLiteral(declareInternalVariable(true), make_shared>()); + Literal equalBoolean = *parseLiteral(declareInternalVariable(true)); // a = or(x, y) <=> (-a \/ x \/ y) /\ (a \/ -x) /\ (a \/ -y) state().clauses.emplace_back(Clause{vector{negate(equalBoolean), ltL, gtL}}); state().clauses.emplace_back(Clause{vector{equalBoolean, negate(ltL)}}); @@ -546,37 +555,39 @@ Literal BooleanLPSolver::negate(Literal const& _lit) return ~_lit; } -Literal BooleanLPSolver::parseLiteralOrReturnEqualBoolean(Expression const& _expr, LetBindings _letBindings) +Literal BooleanLPSolver::parseLiteralOrReturnEqualBoolean(Expression const& _expr) { if (_expr.sort->kind != Kind::Bool) cerr << "expected bool: " << _expr.toString() << endl; solAssert(_expr.sort->kind == Kind::Bool); // TODO when can this fail? - if (optional literal = parseLiteral(_expr, _letBindings)) + if (optional literal = parseLiteral(_expr)) return *literal; else { - Literal newBoolean = *parseLiteral(declareInternalVariable(true), _letBindings); - addBooleanEquality(newBoolean, _expr, _letBindings); + Literal newBoolean = *parseLiteral(declareInternalVariable(true)); + addBooleanEquality(newBoolean, _expr); return newBoolean; } } -optional BooleanLPSolver::parseLinearSum(smtutil::Expression const& _expr, LetBindings _letBindings) +optional BooleanLPSolver::parseLinearSum(smtutil::Expression const& _expr) { if (_expr.name == "let") { - addLetBindings(_expr, _letBindings); - return parseLinearSum(_expr.arguments.back(), move(_letBindings)); + auto newBindings = addLetBindings(_expr); + auto result = parseLinearSum(_expr.arguments.back()); + removeLetBindings(newBindings); + return result; } if (_expr.arguments.empty()) - return parseFactor(_expr, move(_letBindings)); + return parseFactor(_expr); else if (_expr.name == "+") { optional expr = LinearExpression::constant(0); for (auto const& arg: _expr.arguments) - if (optional summand = parseLinearSum(arg, _letBindings)) + if (optional summand = parseLinearSum(arg)) *expr += move(*summand); else return std::nullopt; @@ -588,13 +599,13 @@ optional BooleanLPSolver::parseLinearSum(smtutil::Expression c optional right; if (_expr.arguments.size() == 2) { - left = parseLinearSum(_expr.arguments.at(0), _letBindings); - right = parseLinearSum(_expr.arguments.at(1), _letBindings); + left = parseLinearSum(_expr.arguments.at(0)); + right = parseLinearSum(_expr.arguments.at(1)); } else if (_expr.arguments.size() == 1) { left = LinearExpression::constant(0); - right = parseLinearSum(_expr.arguments.at(0), _letBindings); + right = parseLinearSum(_expr.arguments.at(0)); } else solAssert(false); @@ -608,13 +619,13 @@ optional BooleanLPSolver::parseLinearSum(smtutil::Expression c // TODO this can also have more than to args solAssert(_expr.arguments.size() == 2); // This will result in nullopt unless one of them is a constant. - return parseLinearSum(_expr.arguments.at(0), _letBindings) * parseLinearSum(_expr.arguments.at(1), _letBindings); + return parseLinearSum(_expr.arguments.at(0)) * parseLinearSum(_expr.arguments.at(1)); } else if (_expr.name == "/" || _expr.name == "div") { solAssert(_expr.arguments.size() == 2); - optional left = parseLinearSum(_expr.arguments.at(0), _letBindings); - optional right = parseLinearSum(_expr.arguments.at(1), move(_letBindings)); + optional left = parseLinearSum(_expr.arguments.at(0)); + optional right = parseLinearSum(_expr.arguments.at(1)); if (!left || !right || !right->isConstant()) return std::nullopt; *left /= right->get(0); @@ -624,9 +635,9 @@ optional BooleanLPSolver::parseLinearSum(smtutil::Expression c { solAssert(_expr.arguments.size() == 3); Expression result = declareInternalVariable(false); - addAssertion(!_expr.arguments.at(0) || (result == _expr.arguments.at(1)), _letBindings); - addAssertion(_expr.arguments.at(0) || (result == _expr.arguments.at(2)), _letBindings); - return parseLinearSum(result, make_shared>()); + addAssertion(!_expr.arguments.at(0) || (result == _expr.arguments.at(1))); + addAssertion(_expr.arguments.at(0) || (result == _expr.arguments.at(2))); + return parseLinearSum(result); } else { @@ -673,7 +684,7 @@ bool BooleanLPSolver::isLiteral(smtutil::Expression const& _expr) const _expr.name == "false"; } -optional BooleanLPSolver::parseFactor(smtutil::Expression const& _expr, LetBindings _letBindings) const +optional BooleanLPSolver::parseFactor(smtutil::Expression const& _expr) const { solAssert(_expr.arguments.empty(), ""); solAssert(!_expr.name.empty(), ""); @@ -687,11 +698,11 @@ optional BooleanLPSolver::parseFactor(smtutil::Expression cons return LinearExpression::constant(0); size_t varIndex = 0; - if (_letBindings->count(_expr.name)) + if (m_letBindings.count(_expr.name)) { - LetBinding binding = _letBindings->at(_expr.name); + LetBinding binding = m_letBindings.at(_expr.name); if (holds_alternative(binding)) - return parseFactor(std::get(binding), move(_letBindings)); + return parseFactor(std::get(binding)); else varIndex = std::get(binding); } @@ -771,10 +782,10 @@ size_t BooleanLPSolver::addConditionalConstraint(Constraint _constraint) return index; } -void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expression const& _right, LetBindings _letBindings) +void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expression const& _right) { solAssert(_right.sort->kind == Kind::Bool); - if (optional right = parseLiteral(_right, _letBindings)) + if (optional right = parseLiteral(_right)) { // includes: not, <=, <, >=, >, =, boolean variables. // a = b <=> (-a \/ b) /\ (a \/ -b) @@ -784,15 +795,14 @@ void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expressi state().clauses.emplace_back(Clause{vector{_left, negRight}}); } // TODO This parses twice - else if (_right.name == "=" && parseLinearSum(_right.arguments.at(0), _letBindings) && parseLinearSum(_right.arguments.at(1), _letBindings)) + else if (_right.name == "=" && parseLinearSum(_right.arguments.at(0)) && parseLinearSum(_right.arguments.at(1))) { solAssert(false, "This should be covered by the case above"); // a = (x = y) <=> a = (x <= y && x >= y) addBooleanEquality( _left, _right.arguments.at(0) <= _right.arguments.at(1) && - _right.arguments.at(1) <= _right.arguments.at(0), - move(_letBindings) + _right.arguments.at(1) <= _right.arguments.at(0) ); } else if (_right.name == "ite") @@ -809,9 +819,9 @@ void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expressi // (-c || _left = x) && (c || _left = y) // (-c || ((-_left || x) && (_left || -x))) && ... // (-c || -_left || x) && (-c || _left || -x) && ... - Literal c = parseLiteralOrReturnEqualBoolean(_right.arguments.at(0), _letBindings); - Literal x = parseLiteralOrReturnEqualBoolean(_right.arguments.at(1), _letBindings); - Literal y = parseLiteralOrReturnEqualBoolean(_right.arguments.at(2), _letBindings); + Literal c = parseLiteralOrReturnEqualBoolean(_right.arguments.at(0)); + Literal x = parseLiteralOrReturnEqualBoolean(_right.arguments.at(1)); + Literal y = parseLiteralOrReturnEqualBoolean(_right.arguments.at(2)); state().clauses.emplace_back(Clause{vector{negate(c), negate(_left), x}}); state().clauses.emplace_back(Clause{vector{negate(c), _left, negate(x)}}); state().clauses.emplace_back(Clause{vector{c, negate(_left), y}}); @@ -819,7 +829,7 @@ void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expressi } else { - Literal a = parseLiteralOrReturnEqualBoolean(_right.arguments.at(0), _letBindings); + Literal a = parseLiteralOrReturnEqualBoolean(_right.arguments.at(0)); Literal b; if (_right.arguments.size() > 2) { @@ -827,10 +837,10 @@ void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expressi // Reduce "a and b and c and ..." to "a and (b and c and ...)" smtutil::Expression rightSuffix = _right; rightSuffix.arguments.erase(rightSuffix.arguments.begin()); - b = parseLiteralOrReturnEqualBoolean(rightSuffix, _letBindings); + b = parseLiteralOrReturnEqualBoolean(rightSuffix); } else - b = parseLiteralOrReturnEqualBoolean(_right.arguments.at(1), _letBindings); + b = parseLiteralOrReturnEqualBoolean(_right.arguments.at(1)); if (_right.name == "and") { @@ -850,7 +860,7 @@ void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expressi { solAssert(_right.arguments.size() == 2); // a = (x => y) <=> a = or(-x, y) - addBooleanEquality(_left, !_right.arguments.at(0) || _right.arguments.at(1), move(_letBindings)); + addBooleanEquality(_left, !_right.arguments.at(0) || _right.arguments.at(1)); } else if (_right.name == "=") { @@ -863,7 +873,7 @@ void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expressi else if (_right.name == "xor") { solAssert(_right.arguments.size() == 2); - addBooleanEquality(negate(_left), _right.arguments.at(0) == _right.arguments.at(1), move(_letBindings)); + addBooleanEquality(negate(_left), _right.arguments.at(0) == _right.arguments.at(1)); } else solAssert(false, "Unsupported operation: " + _right.name); diff --git a/libsolutil/BooleanLP.h b/libsolutil/BooleanLP.h index 11bbe6cbd..1e7d8fc17 100644 --- a/libsolutil/BooleanLP.h +++ b/libsolutil/BooleanLP.h @@ -76,10 +76,7 @@ public: void declareVariable(std::string const& _name, smtutil::SortPointer const& _sort) override; - void addAssertion(smtutil::Expression const& _expr) override - { - addAssertion(_expr, std::make_shared>()); - } + void addAssertion(smtutil::Expression const& _expr); std::pair> check(std::vector const& _expressionsToEvaluate) override; @@ -89,33 +86,28 @@ public: private: using rational = boost::rational; using LetBinding = std::variant; - using LetBindings = std::shared_ptr>; - - void addAssertion( - smtutil::Expression const& _expr, - LetBindings _letBindings - ); smtutil::Expression declareInternalVariable(bool _boolean); void declareVariable(std::string const& _name, bool _boolean); - /// Handles a "let" expression and adds the bindings to @a _letBindings. - void addLetBindings(smtutil::Expression const& _let, LetBindings& _letBindings); + /// Handles a "let" expression and adds the bindings to @a m_letBindings. + std::map addLetBindings(smtutil::Expression const& _let); + void removeLetBindings(std::map const& _toRemove); /// Parses an expression of sort bool and returns a literal. - std::optional parseLiteral(smtutil::Expression const& _expr, LetBindings _letBindings); + std::optional parseLiteral(smtutil::Expression const& _expr); Literal negate(Literal const& _lit); - Literal parseLiteralOrReturnEqualBoolean(smtutil::Expression const& _expr, LetBindings _letBindings); + Literal parseLiteralOrReturnEqualBoolean(smtutil::Expression const& _expr); /// Parses the expression and expects a linear sum of variables. /// Returns a vector with the first element being the constant and the /// other elements the factors for the respective variables. /// If the expression cannot be properly parsed or is not linear, /// returns an empty vector. - std::optional parseLinearSum(smtutil::Expression const& _expression, LetBindings _letBindings); + std::optional parseLinearSum(smtutil::Expression const& _expression); bool isLiteral(smtutil::Expression const& _expression) const; - std::optional parseFactor(smtutil::Expression const& _expression, LetBindings _letBindings) const; + std::optional parseFactor(smtutil::Expression const& _expression) const; bool tryAddDirectBounds(Constraint const& _constraint); void addUpperBound(size_t _index, RationalWithDelta _value); @@ -123,7 +115,7 @@ private: size_t addConditionalConstraint(Constraint _constraint); - void addBooleanEquality(Literal const& _left, smtutil::Expression const& _right, LetBindings _letBindings); + void addBooleanEquality(Literal const& _left, smtutil::Expression const& _right); //std::string toString(std::vector const& _bounds) const; std::string toString(Clause const& _clause) const; @@ -142,6 +134,9 @@ private: /// Stack of state, to allow for push()/pop(). std::vector m_state{{State{}}}; + + /// Current let bindings - only valid/used while parsing a statement. + std::map m_letBindings; };