Implement LinearExpression using a sparse vector.

This commit is contained in:
chriseth 2022-03-22 13:39:37 +01:00
parent 72ae0f6a1a
commit 14284ef1b1
3 changed files with 166 additions and 135 deletions

View File

@ -501,8 +501,8 @@ bool BooleanLPSolver::tryAddDirectBounds(Constraint const& _constraint)
{ {
// 0 <= b or 0 = b // 0 <= b or 0 = b
if ( if (
_constraint.data.front() < 0 || _constraint.data.constantFactor() < 0 ||
(_constraint.equality && _constraint.data.front() != 0) (_constraint.equality && _constraint.data.constantFactor() == 0)
) )
{ {
// cout << "SETTING INF" << endl; // cout << "SETTING INF" << endl;
@ -665,7 +665,7 @@ string BooleanLPSolver::toString(Constraint const& _constraint) const
return return
joinHumanReadable(line, " + ") + joinHumanReadable(line, " + ") +
(_constraint.equality ? " = " : " <= ") + (_constraint.equality ? " = " : " <= ") +
::toString(_constraint.data.front()); ::toString(_constraint.data.constantFactor());
} }
Constraint const& BooleanLPSolver::conditionalConstraint(size_t _index) const Constraint const& BooleanLPSolver::conditionalConstraint(size_t _index) const

View File

@ -65,6 +65,8 @@ inline std::vector<bool>& operator|=(std::vector<bool>& _x, std::vector<bool> co
*/ */
struct Tableau struct Tableau
{ {
size_t columns;
size_t rows;
/// The factors of the objective function (first row of the tableau) /// The factors of the objective function (first row of the tableau)
LinearExpression objective; LinearExpression objective;
/// The tableau matrix (equational form). /// The tableau matrix (equational form).
@ -115,43 +117,45 @@ string toString(Tableau const& _tableau)
/// Adds slack variables to remove non-equality costraints from a set of constraints /// Adds slack variables to remove non-equality costraints from a set of constraints
/// and returns the data part of the tableau / constraints. /// and returns the data part of the tableau / constraints.
/// The second return variable is true if the original input had any equality constraints. /// The second return variable is true if the original input had any equality constraints.
pair<vector<LinearExpression>, bool> toEquationalForm(vector<Constraint> _constraints) pair<Tableau, bool> toEquationalForm(vector<Constraint> _constraints)
{ {
size_t varsNeeded = static_cast<size_t>(ranges::count_if(_constraints, [](Constraint const& _c) { return !_c.equality; })); Tableau t;
if (varsNeeded > 0)
t.columns = 0;
for (Constraint const& constraint: _constraints)
t.columns = max(constraint.data.maxIndex() + 1, t.columns);
t.rows = _constraints.size();
size_t varsAdded = 0;
bool hadEquality = false;
for (Constraint& constraint: _constraints)
{ {
size_t columns = _constraints.at(0).data.size(); if (constraint.equality)
size_t varsAdded = 0; hadEquality = true;
for (Constraint& constraint: _constraints) else
{ {
solAssert(constraint.data.size() == columns, ""); constraint.data[t.columns + varsAdded] = bigint(1);
constraint.data.resize(columns + varsNeeded); varsAdded++;
if (!constraint.equality)
{
constraint.equality = true;
constraint.data[columns + varsAdded] = bigint(1);
varsAdded++;
}
} }
solAssert(varsAdded == varsNeeded); t.data.emplace_back(move(constraint.data));
} }
t.columns += varsAdded;
vector<LinearExpression> data; return make_pair(move(t), hadEquality);
for (Constraint& c: _constraints)
data.emplace_back(move(c.data));
return make_pair(move(data), varsNeeded < _constraints.size());
} }
/// Finds the simplex pivot column: The column with the largest positive objective factor. /// Finds the simplex pivot column: The column with the largest positive objective factor.
/// If all objective factors are zero or negative, the optimum has been found and nullopt is returned. /// If all objective factors are zero or negative, the optimum has been found and nullopt is returned.
optional<size_t> findPivotColumn(Tableau const& _tableau) optional<size_t> findPivotColumn(Tableau const& _tableau)
{ {
auto&& [maxColumn, maxValue] = ranges::max( size_t maxColumn = 0;
_tableau.objective.enumerateTail(), rational maxValue{};
{}, for (auto const& [column, value]: _tableau.objective.enumerateTail())
[](std::pair<size_t, rational> const& _x) { return _x.second; } if (value > maxValue)
); {
maxColumn = column;
maxValue = value;
}
if (maxValue <= rational{0}) if (maxValue <= rational{0})
return nullopt; // found optimum return nullopt; // found optimum
@ -218,10 +222,8 @@ void selectLastVectorsAsBasis(Tableau& _tableau)
{ {
// We might skip the operation for a column if it is already the correct // We might skip the operation for a column if it is already the correct
// unit vector and its objective coefficient is zero. // unit vector and its objective coefficient is zero.
size_t columns = _tableau.objective.size(); for (size_t i = 0; i < _tableau.rows; ++i)
size_t rows = _tableau.data.size(); performPivot(_tableau, i, _tableau.columns - _tableau.rows + i);
for (size_t i = 0; i < rows; ++i)
performPivot(_tableau, i, columns - rows + i);
} }
/// If column @a _column inside tableau is a basis vector /// If column @a _column inside tableau is a basis vector
@ -249,8 +251,8 @@ optional<size_t> basisIndex(Tableau const& _tableau, size_t _column)
vector<rational> solutionVector(Tableau const& _tableau) vector<rational> solutionVector(Tableau const& _tableau)
{ {
vector<rational> result; vector<rational> result;
vector<bool> rowsSeen(_tableau.data.size(), false); vector<bool> rowsSeen(_tableau.rows, false);
for (size_t j = 1; j < _tableau.objective.size(); j++) for (size_t j = 1; j < _tableau.columns; j++)
{ {
optional<size_t> row = basisIndex(_tableau, j); optional<size_t> row = basisIndex(_tableau, j);
if (row && rowsSeen[*row]) if (row && rowsSeen[*row])
@ -270,7 +272,7 @@ vector<rational> solutionVector(Tableau const& _tableau)
/// Tries for a number of iterations and then gives up. /// Tries for a number of iterations and then gives up.
pair<LPResult, Tableau> simplexEq(Tableau _tableau) pair<LPResult, Tableau> simplexEq(Tableau _tableau)
{ {
size_t const iterations = min<size_t>(60, 50 + _tableau.objective.size() * 2); size_t const iterations = min<size_t>(60, 50 + _tableau.columns * 2);
for (size_t step = 0; step <= iterations; ++step) for (size_t step = 0; step <= iterations; ++step)
{ {
optional<size_t> pivotColumn = findPivotColumn(_tableau); optional<size_t> pivotColumn = findPivotColumn(_tableau);
@ -298,18 +300,16 @@ pair<LPResult, Tableau> simplexPhaseI(Tableau _tableau)
{ {
LinearExpression originalObjective = _tableau.objective; LinearExpression originalObjective = _tableau.objective;
size_t rows = _tableau.data.size(); size_t const columns = _tableau.columns;
size_t columns = _tableau.objective.size();
for (size_t i = 0; i < rows; ++i)
{
if (_tableau.data[i][0] < 0)
_tableau.data[i] *= -1;
_tableau.data[i].resize(columns + rows);
_tableau.data[i][columns + i] = 1;
}
_tableau.objective = {}; _tableau.objective = {};
_tableau.objective.resize(columns); _tableau.columns += _tableau.rows;
_tableau.objective.resize(columns + rows, rational{-1}); for (size_t i = 0; i < _tableau.rows; ++i)
{
if (_tableau.data[i].constantFactor() < 0)
_tableau.data[i] *= -1;
_tableau.data[i][columns + i] = 1;
_tableau.objective[columns + i] = rational{-1};
}
// This sets the objective factors of the slack variables // This sets the objective factors of the slack variables
// to zero (and thus selects a basic feasible solution). // to zero (and thus selects a basic feasible solution).
@ -332,7 +332,7 @@ pair<LPResult, Tableau> simplexPhaseI(Tableau _tableau)
// Restore original objective and remove slack variables. // Restore original objective and remove slack variables.
_tableau.objective = move(originalObjective); _tableau.objective = move(originalObjective);
for (auto& row: _tableau.data) for (auto& row: _tableau.data)
row.resize(columns); row.eraseIndicesGE(columns);
return make_pair(LPResult::Feasible, move(_tableau)); return make_pair(LPResult::Feasible, move(_tableau));
} }
@ -341,7 +341,7 @@ pair<LPResult, Tableau> simplexPhaseI(Tableau _tableau)
bool needsPhaseI(Tableau const& _tableau) bool needsPhaseI(Tableau const& _tableau)
{ {
for (auto const& row: _tableau.data) for (auto const& row: _tableau.data)
if (row[0] < 0) if (row.constantFactor() < 0)
return true; return true;
return false; return false;
} }
@ -350,10 +350,9 @@ bool needsPhaseI(Tableau const& _tableau)
pair<LPResult, vector<rational>> simplex(vector<Constraint> _constraints, LinearExpression _objectives) pair<LPResult, vector<rational>> simplex(vector<Constraint> _constraints, LinearExpression _objectives)
{ {
Tableau tableau; Tableau tableau;
tableau.objective = move(_objectives);
bool hasEquations = false; bool hasEquations = false;
tie(tableau.data, hasEquations) = toEquationalForm(_constraints); tie(tableau, hasEquations) = toEquationalForm(_constraints);
tableau.objective.resize(tableau.data.at(0).size()); tableau.objective = move(_objectives);
if (hasEquations || needsPhaseI(tableau)) if (hasEquations || needsPhaseI(tableau))
{ {
@ -375,8 +374,6 @@ pair<LPResult, vector<rational>> simplex(vector<Constraint> _constraints, Linear
/// @returns false if the bounds make the state infeasible. /// @returns false if the bounds make the state infeasible.
optional<ReasonSet> boundsToConstraints(SolvingState& _state) optional<ReasonSet> boundsToConstraints(SolvingState& _state)
{ {
size_t columns = _state.variableNames.size();
// Bound zero should not exist because the variable zero does not exist. // Bound zero should not exist because the variable zero does not exist.
for (auto const& [varIndex, bounds]: _state.bounds | ranges::views::enumerate | ranges::views::tail) for (auto const& [varIndex, bounds]: _state.bounds | ranges::views::enumerate | ranges::views::tail)
{ {
@ -387,7 +384,6 @@ optional<ReasonSet> boundsToConstraints(SolvingState& _state)
if (*bounds.lower == *bounds.upper) if (*bounds.lower == *bounds.upper)
{ {
LinearExpression c; LinearExpression c;
c.resize(columns);
c[0] = *bounds.lower; c[0] = *bounds.lower;
c[varIndex] = bigint(1); c[varIndex] = bigint(1);
_state.constraints.emplace_back(Constraint{move(c), true, bounds.lowerReasons + bounds.upperReasons}); _state.constraints.emplace_back(Constraint{move(c), true, bounds.lowerReasons + bounds.upperReasons});
@ -397,7 +393,6 @@ optional<ReasonSet> boundsToConstraints(SolvingState& _state)
if (bounds.lower && *bounds.lower > 0) if (bounds.lower && *bounds.lower > 0)
{ {
LinearExpression c; LinearExpression c;
c.resize(columns);
c[0] = -*bounds.lower; c[0] = -*bounds.lower;
c[varIndex] = bigint(-1); c[varIndex] = bigint(-1);
_state.constraints.emplace_back(Constraint{move(c), false, move(bounds.lowerReasons)}); _state.constraints.emplace_back(Constraint{move(c), false, move(bounds.lowerReasons)});
@ -405,7 +400,6 @@ optional<ReasonSet> boundsToConstraints(SolvingState& _state)
if (bounds.upper) if (bounds.upper)
{ {
LinearExpression c; LinearExpression c;
c.resize(columns);
c[0] = *bounds.upper; c[0] = *bounds.upper;
c[varIndex] = bigint(1); c[varIndex] = bigint(1);
_state.constraints.emplace_back(Constraint{move(c), false, move(bounds.upperReasons)}); _state.constraints.emplace_back(Constraint{move(c), false, move(bounds.upperReasons)});
@ -426,12 +420,11 @@ void eraseIndices(T& _data, vector<bool> const& _indicesToRemove)
_data = move(result); _data = move(result);
} }
void removeColumns(SolvingState& _state, vector<bool> const& _columnsToRemove) void removeColumns(SolvingState& _state, vector<bool> const& _columnsToRemove)
{ {
eraseIndices(_state.bounds, _columnsToRemove); eraseIndices(_state.bounds, _columnsToRemove);
for (Constraint& constraint: _state.constraints) for (Constraint& constraint: _state.constraints)
eraseIndices(constraint.data, _columnsToRemove); constraint.data.eraseIndices(_columnsToRemove);
eraseIndices(_state.variableNames, _columnsToRemove); eraseIndices(_state.variableNames, _columnsToRemove);
} }
@ -440,7 +433,7 @@ auto nonZeroEntriesInColumn(SolvingState const& _state, size_t _column)
return return
_state.constraints | _state.constraints |
ranges::views::enumerate | ranges::views::enumerate |
ranges::views::filter([=](auto const& _entry) { return _entry.second.data[_column]; }) | ranges::views::filter([=](auto const& _entry) -> bool { return _entry.second.data[_column]; }) |
ranges::views::transform([](auto const& _entry) { return _entry.first; }); ranges::views::transform([](auto const& _entry) { return _entry.first; });
} }
@ -482,12 +475,8 @@ pair<vector<bool>, vector<bool>> connectedComponent(SolvingState const& _state,
void normalizeRowLengths(SolvingState& _state) void normalizeRowLengths(SolvingState& _state)
{ {
size_t vars = max(_state.variableNames.size(), _state.bounds.size()); size_t vars = max(_state.variableNames.size(), _state.bounds.size());
for (Constraint const& c: _state.constraints)
vars = max(vars, c.data.size());
_state.variableNames.resize(vars); _state.variableNames.resize(vars);
_state.bounds.resize(vars); _state.bounds.resize(vars);
for (Constraint& c: _state.constraints)
c.data.resize(vars);
} }
} }
@ -498,11 +487,7 @@ bool Constraint::operator<(Constraint const& _other) const
if (equality != _other.equality) if (equality != _other.equality)
return equality < _other.equality; return equality < _other.equality;
for (size_t i = 0; i < max(data.size(), _other.data.size()); ++i) return data < _other.data;
if (rational diff = data.get(i) - _other.data.get(i))
return diff < 0;
return false;
} }
bool Constraint::operator==(Constraint const& _other) const bool Constraint::operator==(Constraint const& _other) const
@ -510,10 +495,7 @@ bool Constraint::operator==(Constraint const& _other) const
if (equality != _other.equality) if (equality != _other.equality)
return false; return false;
for (size_t i = 0; i < max(data.size(), _other.data.size()); ++i) return data == _other.data;
if (data.get(i) != _other.data.get(i))
return false;
return true;
} }
bool SolvingState::Compare::operator()(SolvingState const& _a, SolvingState const& _b) const bool SolvingState::Compare::operator()(SolvingState const& _a, SolvingState const& _b) const
@ -551,7 +533,7 @@ string SolvingState::toString() const
reasonToString(constraint.reasons, reasonLength) + reasonToString(constraint.reasons, reasonLength) +
joinHumanReadable(line, " + ") + joinHumanReadable(line, " + ") +
(constraint.equality ? " = " : " <= ") + (constraint.equality ? " = " : " <= ") +
::toString(constraint.data.front()) + ::toString(constraint.data.constantFactor()) +
"\n"; "\n";
} }
result += "Bounds:\n"; result += "Bounds:\n";
@ -617,7 +599,7 @@ optional<ReasonSet> SolvingStateSimplifier::removeFixedVariables()
if (constraint.data[index]) if (constraint.data[index])
{ {
constraint.data[0] -= constraint.data[index] * lower; constraint.data[0] -= constraint.data[index] * lower;
constraint.data[index] = 0; constraint.data.erase(index);
constraint.reasons += reasons; constraint.reasons += reasons;
} }
} }
@ -645,8 +627,8 @@ optional<ReasonSet> SolvingStateSimplifier::extractDirectConstraints()
{ {
// 0 <= b or 0 = b // 0 <= b or 0 = b
if ( if (
constraint.data.front().numerator() < 0 || constraint.data.constantFactor().numerator() < 0 ||
(constraint.equality && constraint.data.front()) (constraint.equality && constraint.data.constantFactor())
) )
return constraint.reasons; return constraint.reasons;
} }
@ -654,7 +636,7 @@ optional<ReasonSet> SolvingStateSimplifier::extractDirectConstraints()
{ {
auto&& [varIndex, factor] = nonzeroCoefficients.front(); auto&& [varIndex, factor] = nonzeroCoefficients.front();
// a * x <= b // a * x <= b
rational bound = constraint.data[0] / factor; rational bound = constraint.data.constantFactor() / factor;
if ( if (
(factor >= 0 || constraint.equality) && (factor >= 0 || constraint.equality) &&
(!m_state.bounds[varIndex].upper || bound < m_state.bounds[varIndex].upper) (!m_state.bounds[varIndex].upper || bound < m_state.bounds[varIndex].upper)
@ -760,9 +742,15 @@ SolvingState ProblemSplitter::next()
// to undefined behaviour for connectedComponent // to undefined behaviour for connectedComponent
Constraint const& constraint = m_state.constraints[i]; Constraint const& constraint = m_state.constraints[i];
Constraint splitRow{{}, constraint.equality, constraint.reasons}; Constraint splitRow{{}, constraint.equality, constraint.reasons};
for (size_t j = 0; j < constraint.data.size(); j++) splitRow.data[0] = constraint.data.constantFactor();
if (j == 0 || includedColumns[j]) size_t j = 1;
splitRow.data.push_back(constraint.data[j]); for (auto&& [i, included]: includedColumns | ranges::views::enumerate | ranges::views::tail)
if (included)
{
if (rational const& x = constraint.data.get(i))
splitRow.data[j] = x;
j++;
}
splitOff.constraints.push_back(move(splitRow)); splitOff.constraints.push_back(move(splitRow));
} }
@ -818,8 +806,8 @@ pair<LPResult, variant<Model, ReasonSet>> LPSolver::check(SolvingState _state)
else else
{ {
LinearExpression objectives; LinearExpression objectives;
objectives.resize(1); for (size_t i = 1; i < split.variableNames.size(); i++)
objectives.resize(split.constraints.front().data.size(), rational(bigint(1))); objectives[i] = rational(bigint(1));
tie(lpResult, solution) = simplex(split.constraints, move(objectives)); tie(lpResult, solution) = simplex(split.constraints, move(objectives));
// If we do not support models, do not store it in the cache because // If we do not support models, do not store it in the cache because

View File

@ -51,7 +51,6 @@ public:
static LinearExpression factorForVariable(size_t _index, rational _factor) static LinearExpression factorForVariable(size_t _index, rational _factor)
{ {
LinearExpression result; LinearExpression result;
result.resize(_index + 1);
result[_index] = std::move(_factor); result[_index] = std::move(_factor);
return result; return result;
} }
@ -59,20 +58,28 @@ public:
static LinearExpression constant(rational _factor) static LinearExpression constant(rational _factor)
{ {
LinearExpression result; LinearExpression result;
result.resize(1);
result[0] = std::move(_factor); result[0] = std::move(_factor);
return result; return result;
} }
rational const& constantFactor() const
{
return get(0);
}
rational const& get(size_t _index) const rational const& get(size_t _index) const
{ {
static rational const zero; static rational const zero;
return _index < factors.size() ? factors[_index] : zero; auto it = factors.find(_index);
if (it == factors.end())
return zero;
else
return it->second;
} }
rational const& operator[](size_t _index) const rational const& operator[](size_t _index) const
{ {
return factors[_index]; return get(_index);
} }
rational& operator[](size_t _index) rational& operator[](size_t _index)
@ -80,75 +87,106 @@ public:
return factors[_index]; return factors[_index];
} }
auto enumerate() const { return factors | ranges::view::enumerate; } auto const& enumerate() const { return factors; }
// leave out the zero if exists // leave out the constant factor if exists
auto enumerateTail() const { return factors | ranges::view::enumerate | ranges::view::tail; } auto enumerateTail() const
rational const& front() const { return factors.front(); }
void push_back(rational _value) { factors.push_back(std::move(_value)); }
size_t size() const { return factors.size(); }
void resize(size_t _size, rational _default = {})
{ {
factors.resize(_size, std::move(_default)); auto it = factors.begin();
if (it != factors.end() && !it->first)
++it;
return ranges::subrange(it, factors.end());
}
void eraseIndices(std::vector<bool> const& _indices)
{
for (auto it = factors.begin(); it != factors.end();)
{
size_t i = it->first;
if (i < _indices.size() && _indices[i])
it = factors.erase(it);
else
++it;
}
}
/// Erases all indices greater or equal to _index.
void eraseIndicesGE(size_t _index)
{
auto it = factors.begin();
while (it != factors.end() && it->first < _index) ++it;
factors.erase(it, factors.end());
}
void erase(size_t _index) { factors.erase(_index); }
size_t maxIndex() const
{
if (factors.empty())
return 0;
else
return factors.rbegin()->first;
} }
/// @returns true if all factors of variables are zero. /// @returns true if all factors of variables are zero.
bool isConstant() const bool isConstant() const
{ {
return ranges::all_of(factors | ranges::views::tail, [](rational const& _v) { return !_v; }); return ranges::all_of(enumerateTail(), [](auto const& _item) -> bool { return !_item.second; });
}
bool operator<(LinearExpression const& _other) const
{
// "The comparison igrones the map's ordering"
return factors < _other.factors;
}
bool operator==(LinearExpression const& _other) const
{
// TODO this might be wrong if there are stray zeros.
return factors == _other.factors;
} }
LinearExpression& operator/=(rational const& _divisor) LinearExpression& operator/=(rational const& _divisor)
{ {
for (rational& x: factors) for (auto& item: factors)
if (x) item.second /= _divisor;
x /= _divisor;
return *this; return *this;
} }
LinearExpression& operator*=(rational const& _factor) LinearExpression& operator*=(rational const& _factor)
{ {
for (rational& x: factors) for (auto& item: factors)
if (x) item.second *= _factor;
x *= _factor;
return *this; return *this;
} }
friend LinearExpression operator*(rational const& _factor, LinearExpression _expr) friend LinearExpression operator*(rational const& _factor, LinearExpression _expr)
{ {
for (rational& x: _expr.factors) for (auto& item: _expr.factors)
if (x) item.second *= _factor;
x *= _factor;
return _expr; return _expr;
} }
LinearExpression& operator-=(LinearExpression const& _y)
{
if (size() < _y.size())
resize(_y.size());
for (size_t i = 0; i < size(); ++i)
if (i < _y.size() && _y[i])
(*this)[i] -= _y[i];
return *this;
}
LinearExpression operator-(LinearExpression const& _y) const
{
LinearExpression result = *this;
result -= _y;
return result;
}
LinearExpression& operator+=(LinearExpression const& _y) LinearExpression& operator+=(LinearExpression const& _y)
{ {
if (size() < _y.size()) for (auto const& [i, x]: _y.enumerate())
resize(_y.size()); {
for (size_t i = 0; i < size(); ++i) // TODO this could be even more efficient.
if (i < _y.size() && _y[i]) if (rational v = get(i) + x)
(*this)[i] += _y[i]; factors[i] = move(v);
else
factors.erase(i);
}
return *this;
}
LinearExpression& operator-=(LinearExpression const& _y)
{
for (auto const& [i, x]: _y.enumerate())
{
// TODO this could be even more efficient.
if (rational v = get(i) - x)
factors[i] = move(v);
else
factors.erase(i);
}
return *this; return *this;
} }
@ -159,6 +197,13 @@ public:
return result; return result;
} }
LinearExpression operator-(LinearExpression const& _y) const
{
LinearExpression result = *this;
result -= _y;
return result;
}
/// Multiply two linear expression. This only works if at least one of them is a constant. /// Multiply two linear expression. This only works if at least one of them is a constant.
/// Returns nullopt otherwise. /// Returns nullopt otherwise.
@ -174,15 +219,13 @@ public:
if (!_y->isConstant()) if (!_y->isConstant())
return std::nullopt; return std::nullopt;
rational const& factor = _y->get(0); *_x *= _y->constantFactor();
for (rational& element: _x->factors)
element *= factor;
return _x; return _x;
} }
private: private:
std::vector<rational> factors; // TODO maybe a vector of pairs could be more efficient.
std::map<size_t, rational> factors;
}; };
} }