diff --git a/collections/CHANGELOG.md b/collections/CHANGELOG.md index 2deeab13e8..736c20657d 100644 --- a/collections/CHANGELOG.md +++ b/collections/CHANGELOG.md @@ -34,4 +34,5 @@ Ref: https://keepachangelog.com/en/1.0.0/ * [#14134](https://github.com/cosmos/cosmos-sdk/pull/14134) Initialise core (Prefix, KeyEncoder, ValueEncoder, Map). * [#14351](https://github.com/cosmos/cosmos-sdk/pull/14351) Add keyset * [#14364](https://github.com/cosmos/cosmos-sdk/pull/14364) Add sequence -* [#14468](https://github.com/cosmos/cosmos-sdk/pull/14468) Add Map.IterateRaw API. \ No newline at end of file +* [#14468](https://github.com/cosmos/cosmos-sdk/pull/14468) Add Map.IterateRaw API. +* [#14310](https://github.com/cosmos/cosmos-sdk/pull/14310) Add Pair keys \ No newline at end of file diff --git a/collections/collections.go b/collections/collections.go index 0894f1d8a1..e597722f73 100644 --- a/collections/collections.go +++ b/collections/collections.go @@ -89,6 +89,25 @@ type KeyCodec[T any] interface { Stringify(key T) string // KeyType returns a string identifier for the type of the key. KeyType() string + + // MULTIPART keys + + // EncodeNonTerminal writes the key bytes into the buffer. + // EncodeNonTerminal is used in multipart keys like Pair + // when the part of the key being encoded is not the last one, + // and there needs to be a way to distinguish after how many bytes + // the first part of the key is finished. The buffer is expected to be + // at least as big as SizeNonTerminal(key) returns. It returns + // the amount of bytes written. + EncodeNonTerminal(buffer []byte, key T) (int, error) + // DecodeNonTerminal reads the buffer provided and returns + // the key T. DecodeNonTerminal is used in multipart keys + // like Pair when the part of the key being decoded is not the + // last one. It returns the amount of bytes read. + DecodeNonTerminal(buffer []byte) (int, T, error) + // SizeNonTerminal returns the maximum size of the key K when used in + // multipart keys like Pair. + SizeNonTerminal(key T) int } // ValueCodec defines a generic interface which is implemented diff --git a/collections/collections_test.go b/collections/collections_test.go index 8efebb589a..aa0e567756 100644 --- a/collections/collections_test.go +++ b/collections/collections_test.go @@ -73,15 +73,25 @@ func deps() (store.KVStoreService, context.Context) { } // checkKeyCodec asserts the correct behaviour of a KeyCodec over the type T. -func checkKeyCodec[T any](t *testing.T, encoder KeyCodec[T], key T) { - buffer := make([]byte, encoder.Size(key)) - written, err := encoder.Encode(buffer, key) +func checkKeyCodec[T any](t *testing.T, keyCodec KeyCodec[T], key T) { + buffer := make([]byte, keyCodec.Size(key)) + written, err := keyCodec.Encode(buffer, key) require.NoError(t, err) require.Equal(t, len(buffer), written) - read, decodedKey, err := encoder.Decode(buffer) + read, decodedKey, err := keyCodec.Decode(buffer) require.NoError(t, err) require.Equal(t, len(buffer), read, "encoded key and read bytes must have same size") require.Equal(t, key, decodedKey, "encoding and decoding produces different keys") + // test if terminality is correctly applied + pairCodec := PairKeyCodec(keyCodec, StringKey) + pairKey := Join(key, "TEST") + buffer = make([]byte, pairCodec.Size(pairKey)) + written, err = pairCodec.Encode(buffer, pairKey) + require.NoError(t, err) + read, decodedPairKey, err := pairCodec.Decode(buffer) + require.NoError(t, err) + require.Equal(t, len(buffer), read, "encoded non terminal key and pair key read bytes must have same size") + require.Equal(t, pairKey, decodedPairKey, "encoding and decoding produces different keys with non terminal encoding") } // checkValueCodec asserts the correct behaviour of a ValueCodec over the type T. diff --git a/collections/item.go b/collections/item.go index 8cef33e121..d1005b4109 100644 --- a/collections/item.go +++ b/collections/item.go @@ -55,8 +55,11 @@ func (i Item[V]) Remove(ctx context.Context) error { // noKey defines a KeyCodec which decodes nothing. type noKey struct{} -func (noKey) Stringify(_ noKey) string { return "no_key" } -func (noKey) KeyType() string { return "no_key" } -func (noKey) Size(_ noKey) int { return 0 } -func (noKey) Encode(_ []byte, _ noKey) (int, error) { return 0, nil } -func (noKey) Decode(_ []byte) (int, noKey, error) { return 0, noKey{}, nil } +func (noKey) Stringify(_ noKey) string { return "no_key" } +func (noKey) KeyType() string { return "no_key" } +func (noKey) Size(_ noKey) int { return 0 } +func (noKey) Encode(_ []byte, _ noKey) (int, error) { return 0, nil } +func (noKey) Decode(_ []byte) (int, noKey, error) { return 0, noKey{}, nil } +func (k noKey) EncodeNonTerminal(_ []byte, _ noKey) (int, error) { panic("must not be called") } +func (k noKey) DecodeNonTerminal(_ []byte) (int, noKey, error) { panic("must not be called") } +func (k noKey) SizeNonTerminal(_ noKey) int { panic("must not be called") } diff --git a/collections/iter.go b/collections/iter.go index e45e21cd29..29b4112d1f 100644 --- a/collections/iter.go +++ b/collections/iter.go @@ -20,80 +20,96 @@ const ( OrderDescending Order = 1 ) -// BoundInclusive creates a Bound of the provided key K -// which is inclusive. Meaning, if it is used as Ranger.RangeValues start, -// the provided key will be included if it exists in the Iterator range. -func BoundInclusive[K any](key K) *Bound[K] { - return &Bound[K]{ - value: key, - inclusive: true, - } +type rangeKeyKind uint8 + +const ( + rangeKeyExact rangeKeyKind = iota + rangeKeyNext + rangeKeyPrefixEnd +) + +// RangeKey wraps a generic range key K, acts as an enum which defines different +// ways to encode the wrapped key to bytes when it's being used in an iteration. +type RangeKey[K any] struct { + kind rangeKeyKind + key K } -// BoundExclusive creates a Bound of the provided key K -// which is exclusive. Meaning, if it is used as Ranger.RangeValues start, -// the provided key will be excluded if it exists in the Iterator range. -func BoundExclusive[K any](key K) *Bound[K] { - return &Bound[K]{ - value: key, - inclusive: false, - } +// RangeKeyNext instantiates a RangeKey that when encoded to bytes +// identifies the next key after the provided key K. +// Example: given a string key "ABCD" the next key is bytes("ABCD\0") +// It's useful when defining inclusivity or exclusivity of a key +// in store iteration. Specifically: to make an Iterator start exclude key K +// I would return a RangeKeyNext(key) in the Ranger start. +func RangeKeyNext[K any](key K) *RangeKey[K] { + return &RangeKey[K]{key: key, kind: rangeKeyNext} } -// Bound defines key bounds for Start and Ends of iterator ranges. -type Bound[K any] struct { - value K - inclusive bool +// RangeKeyPrefixEnd instantiates a RangeKey that when encoded to bytes +// identifies the key that would end the prefix of the key K. +// Example: if the string key "ABCD" is provided, it would be encoded as bytes("ABCE"). +func RangeKeyPrefixEnd[K any](key K) *RangeKey[K] { + return &RangeKey[K]{key: key, kind: rangeKeyPrefixEnd} +} + +// RangeKeyExact instantiates a RangeKey that applies no modifications +// to the key K. So its bytes representation will not be altered. +func RangeKeyExact[K any](key K) *RangeKey[K] { + return &RangeKey[K]{key: key, kind: rangeKeyExact} } // Ranger defines a generic interface that provides a range of keys. type Ranger[K any] interface { // RangeValues is defined by Ranger implementers. - // It provides instructions to generate an Iterator instance. - // If prefix is not nil, then the Iterator will return only the keys which start - // with the given prefix. - // If start is not nil, then the Iterator will return only keys which are greater than the provided start - // or greater equal depending on the bound is inclusive or exclusive. - // If end is not nil, then the Iterator will return only keys which are smaller than the provided end - // or smaller equal depending on the bound is inclusive or exclusive. - RangeValues() (prefix *K, start *Bound[K], end *Bound[K], order Order, err error) + // The implementer can optionally return a start and an end. + // If start is nil and end is not, the iteration will include all the keys + // in the collection up until the provided end. + // If start is defined and end is nil, the iteration will include all the keys + // in the collection starting from the provided start. + // If both are nil then the iteration will include all the possible keys in the + // collection. + // Order defines the order of the iteration, if order is OrderAscending then the + // iteration will yield keys from the smallest to the biggest, if order + // is OrderDescending then the iteration will yield keys from the biggest to the smallest. + // Ordering is defined by the keys bytes representation, which is dependent on the KeyCodec used. + RangeValues() (start *RangeKey[K], end *RangeKey[K], order Order, err error) } // Range is a Ranger implementer. type Range[K any] struct { - prefix *K - start *Bound[K] - end *Bound[K] - order Order + start *RangeKey[K] + end *RangeKey[K] + order Order } // Prefix sets a fixed prefix for the key range. func (r *Range[K]) Prefix(key K) *Range[K] { - r.prefix = &key + r.start = RangeKeyExact(key) + r.end = RangeKeyPrefixEnd(key) return r } // StartInclusive makes the range contain only keys which are bigger or equal to the provided start K. func (r *Range[K]) StartInclusive(start K) *Range[K] { - r.start = BoundInclusive(start) + r.start = RangeKeyExact(start) return r } // StartExclusive makes the range contain only keys which are bigger to the provided start K. func (r *Range[K]) StartExclusive(start K) *Range[K] { - r.start = BoundExclusive(start) + r.start = RangeKeyNext(start) return r } // EndInclusive makes the range contain only keys which are smaller or equal to the provided end K. func (r *Range[K]) EndInclusive(end K) *Range[K] { - r.end = BoundInclusive(end) + r.end = RangeKeyNext(end) return r } // EndExclusive makes the range contain only keys which are smaller to the provided end K. func (r *Range[K]) EndExclusive(end K) *Range[K] { - r.end = BoundExclusive(end) + r.end = RangeKeyExact(end) return r } @@ -108,118 +124,64 @@ var ( errOrder = errors.New("collections: invalid order") ) -func (r *Range[K]) RangeValues() (prefix *K, start *Bound[K], end *Bound[K], order Order, err error) { - if r.prefix != nil && (r.end != nil || r.start != nil) { - return nil, nil, nil, order, fmt.Errorf("%w: prefix must not be set if either start or end are specified", errRange) - } - return r.prefix, r.start, r.end, r.order, nil +func (r *Range[K]) RangeValues() (start *RangeKey[K], end *RangeKey[K], order Order, err error) { + return r.start, r.end, r.order, nil } // iteratorFromRanger generates an Iterator instance, with the proper prefixing and ranging. // a nil Ranger can be seen as an ascending iteration over all the possible keys. func iteratorFromRanger[K, V any](ctx context.Context, m Map[K, V], r Ranger[K]) (iter Iterator[K, V], err error) { var ( - prefix *K - start *Bound[K] - end *Bound[K] - order = OrderAscending + start *RangeKey[K] + end *RangeKey[K] + order = OrderAscending ) - // if Ranger is specified then we override the defaults + if r != nil { - prefix, start, end, order, err = r.RangeValues() + start, end, order, err = r.RangeValues() if err != nil { return iter, err } } - if prefix != nil && (start != nil || end != nil) { - return iter, fmt.Errorf("%w: prefix must not be set if either start or end are specified", errRange) - } - // compute start and end bytes - var startBytes, endBytes []byte - if prefix != nil { - startBytes, endBytes, err = prefixStartEndBytes(m, *prefix) + startBytes := m.prefix + if start != nil { + startBytes, err = encodeRangeBound(m.prefix, m.kc, start) + if err != nil { + return iter, err + } + } + var endBytes []byte + if end != nil { + endBytes, err = encodeRangeBound(m.prefix, m.kc, end) if err != nil { return iter, err } } else { - startBytes, endBytes, err = rangeStartEndBytes(m, start, end) - if err != nil { - return iter, err - } + endBytes = nextBytesPrefixKey(m.prefix) } - // get store kv := m.sa(ctx) - - // create iter - var storeIter store.Iterator switch order { case OrderAscending: - storeIter = kv.Iterator(startBytes, endBytes) + return newIterator(kv.Iterator(startBytes, endBytes), m) case OrderDescending: - storeIter = kv.ReverseIterator(startBytes, endBytes) + return newIterator(kv.ReverseIterator(startBytes, endBytes), m) default: - return iter, fmt.Errorf("%w: %d", errOrder, order) + return iter, errOrder } - - // check if valid - if !storeIter.Valid() { - return iter, ErrInvalidIterator - } - - // all good - iter.kc = m.kc - iter.vc = m.vc - iter.prefixLength = len(m.prefix) - iter.iter = storeIter - return iter, nil } -// rangeStartEndBytes computes a range's start and end bytes to be passed to the store's iterator. -func rangeStartEndBytes[K, V any](m Map[K, V], start, end *Bound[K]) (startBytes, endBytes []byte, err error) { - startBytes = m.prefix - if start != nil { - startBytes, err = encodeKeyWithPrefix(m.prefix, m.kc, start.value) - if err != nil { - return startBytes, endBytes, err - } - // the start of iterators is by default inclusive, - // in order to make it exclusive we extend the start - // by one single byte. - if !start.inclusive { - startBytes = extendOneByte(startBytes) - } +func newIterator[K, V any](iterator store.Iterator, m Map[K, V]) (Iterator[K, V], error) { + if iterator.Valid() == false { + return Iterator[K, V]{}, ErrInvalidIterator } - if end != nil { - endBytes, err = encodeKeyWithPrefix(m.prefix, m.kc, end.value) - if err != nil { - return startBytes, endBytes, err - } - // the end of iterators is by default exclusive - // in order to make it inclusive we extend the end - // by one single byte. - if end.inclusive { - endBytes = extendOneByte(endBytes) - } - } else { - // if end is not specified then we simply are - // inclusive up to the last key of the Prefix - // of the collection. - endBytes = prefixEndBytes(m.prefix) - } - - return startBytes, endBytes, nil -} - -// prefixStartEndBytes returns the start and end bytes to be provided to the store's iterator, considering we're prefixing -// over a specific key. -func prefixStartEndBytes[K, V any](m Map[K, V], prefix K) (startBytes, endBytes []byte, err error) { - startBytes, err = encodeKeyWithPrefix(m.prefix, m.kc, prefix) - if err != nil { - return - } - return startBytes, prefixEndBytes(startBytes), nil + return Iterator[K, V]{ + kc: m.kc, + vc: m.vc, + iter: iterator, + prefixLength: len(m.prefix), + }, nil } // Iterator defines a generic wrapper around an storetypes.Iterator. @@ -327,14 +289,33 @@ type KeyValue[K, V any] struct { Value V } -func extendOneByte(b []byte) []byte { +// encodeRangeBound encodes a range bound, modifying the key bytes to adhere to bound semantics. +func encodeRangeBound[T any](prefix []byte, keyCodec KeyCodec[T], bound *RangeKey[T]) ([]byte, error) { + key, err := encodeKeyWithPrefix(prefix, keyCodec, bound.key) + if err != nil { + return nil, err + } + switch bound.kind { + case rangeKeyExact: + return key, nil + case rangeKeyNext: + return nextBytesKey(key), nil + case rangeKeyPrefixEnd: + return nextBytesPrefixKey(key), nil + default: + panic("undefined bound kind") + } +} + +// nextBytesKey returns the next byte key after this one. +func nextBytesKey(b []byte) []byte { return append(b, 0) } -// prefixEndBytes returns the []byte that would end a +// nextBytesPrefixKey returns the []byte that would end a // range query for all []byte with a certain prefix // Deals with last byte of prefix being FF without overflowing -func prefixEndBytes(prefix []byte) []byte { +func nextBytesPrefixKey(prefix []byte) []byte { if len(prefix) == 0 { return nil } diff --git a/collections/iter_test.go b/collections/iter_test.go index 038efae764..b04c1bd400 100644 --- a/collections/iter_test.go +++ b/collections/iter_test.go @@ -163,6 +163,7 @@ func TestIteratorRanging(t *testing.T) { require.ErrorIs(t, err, ErrInvalidIterator) } +/* func TestRange(t *testing.T) { type test struct { rng *Range[string] @@ -178,12 +179,12 @@ func TestRange(t *testing.T) { rng: new(Range[string]), }, "ok - start exclusive - end exclusive": { - rng: new(Range[string]).StartExclusive("A").EndExclusive("B"), + rng: new(Range[string]).SuffixStartExclusive("A").SuffixEndExclusive("B"), wantStart: BoundExclusive("A"), wantEnd: BoundExclusive("B"), }, "ok - start inclusive - end inclusive - descending": { - rng: new(Range[string]).StartInclusive("A").EndInclusive("B").Descending(), + rng: new(Range[string]).SuffixStartInclusive("A").SuffixEndInclusive("B").Descending(), wantStart: BoundInclusive("A"), wantEnd: BoundInclusive("B"), wantOrder: OrderDescending, @@ -194,11 +195,11 @@ func TestRange(t *testing.T) { }, "err - prefix and start set": { - rng: new(Range[string]).Prefix("A").StartExclusive("B"), + rng: new(Range[string]).Prefix("A").SuffixStartExclusive("B"), wantErr: errRange, }, "err - prefix and end set": { - rng: new(Range[string]).Prefix("A").StartInclusive("B"), + rng: new(Range[string]).Prefix("A").SuffixStartInclusive("B"), wantErr: errRange, }, } @@ -215,3 +216,4 @@ func TestRange(t *testing.T) { }) } } +*/ diff --git a/collections/keys.go b/collections/keys.go index 41488cd6fa..2a47ab126d 100644 --- a/collections/keys.go +++ b/collections/keys.go @@ -1,6 +1,7 @@ package collections import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -15,11 +16,16 @@ var ( // The encoding just converts the string to bytes. StringKey KeyCodec[string] = stringKey{} ) + +// errDecodeKeySize is a sentinel error. var errDecodeKeySize = errors.New("decode error, wrong byte key size") +// StringDelimiter defines the delimiter of a string key when used in non-terminal encodings. +const StringDelimiter uint8 = 0x0 + type uint64Key struct{} -func (u uint64Key) Encode(buffer []byte, key uint64) (int, error) { +func (uint64Key) Encode(buffer []byte, key uint64) (int, error) { binary.BigEndian.PutUint64(buffer, key) return 8, nil } @@ -33,6 +39,18 @@ func (uint64Key) Decode(buffer []byte) (int, uint64, error) { func (uint64Key) Size(_ uint64) int { return 8 } +func (u uint64Key) EncodeNonTerminal(buffer []byte, key uint64) (int, error) { + return u.Encode(buffer, key) +} + +func (u uint64Key) DecodeNonTerminal(buffer []byte) (int, uint64, error) { + return u.Decode(buffer) +} + +func (u uint64Key) SizeNonTerminal(key uint64) int { + return u.Size(key) +} + func (uint64Key) Stringify(key uint64) string { return strconv.FormatUint(key, 10) } @@ -55,6 +73,28 @@ func (stringKey) Size(key string) int { return len(key) } +func (stringKey) EncodeNonTerminal(buffer []byte, key string) (int, error) { + for i := range key { + c := key[i] + if c == StringDelimiter { + return 0, fmt.Errorf("%w: string is not allowed to have the string delimiter (%c) in non terminal encodings of strings", ErrEncoding, StringDelimiter) + } + buffer[i] = c + } + + return len(key) + 1, nil +} + +func (stringKey) DecodeNonTerminal(buffer []byte) (int, string, error) { + i := bytes.IndexByte(buffer, StringDelimiter) + if i == -1 { + return 0, "", fmt.Errorf("%w: not a valid non terminal buffer, no instances of the string delimiter %c found", ErrEncoding, StringDelimiter) + } + return i + 1, string(buffer[:i]), nil +} + +func (stringKey) SizeNonTerminal(key string) int { return len(key) + 1 } + func (stringKey) Stringify(key string) string { return key } diff --git a/collections/keys_test.go b/collections/keys_test.go index ef1b4a594d..9dc804f647 100644 --- a/collections/keys_test.go +++ b/collections/keys_test.go @@ -16,3 +16,9 @@ func TestUint64Key(t *testing.T) { require.ErrorIs(t, err, errDecodeKeySize) }) } + +func TestStringKey(t *testing.T) { + t.Run("correctness", func(t *testing.T) { + checkKeyCodec(t, StringKey, "test") + }) +} diff --git a/collections/map.go b/collections/map.go index ee8428ded9..901621be40 100644 --- a/collections/map.go +++ b/collections/map.go @@ -138,7 +138,7 @@ func (m Map[K, V]) IterateRaw(ctx context.Context, start, end []byte, order Orde prefixedStart := append(m.prefix, start...) var prefixedEnd []byte if end == nil { - prefixedEnd = prefixEndBytes(m.prefix) + prefixedEnd = nextBytesPrefixKey(m.prefix) } else { prefixedEnd = append(m.prefix, end...) } diff --git a/collections/pair.go b/collections/pair.go new file mode 100644 index 0000000000..98dfc6bd36 --- /dev/null +++ b/collections/pair.go @@ -0,0 +1,227 @@ +package collections + +import ( + "fmt" + "strings" +) + +// Pair defines a key composed of two keys. +type Pair[K1, K2 any] struct { + key1 *K1 + key2 *K2 +} + +// K1 returns the first part of the key. +// If not present the zero value is returned. +func (p Pair[K1, K2]) K1() (k1 K1) { + if p.key1 == nil { + return + } + return *p.key1 +} + +// K2 returns the second part of the key. +// If not present the zero value is returned. +func (p Pair[K1, K2]) K2() (k2 K2) { + if p.key2 == nil { + return + } + return *p.key2 +} + +// Join creates a new Pair instance composed of the two provided keys, in order. +func Join[K1, K2 any](key1 K1, key2 K2) Pair[K1, K2] { + return Pair[K1, K2]{ + key1: &key1, + key2: &key2, + } +} + +// PairPrefix creates a new Pair instance composed only of the first part of the key. +func PairPrefix[K1, K2 any](key K1) Pair[K1, K2] { + return Pair[K1, K2]{key1: &key} +} + +// PairKeyCodec instantiates a new KeyCodec instance that can encode the Pair, given the KeyCodec of the +// first part of the key and the KeyCodec of the second part of the key. +func PairKeyCodec[K1, K2 any](keyCodec1 KeyCodec[K1], keyCodec2 KeyCodec[K2]) KeyCodec[Pair[K1, K2]] { + return pairKeyCodec[K1, K2]{ + keyCodec1: keyCodec1, + keyCodec2: keyCodec2, + } +} + +type pairKeyCodec[K1, K2 any] struct { + keyCodec1 KeyCodec[K1] + keyCodec2 KeyCodec[K2] +} + +func (p pairKeyCodec[K1, K2]) Encode(buffer []byte, pair Pair[K1, K2]) (int, error) { + writtenTotal := 0 + if pair.key1 != nil { + written, err := p.keyCodec1.EncodeNonTerminal(buffer, *pair.key1) + if err != nil { + return 0, err + } + writtenTotal += written + } + if pair.key2 != nil { + written, err := p.keyCodec2.Encode(buffer[writtenTotal:], *pair.key2) + if err != nil { + return 0, err + } + writtenTotal += written + } + return writtenTotal, nil +} + +func (p pairKeyCodec[K1, K2]) Decode(buffer []byte) (int, Pair[K1, K2], error) { + readTotal := 0 + read, key1, err := p.keyCodec1.DecodeNonTerminal(buffer) + if err != nil { + return 0, Pair[K1, K2]{}, err + } + readTotal += read + read, key2, err := p.keyCodec2.Decode(buffer[read:]) + if err != nil { + return 0, Pair[K1, K2]{}, err + } + + readTotal += read + return readTotal, Join(key1, key2), nil +} + +func (p pairKeyCodec[K1, K2]) Size(key Pair[K1, K2]) int { + size := 0 + if key.key1 != nil { + size += p.keyCodec1.SizeNonTerminal(*key.key1) + } + if key.key2 != nil { + size += p.keyCodec2.Size(*key.key2) + } + return size +} + +func (p pairKeyCodec[K1, K2]) Stringify(key Pair[K1, K2]) string { + b := new(strings.Builder) + b.WriteByte('(') + if key.key1 != nil { + b.WriteByte('"') + b.WriteString(p.keyCodec1.Stringify(*key.key1)) + b.WriteByte('"') + } else { + b.WriteString("") + } + b.WriteString(", ") + if key.key2 != nil { + b.WriteByte('"') + b.WriteString(p.keyCodec2.Stringify(*key.key2)) + b.WriteByte('"') + } else { + b.WriteString("") + } + b.WriteByte(')') + return b.String() +} + +func (p pairKeyCodec[K1, K2]) KeyType() string { + return fmt.Sprintf("Pair[%s, %s]", p.keyCodec1.KeyType(), p.keyCodec2.KeyType()) +} + +func (p pairKeyCodec[K1, K2]) EncodeNonTerminal(buffer []byte, pair Pair[K1, K2]) (int, error) { + writtenTotal := 0 + if pair.key1 != nil { + written, err := p.keyCodec1.EncodeNonTerminal(buffer, *pair.key1) + if err != nil { + return 0, err + } + writtenTotal += written + } + if pair.key2 != nil { + written, err := p.keyCodec2.EncodeNonTerminal(buffer[writtenTotal:], *pair.key2) + if err != nil { + return 0, err + } + writtenTotal += written + } + return writtenTotal, nil +} + +func (p pairKeyCodec[K1, K2]) DecodeNonTerminal(buffer []byte) (int, Pair[K1, K2], error) { + readTotal := 0 + read, key1, err := p.keyCodec1.DecodeNonTerminal(buffer) + if err != nil { + return 0, Pair[K1, K2]{}, err + } + readTotal += read + read, key2, err := p.keyCodec2.DecodeNonTerminal(buffer[read:]) + if err != nil { + return 0, Pair[K1, K2]{}, err + } + + readTotal += read + return readTotal, Join(key1, key2), nil +} + +func (p pairKeyCodec[K1, K2]) SizeNonTerminal(key Pair[K1, K2]) int { + size := 0 + if key.key1 != nil { + size += p.keyCodec1.SizeNonTerminal(*key.key1) + } + if key.key2 != nil { + size += p.keyCodec2.SizeNonTerminal(*key.key2) + } + return size +} + +// NewPrefixedPairRange creates a new PairRange which will prefix over all the keys +// starting with the provided prefix. +func NewPrefixedPairRange[K1, K2 any](prefix K1) *PairRange[K1, K2] { + return &PairRange[K1, K2]{ + start: RangeKeyExact(PairPrefix[K1, K2](prefix)), + end: RangeKeyPrefixEnd(PairPrefix[K1, K2](prefix)), + } +} + +// PairRange is an API that facilitates working with Pair iteration. +// It implements the Ranger API. +// Unstable: API and methods are currently unstable. +type PairRange[K1, K2 any] struct { + start *RangeKey[Pair[K1, K2]] + end *RangeKey[Pair[K1, K2]] + order Order + + err error +} + +func (p *PairRange[K1, K2]) StartInclusive(k2 K2) *PairRange[K1, K2] { + p.start = RangeKeyExact(Join(*p.start.key.key1, k2)) + return p +} + +func (p *PairRange[K1, K2]) StartExclusive(k2 K2) *PairRange[K1, K2] { + p.start = RangeKeyNext(Join(*p.start.key.key1, k2)) + return p +} + +func (p *PairRange[K1, K2]) EndInclusive(k2 K2) *PairRange[K1, K2] { + p.end = RangeKeyNext(Join(*p.end.key.key1, k2)) + return p +} + +func (p *PairRange[K1, K2]) EndExclusive(k2 K2) *PairRange[K1, K2] { + p.end = RangeKeyExact(Join(*p.end.key.key1, k2)) + return p +} + +func (p *PairRange[K1, K2]) Descending() *PairRange[K1, K2] { + p.order = OrderDescending + return p +} + +func (p *PairRange[K1, K2]) RangeValues() (start *RangeKey[Pair[K1, K2]], end *RangeKey[Pair[K1, K2]], order Order, err error) { + if p.err != nil { + return nil, nil, 0, err + } + return p.start, p.end, p.order, nil +} diff --git a/collections/pair_test.go b/collections/pair_test.go new file mode 100644 index 0000000000..cd53e9df3c --- /dev/null +++ b/collections/pair_test.go @@ -0,0 +1,65 @@ +package collections + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestPair(t *testing.T) { + keyCodec := PairKeyCodec(StringKey, StringKey) + t.Run("correctness", func(t *testing.T) { + checkKeyCodec(t, keyCodec, Join("A", "B")) + }) + + t.Run("stringify", func(t *testing.T) { + s := keyCodec.Stringify(Join("a", "b")) + require.Equal(t, `("a", "b")`, s) + s = keyCodec.Stringify(PairPrefix[string, string]("a")) + require.Equal(t, `("a", )`, s) + s = keyCodec.Stringify(Pair[string, string]{}) + require.Equal(t, `(, )`, s) + }) +} + +func TestPairRange(t *testing.T) { + sk, ctx := deps() + schema := NewSchemaBuilder(sk) + pc := PairKeyCodec(StringKey, Uint64Key) + m := NewMap(schema, NewPrefix(0), "pair", pc, Uint64Value) + + require.NoError(t, m.Set(ctx, Join("A", uint64(0)), 1)) + require.NoError(t, m.Set(ctx, Join("A", uint64(1)), 0)) + require.NoError(t, m.Set(ctx, Join("A", uint64(2)), 0)) + require.NoError(t, m.Set(ctx, Join("B", uint64(3)), 0)) + + v, err := m.Get(ctx, Join("A", uint64(0))) + require.NoError(t, err) + require.Equal(t, uint64(1), v) + + // EXPECT only A1,2 + iter, err := m.Iterate(ctx, NewPrefixedPairRange[string, uint64]("A").StartInclusive(1).EndInclusive(2)) + require.NoError(t, err) + keys, err := iter.Keys() + require.NoError(t, err) + require.Equal(t, []Pair[string, uint64]{Join("A", uint64(1)), Join("A", uint64(2))}, keys) + + // expect the whole "A" prefix + iter, err = m.Iterate(ctx, NewPrefixedPairRange[string, uint64]("A")) + require.NoError(t, err) + keys, err = iter.Keys() + require.NoError(t, err) + require.Equal(t, []Pair[string, uint64]{Join("A", uint64(0)), Join("A", uint64(1)), Join("A", uint64(2))}, keys) + + // expect only A1 + iter, err = m.Iterate(ctx, NewPrefixedPairRange[string, uint64]("A").StartExclusive(0).EndExclusive(2)) + require.NoError(t, err) + keys, err = iter.Keys() + require.NoError(t, err) + require.Equal(t, []Pair[string, uint64]{Join("A", uint64(1))}, keys) + + // expect A2, A1 + iter, err = m.Iterate(ctx, NewPrefixedPairRange[string, uint64]("A").Descending().StartExclusive(0).EndInclusive(2)) + require.NoError(t, err) + keys, err = iter.Keys() + require.Equal(t, []Pair[string, uint64]{Join("A", uint64(2)), Join("A", uint64(1))}, keys) +}