cosmos-sdk/store/storage/sqlite/iterator.go
Marko 3166ebbf91
store: keys as bytes (#19775)
Co-authored-by: sontrinh16 <trinhleson2000@gmail.com>
Co-authored-by: cool-developer <51834436+cool-develope@users.noreply.github.com>
Co-authored-by: Aleksandr Bezobchuk <alexanderbez@users.noreply.github.com>
Co-authored-by: Matt Kocubinski <mkocubinski@gmail.com>
2024-03-20 10:48:16 +00:00

184 lines
3.6 KiB
Go

package sqlite
import (
"bytes"
"database/sql"
"fmt"
"slices"
"strings"
corestore "cosmossdk.io/core/store"
)
var _ corestore.Iterator = (*iterator)(nil)
type iterator struct {
statement *sql.Stmt
rows *sql.Rows
key, val []byte
start, end []byte
valid bool
err error
}
func newIterator(db *Database, storeKey []byte, targetVersion uint64, start, end []byte, reverse bool) (*iterator, error) {
if targetVersion < db.earliestVersion {
return &iterator{
start: start,
end: end,
valid: false,
}, nil
}
var (
keyClause = []string{"store_key = ?", "version <= ?"}
queryArgs []any
)
switch {
case len(start) > 0 && len(end) > 0:
keyClause = append(keyClause, "key >= ?", "key < ?")
queryArgs = []any{storeKey, targetVersion, start, end, targetVersion}
case len(start) > 0 && len(end) == 0:
keyClause = append(keyClause, "key >= ?")
queryArgs = []any{storeKey, targetVersion, start, targetVersion}
case len(start) == 0 && len(end) > 0:
keyClause = append(keyClause, "key < ?")
queryArgs = []any{storeKey, targetVersion, end, targetVersion}
default:
queryArgs = []any{storeKey, targetVersion, targetVersion}
}
orderBy := "ASC"
if reverse {
orderBy = "DESC"
}
// Note, this is not susceptible to SQL injection because placeholders are used
// for parts of the query outside the store's direct control.
stmt, err := db.storage.Prepare(fmt.Sprintf(`
SELECT x.key, x.value
FROM (
SELECT key, value, version, tombstone,
row_number() OVER (PARTITION BY key ORDER BY version DESC) AS _rn
FROM state_storage WHERE %s
) x
WHERE x._rn = 1 AND (x.tombstone = 0 OR x.tombstone > ?) ORDER BY x.key %s;
`, strings.Join(keyClause, " AND "), orderBy))
if err != nil {
return nil, fmt.Errorf("failed to prepare SQL statement: %w", err)
}
rows, err := stmt.Query(queryArgs...)
if err != nil {
_ = stmt.Close()
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
}
itr := &iterator{
statement: stmt,
rows: rows,
start: start,
end: end,
valid: rows.Next(),
}
if !itr.valid {
itr.err = fmt.Errorf("iterator invalid: %w", sql.ErrNoRows)
return itr, nil
}
// read the first row
itr.parseRow()
if !itr.valid {
return itr, nil
}
return itr, nil
}
func (itr *iterator) Close() (err error) {
if itr.statement != nil {
err = itr.statement.Close()
}
itr.valid = false
itr.statement = nil
itr.rows = nil
return err
}
// Domain returns the domain of the iterator. The caller must not modify the
// return values.
func (itr *iterator) Domain() ([]byte, []byte) {
return itr.start, itr.end
}
func (itr *iterator) Key() []byte {
itr.assertIsValid()
return slices.Clone(itr.key)
}
func (itr *iterator) Value() []byte {
itr.assertIsValid()
return slices.Clone(itr.val)
}
func (itr *iterator) Valid() bool {
if !itr.valid || itr.rows.Err() != nil {
itr.valid = false
return itr.valid
}
// if key is at the end or past it, consider it invalid
if end := itr.end; end != nil {
if bytes.Compare(end, itr.Key()) <= 0 {
itr.valid = false
return itr.valid
}
}
return true
}
func (itr *iterator) Next() {
if itr.rows.Next() {
itr.parseRow()
return
}
itr.valid = false
}
func (itr *iterator) Error() error {
if err := itr.rows.Err(); err != nil {
return err
}
return itr.err
}
func (itr *iterator) parseRow() {
var (
key []byte
value []byte
)
if err := itr.rows.Scan(&key, &value); err != nil {
itr.err = fmt.Errorf("failed to scan row: %s", err)
itr.valid = false
return
}
itr.key = key
itr.val = value
}
func (itr *iterator) assertIsValid() {
if !itr.valid {
panic("iterator is invalid")
}
}