feat(collections): implement Iteration (#14222)

Co-authored-by: testinginprod <testinginprod@somewhere.idk>
Co-authored-by: Aaron Craelius <aaron@regen.network>
This commit is contained in:
testinginprod 2022-12-14 18:02:40 +01:00 committed by GitHub
parent 2410b846e3
commit 7050eb91f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 588 additions and 26 deletions

333
collections/iter.go Normal file
View File

@ -0,0 +1,333 @@
package collections
import (
"context"
"errors"
"fmt"
storetypes "github.com/cosmos/cosmos-sdk/store/types"
)
// ErrInvalidIterator is returned when an Iterate call resulted in an invalid iterator.
var ErrInvalidIterator = errors.New("collections: invalid iterator")
// Order defines the key order.
type Order uint8
const (
// OrderAscending instructs the Iterator to provide keys from the smallest to the greatest.
OrderAscending Order = 0
// OrderDescending instructs the Iterator to provide keys from the greatest to the smallest.
OrderDescending Order = 1
)
// BoundInclusive creates a Bound of the provided key K
// which is inclusive. Meaning, if it is used as Ranger.RangeValues start,
// the provided key will be included if it exists in the Iterator range.
func BoundInclusive[K any](key K) *Bound[K] {
return &Bound[K]{
value: key,
inclusive: true,
}
}
// BoundExclusive creates a Bound of the provided key K
// which is exclusive. Meaning, if it is used as Ranger.RangeValues start,
// the provided key will be excluded if it exists in the Iterator range.
func BoundExclusive[K any](key K) *Bound[K] {
return &Bound[K]{
value: key,
inclusive: false,
}
}
// Bound defines key bounds for Start and Ends of iterator ranges.
type Bound[K any] struct {
value K
inclusive bool
}
// Ranger defines a generic interface that provides a range of keys.
type Ranger[K any] interface {
// RangeValues is defined by Ranger implementers.
// It provides instructions to generate an Iterator instance.
// If prefix is not nil, then the Iterator will return only the keys which start
// with the given prefix.
// If start is not nil, then the Iterator will return only keys which are greater than the provided start
// or greater equal depending on the bound is inclusive or exclusive.
// If end is not nil, then the Iterator will return only keys which are smaller than the provided end
// or smaller equal depending on the bound is inclusive or exclusive.
RangeValues() (prefix *K, start *Bound[K], end *Bound[K], order Order, err error)
}
// Range is a Ranger implementer.
type Range[K any] struct {
prefix *K
start *Bound[K]
end *Bound[K]
order Order
}
// Prefix sets a fixed prefix for the key range.
func (r *Range[K]) Prefix(key K) *Range[K] {
r.prefix = &key
return r
}
// StartInclusive makes the range contain only keys which are bigger or equal to the provided start K.
func (r *Range[K]) StartInclusive(start K) *Range[K] {
r.start = BoundInclusive(start)
return r
}
// StartExclusive makes the range contain only keys which are bigger to the provided start K.
func (r *Range[K]) StartExclusive(start K) *Range[K] {
r.start = BoundExclusive(start)
return r
}
// EndInclusive makes the range contain only keys which are smaller or equal to the provided end K.
func (r *Range[K]) EndInclusive(end K) *Range[K] {
r.end = BoundInclusive(end)
return r
}
// EndExclusive makes the range contain only keys which are smaller to the provided end K.
func (r *Range[K]) EndExclusive(end K) *Range[K] {
r.end = BoundExclusive(end)
return r
}
func (r *Range[K]) Descending() *Range[K] {
r.order = OrderDescending
return r
}
// test sentinel error
var errRange = errors.New("collections: range error")
var errOrder = errors.New("collections: invalid order")
func (r *Range[K]) RangeValues() (prefix *K, start *Bound[K], end *Bound[K], order Order, err error) {
if r.prefix != nil && (r.end != nil || r.start != nil) {
return nil, nil, nil, order, fmt.Errorf("%w: prefix must not be set if either start or end are specified", errRange)
}
return r.prefix, r.start, r.end, r.order, nil
}
// iteratorFromRanger generates an Iterator instance, with the proper prefixing and ranging.
// a nil Ranger can be seen as an ascending iteration over all the possible keys.
func iteratorFromRanger[K, V any](ctx context.Context, m Map[K, V], r Ranger[K]) (iter Iterator[K, V], err error) {
var (
prefix *K
start *Bound[K]
end *Bound[K]
order = OrderAscending
)
// if Ranger is specified then we override the defaults
if r != nil {
prefix, start, end, order, err = r.RangeValues()
if err != nil {
return iter, err
}
}
if prefix != nil && (start != nil || end != nil) {
return iter, fmt.Errorf("%w: prefix must not be set if either start or end are specified", errRange)
}
// compute start and end bytes
var startBytes, endBytes []byte
if prefix != nil {
startBytes, endBytes, err = prefixStartEndBytes(m, *prefix)
if err != nil {
return iter, err
}
} else {
startBytes, endBytes, err = rangeStartEndBytes(m, start, end)
if err != nil {
return iter, err
}
}
// get store
store, err := m.getStore(ctx)
if err != nil {
return iter, err
}
// create iter
var storeIter storetypes.Iterator
switch order {
case OrderAscending:
storeIter = store.Iterator(startBytes, endBytes)
case OrderDescending:
storeIter = store.ReverseIterator(startBytes, endBytes)
default:
return iter, fmt.Errorf("%w: %d", errOrder, order)
}
// check if valid
if !storeIter.Valid() {
return iter, ErrInvalidIterator
}
// all good
iter.kc = m.kc
iter.vc = m.vc
iter.prefixLength = len(m.prefix)
iter.iter = storeIter
return iter, nil
}
// rangeStartEndBytes computes a range's start and end bytes to be passed to the store's iterator.
func rangeStartEndBytes[K, V any](m Map[K, V], start, end *Bound[K]) (startBytes, endBytes []byte, err error) {
startBytes = m.prefix
if start != nil {
startBytes, err = encodeKeyWithPrefix(m.prefix, m.kc, start.value)
if err != nil {
return startBytes, endBytes, err
}
// the start of iterators is by default inclusive,
// in order to make it exclusive we extend the start
// by one single byte.
if !start.inclusive {
startBytes = extendOneByte(startBytes)
}
}
if end != nil {
endBytes, err = encodeKeyWithPrefix(m.prefix, m.kc, end.value)
if err != nil {
return startBytes, endBytes, err
}
// the end of iterators is by default exclusive
// in order to make it inclusive we extend the end
// by one single byte.
if end.inclusive {
endBytes = extendOneByte(endBytes)
}
} else {
// if end is not specified then we simply are
// inclusive up to the last key of the Prefix
// of the collection.
endBytes = storetypes.PrefixEndBytes(m.prefix)
}
return startBytes, endBytes, nil
}
// prefixStartEndBytes returns the start and end bytes to be provided to the store's iterator, considering we're prefixing
// over a specific key.
func prefixStartEndBytes[K, V any](m Map[K, V], prefix K) (startBytes, endBytes []byte, err error) {
startBytes, err = encodeKeyWithPrefix(m.prefix, m.kc, prefix)
if err != nil {
return
}
return startBytes, storetypes.PrefixEndBytes(startBytes), nil
}
// Iterator defines a generic wrapper around an sdk.Iterator.
// This iterator provides automatic key and value encoding,
// it assumes all the keys and values contained within the sdk.Iterator
// range are the same.
type Iterator[K, V any] struct {
kc KeyCodec[K]
vc ValueCodec[V]
iter storetypes.Iterator
prefixLength int // prefixLength refers to the bytes provided by Prefix.Bytes, not Ranger.RangeValues() prefix.
}
// Value returns the current iterator value bytes decoded.
func (i Iterator[K, V]) Value() (V, error) {
return i.vc.Decode(i.iter.Value())
}
// Key returns the current sdk.Iterator decoded key.
func (i Iterator[K, V]) Key() (K, error) {
bytesKey := i.iter.Key()[i.prefixLength:] // strip prefix namespace
read, key, err := i.kc.Decode(bytesKey)
if err != nil {
var k K
return k, err
}
if read != len(bytesKey) {
var k K
return k, fmt.Errorf("%w: key decoder didn't fully consume the key: %T %x %d", ErrEncoding, i.kc, bytesKey, read)
}
return key, nil
}
// Values fully consumes the iterator and returns all the decoded values contained within the range.
func (i Iterator[K, V]) Values() ([]V, error) {
defer i.Close()
var values []V
for ; i.iter.Valid(); i.iter.Next() {
value, err := i.Value()
if err != nil {
return nil, err
}
values = append(values, value)
}
return values, nil
}
// Keys fully consumes the iterator and returns all the decoded keys contained within the range.
func (i Iterator[K, V]) Keys() ([]K, error) {
defer i.Close()
var keys []K
for ; i.iter.Valid(); i.iter.Next() {
key, err := i.Key()
if err != nil {
return nil, err
}
keys = append(keys, key)
}
return keys, nil
}
// KeyValue returns the current key and value decoded.
func (i Iterator[K, V]) KeyValue() (kv KeyValue[K, V], err error) {
key, err := i.Key()
if err != nil {
return kv, err
}
value, err := i.Value()
if err != nil {
return kv, err
}
kv.Key = key
kv.Value = value
return kv, nil
}
// KeyValues fully consumes the iterator and returns the list of key and values within the iterator range.
func (i Iterator[K, V]) KeyValues() ([]KeyValue[K, V], error) {
defer i.Close()
var kvs []KeyValue[K, V]
for ; i.iter.Valid(); i.iter.Next() {
kv, err := i.KeyValue()
if err != nil {
return nil, err
}
kvs = append(kvs, kv)
}
return kvs, nil
}
func (i Iterator[K, V]) Close() error { return i.iter.Close() }
func (i Iterator[K, V]) Next() { i.iter.Next() }
func (i Iterator[K, V]) Valid() bool { return i.iter.Valid() }
// KeyValue represent a Key and Value pair of an iteration.
type KeyValue[K, V any] struct {
Key K
Value V
}
func extendOneByte(b []byte) []byte {
return append(b, 0)
}

201
collections/iter_test.go Normal file
View File

@ -0,0 +1,201 @@
package collections
import (
"fmt"
"github.com/stretchr/testify/require"
"testing"
)
func TestIteratorBasic(t *testing.T) {
sk, ctx := deps()
m := NewMap(sk, NewPrefix("some super amazing prefix"), StringKey, Uint64Value)
for i := uint64(1); i <= 2; i++ {
require.NoError(t, m.Set(ctx, fmt.Sprintf("%d", i), i))
}
iter, err := m.Iterate(ctx, nil)
require.NoError(t, err)
defer iter.Close()
// key codec
key, err := iter.Key()
require.NoError(t, err)
require.Equal(t, "1", key)
// value codec
value, err := iter.Value()
require.NoError(t, err)
require.Equal(t, uint64(1), value)
// assert expected prefixing on iter
require.Equal(t, m.prefix, iter.iter.Key()[:len(m.prefix)])
// advance iter
iter.Next()
require.True(t, iter.Valid())
// key 2
key, err = iter.Key()
require.NoError(t, err)
require.Equal(t, "2", key)
// value 2
value, err = iter.Value()
require.NoError(t, err)
require.Equal(t, uint64(2), value)
// call next, invalid
iter.Next()
require.False(t, iter.Valid())
// close no errors
require.NoError(t, iter.Close())
}
func TestIteratorKeyValues(t *testing.T) {
sk, ctx := deps()
m := NewMap(sk, NewPrefix("some super amazing prefix"), StringKey, Uint64Value)
for i := uint64(0); i <= 5; i++ {
require.NoError(t, m.Set(ctx, fmt.Sprintf("%d", i), i))
}
// test keys
iter, err := m.Iterate(ctx, nil)
require.NoError(t, err)
keys, err := iter.Keys()
require.NoError(t, err)
for i, key := range keys {
require.Equal(t, fmt.Sprintf("%d", i), key)
}
require.NoError(t, iter.Close())
require.False(t, iter.Valid())
// test values
iter, err = m.Iterate(ctx, nil)
require.NoError(t, err)
values, err := iter.Values()
require.NoError(t, err)
for i, value := range values {
require.Equal(t, uint64(i), value)
}
require.NoError(t, iter.Close())
require.False(t, iter.Valid())
// test key value pairings
iter, err = m.Iterate(ctx, nil)
require.NoError(t, err)
kvs, err := iter.KeyValues()
require.NoError(t, err)
for i, kv := range kvs {
require.Equal(t, fmt.Sprintf("%d", i), kv.Key)
require.Equal(t, uint64(i), kv.Value)
}
require.NoError(t, iter.Close())
require.False(t, iter.Valid())
}
func TestIteratorPrefixing(t *testing.T) {
sk, ctx := deps()
m := NewMap(sk, NewPrefix("cool"), StringKey, Uint64Value)
require.NoError(t, m.Set(ctx, "A1", 11))
require.NoError(t, m.Set(ctx, "A2", 12))
require.NoError(t, m.Set(ctx, "B1", 21))
iter, err := m.Iterate(ctx, new(Range[string]).Prefix("A"))
require.NoError(t, err)
keys, err := iter.Keys()
require.NoError(t, err)
require.Equal(t, []string{"A1", "A2"}, keys)
}
func TestIteratorRanging(t *testing.T) {
sk, ctx := deps()
m := NewMap(sk, NewPrefix("cool"), Uint64Key, Uint64Value)
for i := uint64(0); i <= 7; i++ {
require.NoError(t, m.Set(ctx, i, i))
}
// let's range (1-5]; expected: 2..5
iter, err := m.Iterate(ctx, (&Range[uint64]{}).StartExclusive(1).EndInclusive(5))
require.NoError(t, err)
result, err := iter.Keys()
require.NoError(t, err)
require.Equal(t, []uint64{2, 3, 4, 5}, result)
// let's range [1-5); expected 1..4
iter, err = m.Iterate(ctx, (&Range[uint64]{}).StartInclusive(1).EndExclusive(5))
require.NoError(t, err)
result, err = iter.Keys()
require.NoError(t, err)
require.Equal(t, []uint64{1, 2, 3, 4}, result)
// let's range [1-5) descending; expected 4..1
iter, err = m.Iterate(ctx, (&Range[uint64]{}).StartInclusive(1).EndExclusive(5).Descending())
require.NoError(t, err)
result, err = iter.Keys()
require.NoError(t, err)
require.Equal(t, []uint64{4, 3, 2, 1}, result)
// test iterator invalid
_, err = m.Iterate(ctx, new(Range[uint64]).StartInclusive(10).EndInclusive(1))
require.ErrorIs(t, err, ErrInvalidIterator)
}
func TestRange(t *testing.T) {
type test struct {
rng *Range[string]
wantPrefix *string
wantStart *Bound[string]
wantEnd *Bound[string]
wantOrder Order
wantErr error
}
cases := map[string]test{
"ok - empty": {
rng: new(Range[string]),
},
"ok - start exclusive - end exclusive": {
rng: new(Range[string]).StartExclusive("A").EndExclusive("B"),
wantStart: BoundExclusive("A"),
wantEnd: BoundExclusive("B"),
},
"ok - start inclusive - end inclusive - descending": {
rng: new(Range[string]).StartInclusive("A").EndInclusive("B").Descending(),
wantStart: BoundInclusive("A"),
wantEnd: BoundInclusive("B"),
wantOrder: OrderDescending,
},
"ok - prefix": {
rng: new(Range[string]).Prefix("A"),
wantPrefix: func() *string { p := "A"; return &p }(),
},
"err - prefix and start set": {
rng: new(Range[string]).Prefix("A").StartExclusive("B"),
wantErr: errRange,
},
"err - prefix and end set": {
rng: new(Range[string]).Prefix("A").StartInclusive("B"),
wantErr: errRange,
},
}
for name, tc := range cases {
tc := tc
t.Run(name, func(t *testing.T) {
gotPrefix, gotStart, gotEnd, gotOrder, gotErr := tc.rng.RangeValues()
require.ErrorIs(t, gotErr, tc.wantErr)
require.Equal(t, tc.wantPrefix, gotPrefix)
require.Equal(t, tc.wantStart, gotStart)
require.Equal(t, tc.wantEnd, gotEnd)
require.Equal(t, tc.wantOrder, gotOrder)
})
}
}

View File

@ -7,10 +7,14 @@ import (
"strconv"
)
// Uint64Key can be used to encode uint64 keys.
// Encoding is big endian to retain ordering.
var Uint64Key KeyCodec[uint64] = uint64Key{}
var (
// Uint64Key can be used to encode uint64 keys.
// Encoding is big endian to retain ordering.
Uint64Key KeyCodec[uint64] = uint64Key{}
// StringKey can be used to encode string keys.
// The encoding just converts the string to bytes.
StringKey KeyCodec[string] = stringKey{}
)
var errDecodeKeySize = errors.New("decode error, wrong byte key size")
type uint64Key struct{}
@ -20,19 +24,41 @@ func (u uint64Key) Encode(buffer []byte, key uint64) (int, error) {
return 8, nil
}
func (u uint64Key) Decode(buffer []byte) (int, uint64, error) {
func (uint64Key) Decode(buffer []byte) (int, uint64, error) {
if size := len(buffer); size < 8 {
return 0, 0, fmt.Errorf("%w: wanted at least 8, got: %d", errDecodeKeySize, size)
}
return 8, binary.BigEndian.Uint64(buffer), nil
}
func (u uint64Key) Size(_ uint64) int { return 8 }
func (uint64Key) Size(_ uint64) int { return 8 }
func (u uint64Key) Stringify(key uint64) string {
func (uint64Key) Stringify(key uint64) string {
return strconv.FormatUint(key, 10)
}
func (u uint64Key) KeyType() string {
func (uint64Key) KeyType() string {
return "uint64"
}
type stringKey struct{}
func (stringKey) Encode(buffer []byte, key string) (int, error) {
return copy(buffer, key), nil
}
func (stringKey) Decode(buffer []byte) (int, string, error) {
return len(buffer), string(buffer), nil
}
func (stringKey) Size(key string) int {
return len(key)
}
func (stringKey) Stringify(key string) string {
return key
}
func (stringKey) KeyType() string {
return "string"
}

View File

@ -34,7 +34,8 @@ type Map[K, V any] struct {
// Set maps the provided value to the provided key in the store.
// Errors with ErrEncoding if key or value encoding fails.
func (m Map[K, V]) Set(ctx context.Context, key K, value V) error {
keyBytes, err := m.encodeKey(key)
bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key)
if err != nil {
return err
}
@ -48,7 +49,7 @@ func (m Map[K, V]) Set(ctx context.Context, key K, value V) error {
if err != nil {
return err
}
store.Set(keyBytes, valueBytes)
store.Set(bytesKey, valueBytes)
return nil
}
@ -56,7 +57,7 @@ func (m Map[K, V]) Set(ctx context.Context, key K, value V) error {
// errors with ErrNotFound if the key does not exist, or
// with ErrEncoding if the key or value decoding fails.
func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) {
keyBytes, err := m.encodeKey(key)
bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key)
if err != nil {
var v V
return v, err
@ -67,7 +68,7 @@ func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) {
var v V
return v, err
}
valueBytes := store.Get(keyBytes)
valueBytes := store.Get(bytesKey)
if valueBytes == nil {
var v V
return v, fmt.Errorf("%w: key '%s' of type %s", ErrNotFound, m.kc.Stringify(key), m.vc.ValueType())
@ -75,7 +76,6 @@ func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) {
v, err := m.vc.Decode(valueBytes)
if err != nil {
var v V
return v, fmt.Errorf("%w: value decode: %s", ErrEncoding, err) // TODO: use multi err wrapping in go1.20: https://github.com/golang/go/issues/53435
}
return v, nil
@ -84,7 +84,7 @@ func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) {
// Has reports whether the key is present in storage or not.
// Errors with ErrEncoding if key encoding fails.
func (m Map[K, V]) Has(ctx context.Context, key K) (bool, error) {
bytesKey, err := m.encodeKey(key)
bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key)
if err != nil {
return false, err
}
@ -99,7 +99,7 @@ func (m Map[K, V]) Has(ctx context.Context, key K) (bool, error) {
// Errors with ErrEncoding if key encoding fails.
// If the key does not exist then this is a no-op.
func (m Map[K, V]) Remove(ctx context.Context, key K) error {
bytesKey, err := m.encodeKey(key)
bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key)
if err != nil {
return err
}
@ -111,6 +111,12 @@ func (m Map[K, V]) Remove(ctx context.Context, key K) error {
return nil
}
// Iterate provides an Iterator over K and V. It accepts a Ranger interface.
// A nil ranger equals to iterate over all the keys in ascending order.
func (m Map[K, V]) Iterate(ctx context.Context, ranger Ranger[K]) (Iterator[K, V], error) {
return iteratorFromRanger(ctx, m, ranger)
}
func (m Map[K, V]) getStore(ctx context.Context) (storetypes.KVStore, error) {
provider, ok := ctx.(StorageProvider)
if !ok {
@ -119,14 +125,14 @@ func (m Map[K, V]) getStore(ctx context.Context) (storetypes.KVStore, error) {
return provider.KVStore(m.sk), nil
}
func (m Map[K, V]) encodeKey(key K) ([]byte, error) {
prefixLen := len(m.prefix)
func encodeKeyWithPrefix[K any](prefix []byte, kc KeyCodec[K], key K) ([]byte, error) {
prefixLen := len(prefix)
// preallocate buffer
keyBytes := make([]byte, prefixLen+m.kc.Size(key))
keyBytes := make([]byte, prefixLen+kc.Size(key))
// put prefix
copy(keyBytes, m.prefix)
copy(keyBytes, prefix)
// put key
_, err := m.kc.Encode(keyBytes[prefixLen:], key)
_, err := kc.Encode(keyBytes[prefixLen:], key)
if err != nil {
return nil, fmt.Errorf("%w: key encode: %s", ErrEncoding, err) // TODO: use multi err wrapping in go1.20: https://github.com/golang/go/issues/53435
}

View File

@ -3,8 +3,6 @@ package collections
import (
"testing"
storetypes "github.com/cosmos/cosmos-sdk/store/types"
"github.com/stretchr/testify/require"
)
@ -35,14 +33,12 @@ func TestMap(t *testing.T) {
require.False(t, has)
}
func TestMap_encodeKey(t *testing.T) {
func Test_encodeKey(t *testing.T) {
prefix := "prefix"
number := []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}
expectedKey := append([]byte(prefix), number...)
m := NewMap(storetypes.NewKVStoreKey("test"), NewPrefix(prefix), Uint64Key, Uint64Value)
gotKey, err := m.encodeKey(0)
gotKey, err := encodeKeyWithPrefix(NewPrefix(prefix).Bytes(), Uint64Key, 0)
require.NoError(t, err)
require.Equal(t, expectedKey, gotKey)
}