feat(collections): implement Iteration (#14222)
Co-authored-by: testinginprod <testinginprod@somewhere.idk> Co-authored-by: Aaron Craelius <aaron@regen.network>
This commit is contained in:
parent
2410b846e3
commit
7050eb91f2
333
collections/iter.go
Normal file
333
collections/iter.go
Normal file
@ -0,0 +1,333 @@
|
||||
package collections
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
storetypes "github.com/cosmos/cosmos-sdk/store/types"
|
||||
)
|
||||
|
||||
// ErrInvalidIterator is returned when an Iterate call resulted in an invalid iterator.
|
||||
var ErrInvalidIterator = errors.New("collections: invalid iterator")
|
||||
|
||||
// Order defines the key order.
|
||||
type Order uint8
|
||||
|
||||
const (
|
||||
// OrderAscending instructs the Iterator to provide keys from the smallest to the greatest.
|
||||
OrderAscending Order = 0
|
||||
// OrderDescending instructs the Iterator to provide keys from the greatest to the smallest.
|
||||
OrderDescending Order = 1
|
||||
)
|
||||
|
||||
// BoundInclusive creates a Bound of the provided key K
|
||||
// which is inclusive. Meaning, if it is used as Ranger.RangeValues start,
|
||||
// the provided key will be included if it exists in the Iterator range.
|
||||
func BoundInclusive[K any](key K) *Bound[K] {
|
||||
return &Bound[K]{
|
||||
value: key,
|
||||
inclusive: true,
|
||||
}
|
||||
}
|
||||
|
||||
// BoundExclusive creates a Bound of the provided key K
|
||||
// which is exclusive. Meaning, if it is used as Ranger.RangeValues start,
|
||||
// the provided key will be excluded if it exists in the Iterator range.
|
||||
func BoundExclusive[K any](key K) *Bound[K] {
|
||||
return &Bound[K]{
|
||||
value: key,
|
||||
inclusive: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Bound defines key bounds for Start and Ends of iterator ranges.
|
||||
type Bound[K any] struct {
|
||||
value K
|
||||
inclusive bool
|
||||
}
|
||||
|
||||
// Ranger defines a generic interface that provides a range of keys.
|
||||
type Ranger[K any] interface {
|
||||
// RangeValues is defined by Ranger implementers.
|
||||
// It provides instructions to generate an Iterator instance.
|
||||
// If prefix is not nil, then the Iterator will return only the keys which start
|
||||
// with the given prefix.
|
||||
// If start is not nil, then the Iterator will return only keys which are greater than the provided start
|
||||
// or greater equal depending on the bound is inclusive or exclusive.
|
||||
// If end is not nil, then the Iterator will return only keys which are smaller than the provided end
|
||||
// or smaller equal depending on the bound is inclusive or exclusive.
|
||||
RangeValues() (prefix *K, start *Bound[K], end *Bound[K], order Order, err error)
|
||||
}
|
||||
|
||||
// Range is a Ranger implementer.
|
||||
type Range[K any] struct {
|
||||
prefix *K
|
||||
start *Bound[K]
|
||||
end *Bound[K]
|
||||
order Order
|
||||
}
|
||||
|
||||
// Prefix sets a fixed prefix for the key range.
|
||||
func (r *Range[K]) Prefix(key K) *Range[K] {
|
||||
r.prefix = &key
|
||||
return r
|
||||
}
|
||||
|
||||
// StartInclusive makes the range contain only keys which are bigger or equal to the provided start K.
|
||||
func (r *Range[K]) StartInclusive(start K) *Range[K] {
|
||||
r.start = BoundInclusive(start)
|
||||
return r
|
||||
}
|
||||
|
||||
// StartExclusive makes the range contain only keys which are bigger to the provided start K.
|
||||
func (r *Range[K]) StartExclusive(start K) *Range[K] {
|
||||
r.start = BoundExclusive(start)
|
||||
return r
|
||||
}
|
||||
|
||||
// EndInclusive makes the range contain only keys which are smaller or equal to the provided end K.
|
||||
func (r *Range[K]) EndInclusive(end K) *Range[K] {
|
||||
r.end = BoundInclusive(end)
|
||||
return r
|
||||
}
|
||||
|
||||
// EndExclusive makes the range contain only keys which are smaller to the provided end K.
|
||||
func (r *Range[K]) EndExclusive(end K) *Range[K] {
|
||||
r.end = BoundExclusive(end)
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Range[K]) Descending() *Range[K] {
|
||||
r.order = OrderDescending
|
||||
return r
|
||||
}
|
||||
|
||||
// test sentinel error
|
||||
var errRange = errors.New("collections: range error")
|
||||
var errOrder = errors.New("collections: invalid order")
|
||||
|
||||
func (r *Range[K]) RangeValues() (prefix *K, start *Bound[K], end *Bound[K], order Order, err error) {
|
||||
if r.prefix != nil && (r.end != nil || r.start != nil) {
|
||||
return nil, nil, nil, order, fmt.Errorf("%w: prefix must not be set if either start or end are specified", errRange)
|
||||
}
|
||||
return r.prefix, r.start, r.end, r.order, nil
|
||||
}
|
||||
|
||||
// iteratorFromRanger generates an Iterator instance, with the proper prefixing and ranging.
|
||||
// a nil Ranger can be seen as an ascending iteration over all the possible keys.
|
||||
func iteratorFromRanger[K, V any](ctx context.Context, m Map[K, V], r Ranger[K]) (iter Iterator[K, V], err error) {
|
||||
var (
|
||||
prefix *K
|
||||
start *Bound[K]
|
||||
end *Bound[K]
|
||||
order = OrderAscending
|
||||
)
|
||||
// if Ranger is specified then we override the defaults
|
||||
if r != nil {
|
||||
prefix, start, end, order, err = r.RangeValues()
|
||||
if err != nil {
|
||||
return iter, err
|
||||
}
|
||||
}
|
||||
if prefix != nil && (start != nil || end != nil) {
|
||||
return iter, fmt.Errorf("%w: prefix must not be set if either start or end are specified", errRange)
|
||||
}
|
||||
|
||||
// compute start and end bytes
|
||||
var startBytes, endBytes []byte
|
||||
if prefix != nil {
|
||||
startBytes, endBytes, err = prefixStartEndBytes(m, *prefix)
|
||||
if err != nil {
|
||||
return iter, err
|
||||
}
|
||||
} else {
|
||||
startBytes, endBytes, err = rangeStartEndBytes(m, start, end)
|
||||
if err != nil {
|
||||
return iter, err
|
||||
}
|
||||
}
|
||||
|
||||
// get store
|
||||
store, err := m.getStore(ctx)
|
||||
if err != nil {
|
||||
return iter, err
|
||||
}
|
||||
|
||||
// create iter
|
||||
var storeIter storetypes.Iterator
|
||||
switch order {
|
||||
case OrderAscending:
|
||||
storeIter = store.Iterator(startBytes, endBytes)
|
||||
case OrderDescending:
|
||||
storeIter = store.ReverseIterator(startBytes, endBytes)
|
||||
default:
|
||||
return iter, fmt.Errorf("%w: %d", errOrder, order)
|
||||
}
|
||||
|
||||
// check if valid
|
||||
if !storeIter.Valid() {
|
||||
return iter, ErrInvalidIterator
|
||||
}
|
||||
|
||||
// all good
|
||||
iter.kc = m.kc
|
||||
iter.vc = m.vc
|
||||
iter.prefixLength = len(m.prefix)
|
||||
iter.iter = storeIter
|
||||
return iter, nil
|
||||
}
|
||||
|
||||
// rangeStartEndBytes computes a range's start and end bytes to be passed to the store's iterator.
|
||||
func rangeStartEndBytes[K, V any](m Map[K, V], start, end *Bound[K]) (startBytes, endBytes []byte, err error) {
|
||||
startBytes = m.prefix
|
||||
if start != nil {
|
||||
startBytes, err = encodeKeyWithPrefix(m.prefix, m.kc, start.value)
|
||||
if err != nil {
|
||||
return startBytes, endBytes, err
|
||||
}
|
||||
// the start of iterators is by default inclusive,
|
||||
// in order to make it exclusive we extend the start
|
||||
// by one single byte.
|
||||
if !start.inclusive {
|
||||
startBytes = extendOneByte(startBytes)
|
||||
}
|
||||
}
|
||||
if end != nil {
|
||||
endBytes, err = encodeKeyWithPrefix(m.prefix, m.kc, end.value)
|
||||
if err != nil {
|
||||
return startBytes, endBytes, err
|
||||
}
|
||||
// the end of iterators is by default exclusive
|
||||
// in order to make it inclusive we extend the end
|
||||
// by one single byte.
|
||||
if end.inclusive {
|
||||
endBytes = extendOneByte(endBytes)
|
||||
}
|
||||
} else {
|
||||
// if end is not specified then we simply are
|
||||
// inclusive up to the last key of the Prefix
|
||||
// of the collection.
|
||||
endBytes = storetypes.PrefixEndBytes(m.prefix)
|
||||
}
|
||||
|
||||
return startBytes, endBytes, nil
|
||||
}
|
||||
|
||||
// prefixStartEndBytes returns the start and end bytes to be provided to the store's iterator, considering we're prefixing
|
||||
// over a specific key.
|
||||
func prefixStartEndBytes[K, V any](m Map[K, V], prefix K) (startBytes, endBytes []byte, err error) {
|
||||
startBytes, err = encodeKeyWithPrefix(m.prefix, m.kc, prefix)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return startBytes, storetypes.PrefixEndBytes(startBytes), nil
|
||||
}
|
||||
|
||||
// Iterator defines a generic wrapper around an sdk.Iterator.
|
||||
// This iterator provides automatic key and value encoding,
|
||||
// it assumes all the keys and values contained within the sdk.Iterator
|
||||
// range are the same.
|
||||
type Iterator[K, V any] struct {
|
||||
kc KeyCodec[K]
|
||||
vc ValueCodec[V]
|
||||
|
||||
iter storetypes.Iterator
|
||||
|
||||
prefixLength int // prefixLength refers to the bytes provided by Prefix.Bytes, not Ranger.RangeValues() prefix.
|
||||
}
|
||||
|
||||
// Value returns the current iterator value bytes decoded.
|
||||
func (i Iterator[K, V]) Value() (V, error) {
|
||||
return i.vc.Decode(i.iter.Value())
|
||||
}
|
||||
|
||||
// Key returns the current sdk.Iterator decoded key.
|
||||
func (i Iterator[K, V]) Key() (K, error) {
|
||||
bytesKey := i.iter.Key()[i.prefixLength:] // strip prefix namespace
|
||||
|
||||
read, key, err := i.kc.Decode(bytesKey)
|
||||
if err != nil {
|
||||
var k K
|
||||
return k, err
|
||||
}
|
||||
if read != len(bytesKey) {
|
||||
var k K
|
||||
return k, fmt.Errorf("%w: key decoder didn't fully consume the key: %T %x %d", ErrEncoding, i.kc, bytesKey, read)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Values fully consumes the iterator and returns all the decoded values contained within the range.
|
||||
func (i Iterator[K, V]) Values() ([]V, error) {
|
||||
defer i.Close()
|
||||
|
||||
var values []V
|
||||
for ; i.iter.Valid(); i.iter.Next() {
|
||||
value, err := i.Value()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
// Keys fully consumes the iterator and returns all the decoded keys contained within the range.
|
||||
func (i Iterator[K, V]) Keys() ([]K, error) {
|
||||
defer i.Close()
|
||||
|
||||
var keys []K
|
||||
for ; i.iter.Valid(); i.iter.Next() {
|
||||
key, err := i.Key()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// KeyValue returns the current key and value decoded.
|
||||
func (i Iterator[K, V]) KeyValue() (kv KeyValue[K, V], err error) {
|
||||
key, err := i.Key()
|
||||
if err != nil {
|
||||
return kv, err
|
||||
}
|
||||
value, err := i.Value()
|
||||
if err != nil {
|
||||
return kv, err
|
||||
}
|
||||
kv.Key = key
|
||||
kv.Value = value
|
||||
return kv, nil
|
||||
}
|
||||
|
||||
// KeyValues fully consumes the iterator and returns the list of key and values within the iterator range.
|
||||
func (i Iterator[K, V]) KeyValues() ([]KeyValue[K, V], error) {
|
||||
defer i.Close()
|
||||
|
||||
var kvs []KeyValue[K, V]
|
||||
for ; i.iter.Valid(); i.iter.Next() {
|
||||
kv, err := i.KeyValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kvs = append(kvs, kv)
|
||||
}
|
||||
|
||||
return kvs, nil
|
||||
}
|
||||
|
||||
func (i Iterator[K, V]) Close() error { return i.iter.Close() }
|
||||
func (i Iterator[K, V]) Next() { i.iter.Next() }
|
||||
func (i Iterator[K, V]) Valid() bool { return i.iter.Valid() }
|
||||
|
||||
// KeyValue represent a Key and Value pair of an iteration.
|
||||
type KeyValue[K, V any] struct {
|
||||
Key K
|
||||
Value V
|
||||
}
|
||||
|
||||
func extendOneByte(b []byte) []byte {
|
||||
return append(b, 0)
|
||||
}
|
||||
201
collections/iter_test.go
Normal file
201
collections/iter_test.go
Normal file
@ -0,0 +1,201 @@
|
||||
package collections
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIteratorBasic(t *testing.T) {
|
||||
sk, ctx := deps()
|
||||
m := NewMap(sk, NewPrefix("some super amazing prefix"), StringKey, Uint64Value)
|
||||
|
||||
for i := uint64(1); i <= 2; i++ {
|
||||
require.NoError(t, m.Set(ctx, fmt.Sprintf("%d", i), i))
|
||||
}
|
||||
|
||||
iter, err := m.Iterate(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
defer iter.Close()
|
||||
|
||||
// key codec
|
||||
key, err := iter.Key()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "1", key)
|
||||
|
||||
// value codec
|
||||
value, err := iter.Value()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint64(1), value)
|
||||
|
||||
// assert expected prefixing on iter
|
||||
require.Equal(t, m.prefix, iter.iter.Key()[:len(m.prefix)])
|
||||
|
||||
// advance iter
|
||||
iter.Next()
|
||||
require.True(t, iter.Valid())
|
||||
|
||||
// key 2
|
||||
key, err = iter.Key()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2", key)
|
||||
|
||||
// value 2
|
||||
value, err = iter.Value()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint64(2), value)
|
||||
|
||||
// call next, invalid
|
||||
iter.Next()
|
||||
require.False(t, iter.Valid())
|
||||
// close no errors
|
||||
require.NoError(t, iter.Close())
|
||||
}
|
||||
|
||||
func TestIteratorKeyValues(t *testing.T) {
|
||||
sk, ctx := deps()
|
||||
m := NewMap(sk, NewPrefix("some super amazing prefix"), StringKey, Uint64Value)
|
||||
|
||||
for i := uint64(0); i <= 5; i++ {
|
||||
require.NoError(t, m.Set(ctx, fmt.Sprintf("%d", i), i))
|
||||
}
|
||||
|
||||
// test keys
|
||||
iter, err := m.Iterate(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
keys, err := iter.Keys()
|
||||
require.NoError(t, err)
|
||||
|
||||
for i, key := range keys {
|
||||
require.Equal(t, fmt.Sprintf("%d", i), key)
|
||||
}
|
||||
require.NoError(t, iter.Close())
|
||||
require.False(t, iter.Valid())
|
||||
|
||||
// test values
|
||||
iter, err = m.Iterate(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
values, err := iter.Values()
|
||||
require.NoError(t, err)
|
||||
|
||||
for i, value := range values {
|
||||
require.Equal(t, uint64(i), value)
|
||||
}
|
||||
require.NoError(t, iter.Close())
|
||||
require.False(t, iter.Valid())
|
||||
|
||||
// test key value pairings
|
||||
iter, err = m.Iterate(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
kvs, err := iter.KeyValues()
|
||||
require.NoError(t, err)
|
||||
|
||||
for i, kv := range kvs {
|
||||
require.Equal(t, fmt.Sprintf("%d", i), kv.Key)
|
||||
require.Equal(t, uint64(i), kv.Value)
|
||||
}
|
||||
require.NoError(t, iter.Close())
|
||||
require.False(t, iter.Valid())
|
||||
}
|
||||
|
||||
func TestIteratorPrefixing(t *testing.T) {
|
||||
sk, ctx := deps()
|
||||
m := NewMap(sk, NewPrefix("cool"), StringKey, Uint64Value)
|
||||
|
||||
require.NoError(t, m.Set(ctx, "A1", 11))
|
||||
require.NoError(t, m.Set(ctx, "A2", 12))
|
||||
require.NoError(t, m.Set(ctx, "B1", 21))
|
||||
|
||||
iter, err := m.Iterate(ctx, new(Range[string]).Prefix("A"))
|
||||
require.NoError(t, err)
|
||||
keys, err := iter.Keys()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"A1", "A2"}, keys)
|
||||
}
|
||||
|
||||
func TestIteratorRanging(t *testing.T) {
|
||||
sk, ctx := deps()
|
||||
m := NewMap(sk, NewPrefix("cool"), Uint64Key, Uint64Value)
|
||||
|
||||
for i := uint64(0); i <= 7; i++ {
|
||||
require.NoError(t, m.Set(ctx, i, i))
|
||||
}
|
||||
|
||||
// let's range (1-5]; expected: 2..5
|
||||
iter, err := m.Iterate(ctx, (&Range[uint64]{}).StartExclusive(1).EndInclusive(5))
|
||||
require.NoError(t, err)
|
||||
result, err := iter.Keys()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []uint64{2, 3, 4, 5}, result)
|
||||
|
||||
// let's range [1-5); expected 1..4
|
||||
iter, err = m.Iterate(ctx, (&Range[uint64]{}).StartInclusive(1).EndExclusive(5))
|
||||
require.NoError(t, err)
|
||||
result, err = iter.Keys()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []uint64{1, 2, 3, 4}, result)
|
||||
|
||||
// let's range [1-5) descending; expected 4..1
|
||||
iter, err = m.Iterate(ctx, (&Range[uint64]{}).StartInclusive(1).EndExclusive(5).Descending())
|
||||
require.NoError(t, err)
|
||||
result, err = iter.Keys()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []uint64{4, 3, 2, 1}, result)
|
||||
|
||||
// test iterator invalid
|
||||
_, err = m.Iterate(ctx, new(Range[uint64]).StartInclusive(10).EndInclusive(1))
|
||||
require.ErrorIs(t, err, ErrInvalidIterator)
|
||||
}
|
||||
|
||||
func TestRange(t *testing.T) {
|
||||
type test struct {
|
||||
rng *Range[string]
|
||||
wantPrefix *string
|
||||
wantStart *Bound[string]
|
||||
wantEnd *Bound[string]
|
||||
wantOrder Order
|
||||
wantErr error
|
||||
}
|
||||
|
||||
cases := map[string]test{
|
||||
"ok - empty": {
|
||||
rng: new(Range[string]),
|
||||
},
|
||||
"ok - start exclusive - end exclusive": {
|
||||
rng: new(Range[string]).StartExclusive("A").EndExclusive("B"),
|
||||
wantStart: BoundExclusive("A"),
|
||||
wantEnd: BoundExclusive("B"),
|
||||
},
|
||||
"ok - start inclusive - end inclusive - descending": {
|
||||
rng: new(Range[string]).StartInclusive("A").EndInclusive("B").Descending(),
|
||||
wantStart: BoundInclusive("A"),
|
||||
wantEnd: BoundInclusive("B"),
|
||||
wantOrder: OrderDescending,
|
||||
},
|
||||
"ok - prefix": {
|
||||
rng: new(Range[string]).Prefix("A"),
|
||||
wantPrefix: func() *string { p := "A"; return &p }(),
|
||||
},
|
||||
|
||||
"err - prefix and start set": {
|
||||
rng: new(Range[string]).Prefix("A").StartExclusive("B"),
|
||||
wantErr: errRange,
|
||||
},
|
||||
"err - prefix and end set": {
|
||||
rng: new(Range[string]).Prefix("A").StartInclusive("B"),
|
||||
wantErr: errRange,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
gotPrefix, gotStart, gotEnd, gotOrder, gotErr := tc.rng.RangeValues()
|
||||
require.ErrorIs(t, gotErr, tc.wantErr)
|
||||
require.Equal(t, tc.wantPrefix, gotPrefix)
|
||||
require.Equal(t, tc.wantStart, gotStart)
|
||||
require.Equal(t, tc.wantEnd, gotEnd)
|
||||
require.Equal(t, tc.wantOrder, gotOrder)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -7,10 +7,14 @@ import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Uint64Key can be used to encode uint64 keys.
|
||||
// Encoding is big endian to retain ordering.
|
||||
var Uint64Key KeyCodec[uint64] = uint64Key{}
|
||||
|
||||
var (
|
||||
// Uint64Key can be used to encode uint64 keys.
|
||||
// Encoding is big endian to retain ordering.
|
||||
Uint64Key KeyCodec[uint64] = uint64Key{}
|
||||
// StringKey can be used to encode string keys.
|
||||
// The encoding just converts the string to bytes.
|
||||
StringKey KeyCodec[string] = stringKey{}
|
||||
)
|
||||
var errDecodeKeySize = errors.New("decode error, wrong byte key size")
|
||||
|
||||
type uint64Key struct{}
|
||||
@ -20,19 +24,41 @@ func (u uint64Key) Encode(buffer []byte, key uint64) (int, error) {
|
||||
return 8, nil
|
||||
}
|
||||
|
||||
func (u uint64Key) Decode(buffer []byte) (int, uint64, error) {
|
||||
func (uint64Key) Decode(buffer []byte) (int, uint64, error) {
|
||||
if size := len(buffer); size < 8 {
|
||||
return 0, 0, fmt.Errorf("%w: wanted at least 8, got: %d", errDecodeKeySize, size)
|
||||
}
|
||||
return 8, binary.BigEndian.Uint64(buffer), nil
|
||||
}
|
||||
|
||||
func (u uint64Key) Size(_ uint64) int { return 8 }
|
||||
func (uint64Key) Size(_ uint64) int { return 8 }
|
||||
|
||||
func (u uint64Key) Stringify(key uint64) string {
|
||||
func (uint64Key) Stringify(key uint64) string {
|
||||
return strconv.FormatUint(key, 10)
|
||||
}
|
||||
|
||||
func (u uint64Key) KeyType() string {
|
||||
func (uint64Key) KeyType() string {
|
||||
return "uint64"
|
||||
}
|
||||
|
||||
type stringKey struct{}
|
||||
|
||||
func (stringKey) Encode(buffer []byte, key string) (int, error) {
|
||||
return copy(buffer, key), nil
|
||||
}
|
||||
|
||||
func (stringKey) Decode(buffer []byte) (int, string, error) {
|
||||
return len(buffer), string(buffer), nil
|
||||
}
|
||||
|
||||
func (stringKey) Size(key string) int {
|
||||
return len(key)
|
||||
}
|
||||
|
||||
func (stringKey) Stringify(key string) string {
|
||||
return key
|
||||
}
|
||||
|
||||
func (stringKey) KeyType() string {
|
||||
return "string"
|
||||
}
|
||||
|
||||
@ -34,7 +34,8 @@ type Map[K, V any] struct {
|
||||
// Set maps the provided value to the provided key in the store.
|
||||
// Errors with ErrEncoding if key or value encoding fails.
|
||||
func (m Map[K, V]) Set(ctx context.Context, key K, value V) error {
|
||||
keyBytes, err := m.encodeKey(key)
|
||||
bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -48,7 +49,7 @@ func (m Map[K, V]) Set(ctx context.Context, key K, value V) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
store.Set(keyBytes, valueBytes)
|
||||
store.Set(bytesKey, valueBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -56,7 +57,7 @@ func (m Map[K, V]) Set(ctx context.Context, key K, value V) error {
|
||||
// errors with ErrNotFound if the key does not exist, or
|
||||
// with ErrEncoding if the key or value decoding fails.
|
||||
func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) {
|
||||
keyBytes, err := m.encodeKey(key)
|
||||
bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key)
|
||||
if err != nil {
|
||||
var v V
|
||||
return v, err
|
||||
@ -67,7 +68,7 @@ func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) {
|
||||
var v V
|
||||
return v, err
|
||||
}
|
||||
valueBytes := store.Get(keyBytes)
|
||||
valueBytes := store.Get(bytesKey)
|
||||
if valueBytes == nil {
|
||||
var v V
|
||||
return v, fmt.Errorf("%w: key '%s' of type %s", ErrNotFound, m.kc.Stringify(key), m.vc.ValueType())
|
||||
@ -75,7 +76,6 @@ func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) {
|
||||
|
||||
v, err := m.vc.Decode(valueBytes)
|
||||
if err != nil {
|
||||
var v V
|
||||
return v, fmt.Errorf("%w: value decode: %s", ErrEncoding, err) // TODO: use multi err wrapping in go1.20: https://github.com/golang/go/issues/53435
|
||||
}
|
||||
return v, nil
|
||||
@ -84,7 +84,7 @@ func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) {
|
||||
// Has reports whether the key is present in storage or not.
|
||||
// Errors with ErrEncoding if key encoding fails.
|
||||
func (m Map[K, V]) Has(ctx context.Context, key K) (bool, error) {
|
||||
bytesKey, err := m.encodeKey(key)
|
||||
bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -99,7 +99,7 @@ func (m Map[K, V]) Has(ctx context.Context, key K) (bool, error) {
|
||||
// Errors with ErrEncoding if key encoding fails.
|
||||
// If the key does not exist then this is a no-op.
|
||||
func (m Map[K, V]) Remove(ctx context.Context, key K) error {
|
||||
bytesKey, err := m.encodeKey(key)
|
||||
bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -111,6 +111,12 @@ func (m Map[K, V]) Remove(ctx context.Context, key K) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Iterate provides an Iterator over K and V. It accepts a Ranger interface.
|
||||
// A nil ranger equals to iterate over all the keys in ascending order.
|
||||
func (m Map[K, V]) Iterate(ctx context.Context, ranger Ranger[K]) (Iterator[K, V], error) {
|
||||
return iteratorFromRanger(ctx, m, ranger)
|
||||
}
|
||||
|
||||
func (m Map[K, V]) getStore(ctx context.Context) (storetypes.KVStore, error) {
|
||||
provider, ok := ctx.(StorageProvider)
|
||||
if !ok {
|
||||
@ -119,14 +125,14 @@ func (m Map[K, V]) getStore(ctx context.Context) (storetypes.KVStore, error) {
|
||||
return provider.KVStore(m.sk), nil
|
||||
}
|
||||
|
||||
func (m Map[K, V]) encodeKey(key K) ([]byte, error) {
|
||||
prefixLen := len(m.prefix)
|
||||
func encodeKeyWithPrefix[K any](prefix []byte, kc KeyCodec[K], key K) ([]byte, error) {
|
||||
prefixLen := len(prefix)
|
||||
// preallocate buffer
|
||||
keyBytes := make([]byte, prefixLen+m.kc.Size(key))
|
||||
keyBytes := make([]byte, prefixLen+kc.Size(key))
|
||||
// put prefix
|
||||
copy(keyBytes, m.prefix)
|
||||
copy(keyBytes, prefix)
|
||||
// put key
|
||||
_, err := m.kc.Encode(keyBytes[prefixLen:], key)
|
||||
_, err := kc.Encode(keyBytes[prefixLen:], key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: key encode: %s", ErrEncoding, err) // TODO: use multi err wrapping in go1.20: https://github.com/golang/go/issues/53435
|
||||
}
|
||||
|
||||
@ -3,8 +3,6 @@ package collections
|
||||
import (
|
||||
"testing"
|
||||
|
||||
storetypes "github.com/cosmos/cosmos-sdk/store/types"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -35,14 +33,12 @@ func TestMap(t *testing.T) {
|
||||
require.False(t, has)
|
||||
}
|
||||
|
||||
func TestMap_encodeKey(t *testing.T) {
|
||||
func Test_encodeKey(t *testing.T) {
|
||||
prefix := "prefix"
|
||||
number := []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}
|
||||
expectedKey := append([]byte(prefix), number...)
|
||||
|
||||
m := NewMap(storetypes.NewKVStoreKey("test"), NewPrefix(prefix), Uint64Key, Uint64Value)
|
||||
|
||||
gotKey, err := m.encodeKey(0)
|
||||
gotKey, err := encodeKeyWithPrefix(NewPrefix(prefix).Bytes(), Uint64Key, 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedKey, gotKey)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user