diff --git a/collections/iter.go b/collections/iter.go new file mode 100644 index 0000000000..ac48e7a398 --- /dev/null +++ b/collections/iter.go @@ -0,0 +1,333 @@ +package collections + +import ( + "context" + "errors" + "fmt" + storetypes "github.com/cosmos/cosmos-sdk/store/types" +) + +// ErrInvalidIterator is returned when an Iterate call resulted in an invalid iterator. +var ErrInvalidIterator = errors.New("collections: invalid iterator") + +// Order defines the key order. +type Order uint8 + +const ( + // OrderAscending instructs the Iterator to provide keys from the smallest to the greatest. + OrderAscending Order = 0 + // OrderDescending instructs the Iterator to provide keys from the greatest to the smallest. + OrderDescending Order = 1 +) + +// BoundInclusive creates a Bound of the provided key K +// which is inclusive. Meaning, if it is used as Ranger.RangeValues start, +// the provided key will be included if it exists in the Iterator range. +func BoundInclusive[K any](key K) *Bound[K] { + return &Bound[K]{ + value: key, + inclusive: true, + } +} + +// BoundExclusive creates a Bound of the provided key K +// which is exclusive. Meaning, if it is used as Ranger.RangeValues start, +// the provided key will be excluded if it exists in the Iterator range. +func BoundExclusive[K any](key K) *Bound[K] { + return &Bound[K]{ + value: key, + inclusive: false, + } +} + +// Bound defines key bounds for Start and Ends of iterator ranges. +type Bound[K any] struct { + value K + inclusive bool +} + +// Ranger defines a generic interface that provides a range of keys. +type Ranger[K any] interface { + // RangeValues is defined by Ranger implementers. + // It provides instructions to generate an Iterator instance. + // If prefix is not nil, then the Iterator will return only the keys which start + // with the given prefix. + // If start is not nil, then the Iterator will return only keys which are greater than the provided start + // or greater equal depending on the bound is inclusive or exclusive. + // If end is not nil, then the Iterator will return only keys which are smaller than the provided end + // or smaller equal depending on the bound is inclusive or exclusive. + RangeValues() (prefix *K, start *Bound[K], end *Bound[K], order Order, err error) +} + +// Range is a Ranger implementer. +type Range[K any] struct { + prefix *K + start *Bound[K] + end *Bound[K] + order Order +} + +// Prefix sets a fixed prefix for the key range. +func (r *Range[K]) Prefix(key K) *Range[K] { + r.prefix = &key + return r +} + +// StartInclusive makes the range contain only keys which are bigger or equal to the provided start K. +func (r *Range[K]) StartInclusive(start K) *Range[K] { + r.start = BoundInclusive(start) + return r +} + +// StartExclusive makes the range contain only keys which are bigger to the provided start K. +func (r *Range[K]) StartExclusive(start K) *Range[K] { + r.start = BoundExclusive(start) + return r +} + +// EndInclusive makes the range contain only keys which are smaller or equal to the provided end K. +func (r *Range[K]) EndInclusive(end K) *Range[K] { + r.end = BoundInclusive(end) + return r +} + +// EndExclusive makes the range contain only keys which are smaller to the provided end K. +func (r *Range[K]) EndExclusive(end K) *Range[K] { + r.end = BoundExclusive(end) + return r +} + +func (r *Range[K]) Descending() *Range[K] { + r.order = OrderDescending + return r +} + +// test sentinel error +var errRange = errors.New("collections: range error") +var errOrder = errors.New("collections: invalid order") + +func (r *Range[K]) RangeValues() (prefix *K, start *Bound[K], end *Bound[K], order Order, err error) { + if r.prefix != nil && (r.end != nil || r.start != nil) { + return nil, nil, nil, order, fmt.Errorf("%w: prefix must not be set if either start or end are specified", errRange) + } + return r.prefix, r.start, r.end, r.order, nil +} + +// iteratorFromRanger generates an Iterator instance, with the proper prefixing and ranging. +// a nil Ranger can be seen as an ascending iteration over all the possible keys. +func iteratorFromRanger[K, V any](ctx context.Context, m Map[K, V], r Ranger[K]) (iter Iterator[K, V], err error) { + var ( + prefix *K + start *Bound[K] + end *Bound[K] + order = OrderAscending + ) + // if Ranger is specified then we override the defaults + if r != nil { + prefix, start, end, order, err = r.RangeValues() + if err != nil { + return iter, err + } + } + if prefix != nil && (start != nil || end != nil) { + return iter, fmt.Errorf("%w: prefix must not be set if either start or end are specified", errRange) + } + + // compute start and end bytes + var startBytes, endBytes []byte + if prefix != nil { + startBytes, endBytes, err = prefixStartEndBytes(m, *prefix) + if err != nil { + return iter, err + } + } else { + startBytes, endBytes, err = rangeStartEndBytes(m, start, end) + if err != nil { + return iter, err + } + } + + // get store + store, err := m.getStore(ctx) + if err != nil { + return iter, err + } + + // create iter + var storeIter storetypes.Iterator + switch order { + case OrderAscending: + storeIter = store.Iterator(startBytes, endBytes) + case OrderDescending: + storeIter = store.ReverseIterator(startBytes, endBytes) + default: + return iter, fmt.Errorf("%w: %d", errOrder, order) + } + + // check if valid + if !storeIter.Valid() { + return iter, ErrInvalidIterator + } + + // all good + iter.kc = m.kc + iter.vc = m.vc + iter.prefixLength = len(m.prefix) + iter.iter = storeIter + return iter, nil +} + +// rangeStartEndBytes computes a range's start and end bytes to be passed to the store's iterator. +func rangeStartEndBytes[K, V any](m Map[K, V], start, end *Bound[K]) (startBytes, endBytes []byte, err error) { + startBytes = m.prefix + if start != nil { + startBytes, err = encodeKeyWithPrefix(m.prefix, m.kc, start.value) + if err != nil { + return startBytes, endBytes, err + } + // the start of iterators is by default inclusive, + // in order to make it exclusive we extend the start + // by one single byte. + if !start.inclusive { + startBytes = extendOneByte(startBytes) + } + } + if end != nil { + endBytes, err = encodeKeyWithPrefix(m.prefix, m.kc, end.value) + if err != nil { + return startBytes, endBytes, err + } + // the end of iterators is by default exclusive + // in order to make it inclusive we extend the end + // by one single byte. + if end.inclusive { + endBytes = extendOneByte(endBytes) + } + } else { + // if end is not specified then we simply are + // inclusive up to the last key of the Prefix + // of the collection. + endBytes = storetypes.PrefixEndBytes(m.prefix) + } + + return startBytes, endBytes, nil +} + +// prefixStartEndBytes returns the start and end bytes to be provided to the store's iterator, considering we're prefixing +// over a specific key. +func prefixStartEndBytes[K, V any](m Map[K, V], prefix K) (startBytes, endBytes []byte, err error) { + startBytes, err = encodeKeyWithPrefix(m.prefix, m.kc, prefix) + if err != nil { + return + } + return startBytes, storetypes.PrefixEndBytes(startBytes), nil +} + +// Iterator defines a generic wrapper around an sdk.Iterator. +// This iterator provides automatic key and value encoding, +// it assumes all the keys and values contained within the sdk.Iterator +// range are the same. +type Iterator[K, V any] struct { + kc KeyCodec[K] + vc ValueCodec[V] + + iter storetypes.Iterator + + prefixLength int // prefixLength refers to the bytes provided by Prefix.Bytes, not Ranger.RangeValues() prefix. +} + +// Value returns the current iterator value bytes decoded. +func (i Iterator[K, V]) Value() (V, error) { + return i.vc.Decode(i.iter.Value()) +} + +// Key returns the current sdk.Iterator decoded key. +func (i Iterator[K, V]) Key() (K, error) { + bytesKey := i.iter.Key()[i.prefixLength:] // strip prefix namespace + + read, key, err := i.kc.Decode(bytesKey) + if err != nil { + var k K + return k, err + } + if read != len(bytesKey) { + var k K + return k, fmt.Errorf("%w: key decoder didn't fully consume the key: %T %x %d", ErrEncoding, i.kc, bytesKey, read) + } + return key, nil +} + +// Values fully consumes the iterator and returns all the decoded values contained within the range. +func (i Iterator[K, V]) Values() ([]V, error) { + defer i.Close() + + var values []V + for ; i.iter.Valid(); i.iter.Next() { + value, err := i.Value() + if err != nil { + return nil, err + } + values = append(values, value) + } + return values, nil +} + +// Keys fully consumes the iterator and returns all the decoded keys contained within the range. +func (i Iterator[K, V]) Keys() ([]K, error) { + defer i.Close() + + var keys []K + for ; i.iter.Valid(); i.iter.Next() { + key, err := i.Key() + if err != nil { + return nil, err + } + keys = append(keys, key) + } + return keys, nil +} + +// KeyValue returns the current key and value decoded. +func (i Iterator[K, V]) KeyValue() (kv KeyValue[K, V], err error) { + key, err := i.Key() + if err != nil { + return kv, err + } + value, err := i.Value() + if err != nil { + return kv, err + } + kv.Key = key + kv.Value = value + return kv, nil +} + +// KeyValues fully consumes the iterator and returns the list of key and values within the iterator range. +func (i Iterator[K, V]) KeyValues() ([]KeyValue[K, V], error) { + defer i.Close() + + var kvs []KeyValue[K, V] + for ; i.iter.Valid(); i.iter.Next() { + kv, err := i.KeyValue() + if err != nil { + return nil, err + } + kvs = append(kvs, kv) + } + + return kvs, nil +} + +func (i Iterator[K, V]) Close() error { return i.iter.Close() } +func (i Iterator[K, V]) Next() { i.iter.Next() } +func (i Iterator[K, V]) Valid() bool { return i.iter.Valid() } + +// KeyValue represent a Key and Value pair of an iteration. +type KeyValue[K, V any] struct { + Key K + Value V +} + +func extendOneByte(b []byte) []byte { + return append(b, 0) +} diff --git a/collections/iter_test.go b/collections/iter_test.go new file mode 100644 index 0000000000..9dfc3e3f0c --- /dev/null +++ b/collections/iter_test.go @@ -0,0 +1,201 @@ +package collections + +import ( + "fmt" + "github.com/stretchr/testify/require" + "testing" +) + +func TestIteratorBasic(t *testing.T) { + sk, ctx := deps() + m := NewMap(sk, NewPrefix("some super amazing prefix"), StringKey, Uint64Value) + + for i := uint64(1); i <= 2; i++ { + require.NoError(t, m.Set(ctx, fmt.Sprintf("%d", i), i)) + } + + iter, err := m.Iterate(ctx, nil) + require.NoError(t, err) + defer iter.Close() + + // key codec + key, err := iter.Key() + require.NoError(t, err) + require.Equal(t, "1", key) + + // value codec + value, err := iter.Value() + require.NoError(t, err) + require.Equal(t, uint64(1), value) + + // assert expected prefixing on iter + require.Equal(t, m.prefix, iter.iter.Key()[:len(m.prefix)]) + + // advance iter + iter.Next() + require.True(t, iter.Valid()) + + // key 2 + key, err = iter.Key() + require.NoError(t, err) + require.Equal(t, "2", key) + + // value 2 + value, err = iter.Value() + require.NoError(t, err) + require.Equal(t, uint64(2), value) + + // call next, invalid + iter.Next() + require.False(t, iter.Valid()) + // close no errors + require.NoError(t, iter.Close()) +} + +func TestIteratorKeyValues(t *testing.T) { + sk, ctx := deps() + m := NewMap(sk, NewPrefix("some super amazing prefix"), StringKey, Uint64Value) + + for i := uint64(0); i <= 5; i++ { + require.NoError(t, m.Set(ctx, fmt.Sprintf("%d", i), i)) + } + + // test keys + iter, err := m.Iterate(ctx, nil) + require.NoError(t, err) + keys, err := iter.Keys() + require.NoError(t, err) + + for i, key := range keys { + require.Equal(t, fmt.Sprintf("%d", i), key) + } + require.NoError(t, iter.Close()) + require.False(t, iter.Valid()) + + // test values + iter, err = m.Iterate(ctx, nil) + require.NoError(t, err) + values, err := iter.Values() + require.NoError(t, err) + + for i, value := range values { + require.Equal(t, uint64(i), value) + } + require.NoError(t, iter.Close()) + require.False(t, iter.Valid()) + + // test key value pairings + iter, err = m.Iterate(ctx, nil) + require.NoError(t, err) + kvs, err := iter.KeyValues() + require.NoError(t, err) + + for i, kv := range kvs { + require.Equal(t, fmt.Sprintf("%d", i), kv.Key) + require.Equal(t, uint64(i), kv.Value) + } + require.NoError(t, iter.Close()) + require.False(t, iter.Valid()) +} + +func TestIteratorPrefixing(t *testing.T) { + sk, ctx := deps() + m := NewMap(sk, NewPrefix("cool"), StringKey, Uint64Value) + + require.NoError(t, m.Set(ctx, "A1", 11)) + require.NoError(t, m.Set(ctx, "A2", 12)) + require.NoError(t, m.Set(ctx, "B1", 21)) + + iter, err := m.Iterate(ctx, new(Range[string]).Prefix("A")) + require.NoError(t, err) + keys, err := iter.Keys() + require.NoError(t, err) + require.Equal(t, []string{"A1", "A2"}, keys) +} + +func TestIteratorRanging(t *testing.T) { + sk, ctx := deps() + m := NewMap(sk, NewPrefix("cool"), Uint64Key, Uint64Value) + + for i := uint64(0); i <= 7; i++ { + require.NoError(t, m.Set(ctx, i, i)) + } + + // let's range (1-5]; expected: 2..5 + iter, err := m.Iterate(ctx, (&Range[uint64]{}).StartExclusive(1).EndInclusive(5)) + require.NoError(t, err) + result, err := iter.Keys() + require.NoError(t, err) + require.Equal(t, []uint64{2, 3, 4, 5}, result) + + // let's range [1-5); expected 1..4 + iter, err = m.Iterate(ctx, (&Range[uint64]{}).StartInclusive(1).EndExclusive(5)) + require.NoError(t, err) + result, err = iter.Keys() + require.NoError(t, err) + require.Equal(t, []uint64{1, 2, 3, 4}, result) + + // let's range [1-5) descending; expected 4..1 + iter, err = m.Iterate(ctx, (&Range[uint64]{}).StartInclusive(1).EndExclusive(5).Descending()) + require.NoError(t, err) + result, err = iter.Keys() + require.NoError(t, err) + require.Equal(t, []uint64{4, 3, 2, 1}, result) + + // test iterator invalid + _, err = m.Iterate(ctx, new(Range[uint64]).StartInclusive(10).EndInclusive(1)) + require.ErrorIs(t, err, ErrInvalidIterator) +} + +func TestRange(t *testing.T) { + type test struct { + rng *Range[string] + wantPrefix *string + wantStart *Bound[string] + wantEnd *Bound[string] + wantOrder Order + wantErr error + } + + cases := map[string]test{ + "ok - empty": { + rng: new(Range[string]), + }, + "ok - start exclusive - end exclusive": { + rng: new(Range[string]).StartExclusive("A").EndExclusive("B"), + wantStart: BoundExclusive("A"), + wantEnd: BoundExclusive("B"), + }, + "ok - start inclusive - end inclusive - descending": { + rng: new(Range[string]).StartInclusive("A").EndInclusive("B").Descending(), + wantStart: BoundInclusive("A"), + wantEnd: BoundInclusive("B"), + wantOrder: OrderDescending, + }, + "ok - prefix": { + rng: new(Range[string]).Prefix("A"), + wantPrefix: func() *string { p := "A"; return &p }(), + }, + + "err - prefix and start set": { + rng: new(Range[string]).Prefix("A").StartExclusive("B"), + wantErr: errRange, + }, + "err - prefix and end set": { + rng: new(Range[string]).Prefix("A").StartInclusive("B"), + wantErr: errRange, + }, + } + + for name, tc := range cases { + tc := tc + t.Run(name, func(t *testing.T) { + gotPrefix, gotStart, gotEnd, gotOrder, gotErr := tc.rng.RangeValues() + require.ErrorIs(t, gotErr, tc.wantErr) + require.Equal(t, tc.wantPrefix, gotPrefix) + require.Equal(t, tc.wantStart, gotStart) + require.Equal(t, tc.wantEnd, gotEnd) + require.Equal(t, tc.wantOrder, gotOrder) + }) + } +} diff --git a/collections/keys.go b/collections/keys.go index 5440442d56..41488cd6fa 100644 --- a/collections/keys.go +++ b/collections/keys.go @@ -7,10 +7,14 @@ import ( "strconv" ) -// Uint64Key can be used to encode uint64 keys. -// Encoding is big endian to retain ordering. -var Uint64Key KeyCodec[uint64] = uint64Key{} - +var ( + // Uint64Key can be used to encode uint64 keys. + // Encoding is big endian to retain ordering. + Uint64Key KeyCodec[uint64] = uint64Key{} + // StringKey can be used to encode string keys. + // The encoding just converts the string to bytes. + StringKey KeyCodec[string] = stringKey{} +) var errDecodeKeySize = errors.New("decode error, wrong byte key size") type uint64Key struct{} @@ -20,19 +24,41 @@ func (u uint64Key) Encode(buffer []byte, key uint64) (int, error) { return 8, nil } -func (u uint64Key) Decode(buffer []byte) (int, uint64, error) { +func (uint64Key) Decode(buffer []byte) (int, uint64, error) { if size := len(buffer); size < 8 { return 0, 0, fmt.Errorf("%w: wanted at least 8, got: %d", errDecodeKeySize, size) } return 8, binary.BigEndian.Uint64(buffer), nil } -func (u uint64Key) Size(_ uint64) int { return 8 } +func (uint64Key) Size(_ uint64) int { return 8 } -func (u uint64Key) Stringify(key uint64) string { +func (uint64Key) Stringify(key uint64) string { return strconv.FormatUint(key, 10) } -func (u uint64Key) KeyType() string { +func (uint64Key) KeyType() string { return "uint64" } + +type stringKey struct{} + +func (stringKey) Encode(buffer []byte, key string) (int, error) { + return copy(buffer, key), nil +} + +func (stringKey) Decode(buffer []byte) (int, string, error) { + return len(buffer), string(buffer), nil +} + +func (stringKey) Size(key string) int { + return len(key) +} + +func (stringKey) Stringify(key string) string { + return key +} + +func (stringKey) KeyType() string { + return "string" +} diff --git a/collections/map.go b/collections/map.go index 8d9f1881e7..22280c0fe3 100644 --- a/collections/map.go +++ b/collections/map.go @@ -34,7 +34,8 @@ type Map[K, V any] struct { // Set maps the provided value to the provided key in the store. // Errors with ErrEncoding if key or value encoding fails. func (m Map[K, V]) Set(ctx context.Context, key K, value V) error { - keyBytes, err := m.encodeKey(key) + bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key) + if err != nil { return err } @@ -48,7 +49,7 @@ func (m Map[K, V]) Set(ctx context.Context, key K, value V) error { if err != nil { return err } - store.Set(keyBytes, valueBytes) + store.Set(bytesKey, valueBytes) return nil } @@ -56,7 +57,7 @@ func (m Map[K, V]) Set(ctx context.Context, key K, value V) error { // errors with ErrNotFound if the key does not exist, or // with ErrEncoding if the key or value decoding fails. func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) { - keyBytes, err := m.encodeKey(key) + bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key) if err != nil { var v V return v, err @@ -67,7 +68,7 @@ func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) { var v V return v, err } - valueBytes := store.Get(keyBytes) + valueBytes := store.Get(bytesKey) if valueBytes == nil { var v V return v, fmt.Errorf("%w: key '%s' of type %s", ErrNotFound, m.kc.Stringify(key), m.vc.ValueType()) @@ -75,7 +76,6 @@ func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) { v, err := m.vc.Decode(valueBytes) if err != nil { - var v V return v, fmt.Errorf("%w: value decode: %s", ErrEncoding, err) // TODO: use multi err wrapping in go1.20: https://github.com/golang/go/issues/53435 } return v, nil @@ -84,7 +84,7 @@ func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) { // Has reports whether the key is present in storage or not. // Errors with ErrEncoding if key encoding fails. func (m Map[K, V]) Has(ctx context.Context, key K) (bool, error) { - bytesKey, err := m.encodeKey(key) + bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key) if err != nil { return false, err } @@ -99,7 +99,7 @@ func (m Map[K, V]) Has(ctx context.Context, key K) (bool, error) { // Errors with ErrEncoding if key encoding fails. // If the key does not exist then this is a no-op. func (m Map[K, V]) Remove(ctx context.Context, key K) error { - bytesKey, err := m.encodeKey(key) + bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key) if err != nil { return err } @@ -111,6 +111,12 @@ func (m Map[K, V]) Remove(ctx context.Context, key K) error { return nil } +// Iterate provides an Iterator over K and V. It accepts a Ranger interface. +// A nil ranger equals to iterate over all the keys in ascending order. +func (m Map[K, V]) Iterate(ctx context.Context, ranger Ranger[K]) (Iterator[K, V], error) { + return iteratorFromRanger(ctx, m, ranger) +} + func (m Map[K, V]) getStore(ctx context.Context) (storetypes.KVStore, error) { provider, ok := ctx.(StorageProvider) if !ok { @@ -119,14 +125,14 @@ func (m Map[K, V]) getStore(ctx context.Context) (storetypes.KVStore, error) { return provider.KVStore(m.sk), nil } -func (m Map[K, V]) encodeKey(key K) ([]byte, error) { - prefixLen := len(m.prefix) +func encodeKeyWithPrefix[K any](prefix []byte, kc KeyCodec[K], key K) ([]byte, error) { + prefixLen := len(prefix) // preallocate buffer - keyBytes := make([]byte, prefixLen+m.kc.Size(key)) + keyBytes := make([]byte, prefixLen+kc.Size(key)) // put prefix - copy(keyBytes, m.prefix) + copy(keyBytes, prefix) // put key - _, err := m.kc.Encode(keyBytes[prefixLen:], key) + _, err := kc.Encode(keyBytes[prefixLen:], key) if err != nil { return nil, fmt.Errorf("%w: key encode: %s", ErrEncoding, err) // TODO: use multi err wrapping in go1.20: https://github.com/golang/go/issues/53435 } diff --git a/collections/map_test.go b/collections/map_test.go index 4665fa9438..3002f11c5a 100644 --- a/collections/map_test.go +++ b/collections/map_test.go @@ -3,8 +3,6 @@ package collections import ( "testing" - storetypes "github.com/cosmos/cosmos-sdk/store/types" - "github.com/stretchr/testify/require" ) @@ -35,14 +33,12 @@ func TestMap(t *testing.T) { require.False(t, has) } -func TestMap_encodeKey(t *testing.T) { +func Test_encodeKey(t *testing.T) { prefix := "prefix" number := []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0} expectedKey := append([]byte(prefix), number...) - m := NewMap(storetypes.NewKVStoreKey("test"), NewPrefix(prefix), Uint64Key, Uint64Value) - - gotKey, err := m.encodeKey(0) + gotKey, err := encodeKeyWithPrefix(NewPrefix(prefix).Bytes(), Uint64Key, 0) require.NoError(t, err) require.Equal(t, expectedKey, gotKey) }