Simplifications on LinearExpression.

This commit is contained in:
chriseth 2022-02-03 09:55:23 +01:00
parent e6c67924b0
commit 00a277c0f5
2 changed files with 15 additions and 14 deletions

View File

@ -353,7 +353,7 @@ void eraseIndices(T& _data, vector<bool> const& _indices)
T result;
for (size_t i = 0; i < _data.size(); i++)
if (!_indices[i])
result.emplace_back(move(_data[i]));
result.push_back(move(_data[i]));
_data = move(result);
}

View File

@ -39,6 +39,7 @@ using rational = boost::rational<bigint>;
/**
* A linear expression of the form
* factors[0] + factors[1] * X1 + factors[2] * X2 + ...
* where the variables X_i are implicit.
*/
struct LinearExpression
{
@ -72,7 +73,9 @@ struct LinearExpression
auto begin() const { return factors.begin(); }
auto end() const { return factors.end(); }
void emplace_back(rational _value) { factors.emplace_back(move(_value)); }
void push_back(rational _value) { factors.push_back(move(_value)); }
size_t size() const { return factors.size(); }
void resize(size_t _size)
{
@ -88,15 +91,13 @@ struct LinearExpression
bool isConstant() const
{
return ranges::all_of(factors | ranges::views::tail, [](rational const& _v) { return _v.numerator() == 0; });
return ranges::all_of(factors | ranges::views::tail, [](rational const& _v) { return !_v; });
}
size_t size() const { return factors.size(); }
LinearExpression& operator/=(rational const& _divisor)
{
for (rational& x: factors)
if (x.numerator())
if (x)
x /= _divisor;
return *this;
}
@ -104,7 +105,7 @@ struct LinearExpression
LinearExpression& operator*=(rational const& _factor)
{
for (rational& x: factors)
if (x.numerator())
if (x)
x *= _factor;
return *this;
}
@ -112,7 +113,7 @@ struct LinearExpression
friend LinearExpression operator*(rational const& _factor, LinearExpression _expr)
{
for (rational& x: _expr.factors)
if (x.numerator())
if (x)
x *= _factor;
return _expr;
}
@ -120,10 +121,10 @@ struct LinearExpression
LinearExpression& operator-=(LinearExpression const& _y)
{
if (size() < _y.size())
factors.resize(_y.size());
resize(_y.size());
for (size_t i = 0; i < size(); ++i)
if (_y.factors[i].numerator())
factors[i] -= _y.factors[i];
if (i < _y.size() && _y[i])
(*this)[i] -= _y[i];
return *this;
}
@ -137,10 +138,10 @@ struct LinearExpression
LinearExpression& operator+=(LinearExpression const& _y)
{
if (size() < _y.size())
factors.resize(_y.size());
resize(_y.size());
for (size_t i = 0; i < size(); ++i)
if (_y.factors[i].numerator())
factors[i] += _y.factors[i];
if (i < _y.size() && _y[i])
(*this)[i] += _y[i];
return *this;
}