diff --git a/libsolutil/CMakeLists.txt b/libsolutil/CMakeLists.txt index 077d9f1e6..32aaf5264 100644 --- a/libsolutil/CMakeLists.txt +++ b/libsolutil/CMakeLists.txt @@ -32,6 +32,7 @@ set(sources Numeric.cpp Numeric.h picosha2.h + LinearExpression.cpp LinearExpression.h Result.h SetOnce.h diff --git a/libsolutil/LinearExpression.cpp b/libsolutil/LinearExpression.cpp new file mode 100644 index 000000000..487a2b43a --- /dev/null +++ b/libsolutil/LinearExpression.cpp @@ -0,0 +1,124 @@ +/* + This file is part of solidity. + + solidity is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + solidity is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with solidity. If not, see . +*/ +// SPDX-License-Identifier: GPL-3.0 + +#include + +using solidity::util; + +void SparseMatrix::multiplyRowByFactor(size_t _row, rational const& _factor) +{ + Entry* e = m_row_start[_row]; + while (e) + { + e->value *= _factor; + Entry* next = e->next_in_row; + if (!e->value) + remove(*e); + e = next; + } +} + +void SparseMatrix::addMultipleOfRow(size_t _sourceRow, size_t _targetRow, rational const& _factor) +{ + Entry* source = m_row_start[_sourceRow]; + Entry* target = m_row_start[_targetRow]; + + while (source) + { + while (target && target->col < source->col) + target = target->next_in_row; + if (target && target->col == source->col) + { + target->value += _factor * source->value; + if (!target->value) + { + Entry* next = target->next_in_row; + remove(*target); + target = next; + } + } + else if (!target) + target = appendToRow(_targetRow, source->col, _factor * source->value); + else if (target->col > source->col) + target = prependInRow(target, source->col, _factor * source->value); + + source = source->next_in_row; + } +} + + +void SparseMatrix::appendRow(LinearExpression const& _entries) +{ + Entry* prev = nullptr; + m_row_start.push(nullptr); + m_row_end.push(nullptr); + size_t row_nr = m_row_start.size() - 1; + for (auto&& [i, v]: _entries.enumerate()) { + if (!v) + continue; + prev = appendToRow(row_nr, i, move(v)); + + prev = curr; + } +} + +void SparseMatrix::remove(Entry& _e) +{ + if (_e.prev_in_row) + _e.prev_in_row->next_in_row = _e.next_in_row; + else + m_row_start[_e.row] = _e.next_in_row; + if (_e.next_in_row) + _e.next_in_row->prev_in_row = _e.prev_in_row; + else + m_row_end[_e.row] = _e.prev_in_row; + if (_e.prev_in_col) + _e.prev_in_col->next_in_col = _e.next_in_col; + else + m_col_start[_e.col] = _e.next_in_col; + if (_e.next_in_col) + _e.next_in_col->prev_in_col = _e.prev_in_col; + else + m_col_end[_e.col] = _e.prev_in_col; +} + +void SparseMatrix::appendToRow(size_t _row, size_t _column, rational _value) +{ + m_elements.emplace(make_unique( + move(_value), + _row, + _column, + m_row_end[_row], + nullptr, + m_column_end[i], + nullptr + )); + Entry const* e = m_elements.back().get(); + if (m_row_end[_row]) + m_row_end[_row]->next_in_row = e; + if (!m_row_start[_row]) + m_row_start[_row] = e; + if (i >= m_col_start.size()) + m_col_start.resize(i + 1); + if (!m_col_start[i]) + m_col_start[i] = e; + if (i >= m_col_end.size()) + m_col_end.resize(i + 1); + if (!m_col_end[i]) + m_col_end[i] = e; +} diff --git a/libsolutil/LinearExpression.h b/libsolutil/LinearExpression.h index 0d5045421..75719e1ff 100644 --- a/libsolutil/LinearExpression.h +++ b/libsolutil/LinearExpression.h @@ -185,4 +185,41 @@ private: std::vector factors; }; +class SparseMatrix +{ +public: + struct Entry + { + rational value; + // TOOD make it 32 bit as well + size_t row; + size_t col; + // TODO maybe better to use 32-bit indices instead of 64-bit pointers + Entry* prev_in_row; + Entry* next_in_row; + Entry* prev_in_col; + Entry* next_in_col; + }; + /// @returns (i, v) for all non-zero v in the column _column + void enumerateColumn(size_t _column) const; + /// @returns (i, v) for all non-zero v in the row _row + void enumerateRow(size_t _row) const; + void multiplyRowByFactor(size_t _row, rational const& _factor); + void addMultipleOfRow(size_t _sourceRow, size_t _targetRow, rational const& _factor); + rational entry(size_t _row, size_t _column) const; + + void appendRow(LinearExpression const& _entries); + +private: + + void remove(Entry& _entry); + Entry* appendToRow(size_t _row, size_t _column, rational _value); + + std::vector> m_elements; + std::vector m_row_start; + std::vector m_col_start; + std::vector m_row_end; + std::vector m_col_end; +}; + }