refactor(collections): Indexes perf improvements and simplification (#15552)

Co-authored-by: testinginprod <testinginprod@somewhere.idk>
This commit is contained in:
testinginprod 2023-03-30 22:11:16 +02:00 committed by GitHub
parent 0869411d5c
commit 182f3fd551
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 296 additions and 668 deletions

View File

@ -2,8 +2,6 @@ package collections
import (
"context"
"errors"
"fmt"
"cosmossdk.io/collections/codec"
)
@ -21,12 +19,12 @@ type Indexes[PrimaryKey, Value any] interface {
// Index represents an index of the Value indexed using the type PrimaryKey.
type Index[PrimaryKey, Value any] interface {
// Reference creates a reference between the provided primary key and value.
// If oldValue is not nil then the Index must update the references
// of the primary key associated with the new value and remove the
// old invalid references.
Reference(ctx context.Context, pk PrimaryKey, newValue Value, oldValue *Value) error
// It provides a lazyOldValue function that if called will attempt to fetch
// the previous old value, returns ErrNotFound if no value existed.
Reference(ctx context.Context, pk PrimaryKey, newValue Value, lazyOldValue func() (Value, error)) error
// Unreference removes the reference between the primary key and value.
Unreference(ctx context.Context, pk PrimaryKey, value Value) error
// If error is ErrNotFound then it means that the value did not exist before.
Unreference(ctx context.Context, pk PrimaryKey, lazyOldValue func() (Value, error)) error
}
// IndexedMap works like a Map but creates references between fields of Value and its PrimaryKey.
@ -40,10 +38,10 @@ type IndexedMap[PrimaryKey, Value any, Idx Indexes[PrimaryKey, Value]] struct {
}
// NewIndexedMap instantiates a new IndexedMap. Accepts a SchemaBuilder, a Prefix,
// a humanized name that defines the name of the collection, the primary key codec
// a humanised name that defines the name of the collection, the primary key codec
// which is basically what IndexedMap uses to encode the primary key to bytes,
// the value codec which is what the IndexedMap uses to encode the value.
// Then it expects the initialized indexes.
// Then it expects the initialised indexes.
func NewIndexedMap[PrimaryKey, Value any, Idx Indexes[PrimaryKey, Value]](
schema *SchemaBuilder,
prefix Prefix,
@ -76,26 +74,10 @@ func (m *IndexedMap[PrimaryKey, Value, Idx]) Has(ctx context.Context, pk Primary
// Set maps the value using the primary key. It will also iterate every index and instruct them to
// add or update the indexes.
func (m *IndexedMap[PrimaryKey, Value, Idx]) Set(ctx context.Context, pk PrimaryKey, value Value) error {
// we need to see if there was a previous instance of the value
oldValue, err := m.m.Get(ctx, pk)
switch {
// update indexes
case err == nil:
err = m.ref(ctx, pk, value, &oldValue)
if err != nil {
return fmt.Errorf("collections: indexing error: %w", err)
}
// create new indexes
case errors.Is(err, ErrNotFound):
err = m.ref(ctx, pk, value, nil)
if err != nil {
return fmt.Errorf("collections: indexing error: %w", err)
}
// cannot move forward error
default:
err := m.ref(ctx, pk, value)
if err != nil {
return err
}
return m.m.Set(ctx, pk, value)
}
@ -103,13 +85,7 @@ func (m *IndexedMap[PrimaryKey, Value, Idx]) Set(ctx context.Context, pk Primary
// it iterates over all the indexes and instructs them to remove all the references
// associated with the removed value.
func (m *IndexedMap[PrimaryKey, Value, Idx]) Remove(ctx context.Context, pk PrimaryKey) error {
oldValue, err := m.m.Get(ctx, pk)
if err != nil {
// TODO retain Map behavior? which does not error in case we remove a non-existing object
return err
}
err = m.unref(ctx, pk, oldValue)
err := m.unref(ctx, pk)
if err != nil {
return err
}
@ -134,9 +110,9 @@ func (m *IndexedMap[PrimaryKey, Value, Idx]) ValueCodec() codec.ValueCodec[Value
return m.m.ValueCodec()
}
func (m *IndexedMap[PrimaryKey, Value, Idx]) ref(ctx context.Context, pk PrimaryKey, value Value, oldValue *Value) error {
func (m *IndexedMap[PrimaryKey, Value, Idx]) ref(ctx context.Context, pk PrimaryKey, value Value) error {
for _, index := range m.Indexes.IndexesList() {
err := index.Reference(ctx, pk, value, oldValue)
err := index.Reference(ctx, pk, value, cachedGet[PrimaryKey, Value](m, ctx, pk))
if err != nil {
return err
}
@ -144,12 +120,34 @@ func (m *IndexedMap[PrimaryKey, Value, Idx]) ref(ctx context.Context, pk Primary
return nil
}
func (m *IndexedMap[PrimaryKey, Value, Idx]) unref(ctx context.Context, pk PrimaryKey, value Value) error {
func (m *IndexedMap[PrimaryKey, Value, Idx]) unref(ctx context.Context, pk PrimaryKey) error {
for _, index := range m.Indexes.IndexesList() {
err := index.Unreference(ctx, pk, value)
err := index.Unreference(ctx, pk, cachedGet[PrimaryKey, Value](m, ctx, pk))
if err != nil {
return err
}
}
return nil
}
// cachedGet returns a function that gets the value V, given the key K but
// returns always the same result on multiple calls.
func cachedGet[K, V any, M interface {
Get(ctx context.Context, key K) (V, error)
}](m M, ctx context.Context, key K,
) func() (V, error) {
var (
value V
err error
calledOnce bool
)
return func() (V, error) {
if calledOnce {
return value, err
}
value, err = m.Get(ctx, key)
calledOnce = true
return value, err
}
}

View File

@ -5,6 +5,7 @@ import (
"cosmossdk.io/collections"
"cosmossdk.io/collections/colltest"
"cosmossdk.io/collections/indexes"
"github.com/stretchr/testify/require"
)
@ -17,11 +18,11 @@ type companyIndexes struct {
// City is an index of the company indexed map. It indexes a company
// given its city. The index is multi, meaning that there can be multiple
// companies from the same city.
City *collections.GenericMultiIndex[string, string, string, company]
City *indexes.Multi[string, string, company]
// Vat is an index of the company indexed map. It indexes a company
// given its VAT number. The index is unique, meaning that there can be
// only one VAT number for a company.
Vat *collections.GenericUniqueIndex[uint64, string, string, company]
Vat *indexes.Unique[uint64, string, company]
}
func (c companyIndexes) IndexesList() []collections.Index[string, company] {
@ -31,11 +32,11 @@ func (c companyIndexes) IndexesList() []collections.Index[string, company] {
func newTestIndexedMap(schema *collections.SchemaBuilder) *collections.IndexedMap[string, company, companyIndexes] {
return collections.NewIndexedMap(schema, collections.NewPrefix(0), "companies", collections.StringKey, colltest.MockValueCodec[company](),
companyIndexes{
City: collections.NewGenericMultiIndex(schema, collections.NewPrefix(1), "companies_by_city", collections.StringKey, collections.StringKey, func(pk string, value company) ([]collections.IndexReference[string, string], error) {
return []collections.IndexReference[string, string]{collections.NewIndexReference(value.City, pk)}, nil
City: indexes.NewMulti(schema, collections.NewPrefix(1), "companies_by_city", collections.StringKey, collections.StringKey, func(pk string, value company) (string, error) {
return value.City, nil
}),
Vat: collections.NewGenericUniqueIndex(schema, collections.NewPrefix(2), "companies_by_vat", collections.Uint64Key, collections.StringKey, func(pk string, v company) ([]collections.IndexReference[uint64, string], error) {
return []collections.IndexReference[uint64, string]{collections.NewIndexReference(v.Vat, pk)}, nil
Vat: indexes.NewUnique(schema, collections.NewPrefix(2), "companies_by_vat", collections.Uint64Key, collections.StringKey, func(pk string, value company) (uint64, error) {
return value.Vat, nil
}),
},
)
@ -66,7 +67,7 @@ func TestIndexedMap(t *testing.T) {
})
require.NoError(t, err)
pk, err := im.Indexes.Vat.Get(ctx, 1)
pk, err := im.Indexes.Vat.MatchExact(ctx, 1)
require.NoError(t, err)
require.Equal(t, "2", pk)
@ -77,17 +78,17 @@ func TestIndexedMap(t *testing.T) {
})
require.NoError(t, err)
pk, err = im.Indexes.Vat.Get(ctx, 2)
pk, err = im.Indexes.Vat.MatchExact(ctx, 2)
require.NoError(t, err)
require.Equal(t, "2", pk)
_, err = im.Indexes.Vat.Get(ctx, 1)
_, err = im.Indexes.Vat.MatchExact(ctx, 1)
require.ErrorIs(t, err, collections.ErrNotFound)
// test removal
err = im.Remove(ctx, "2")
require.NoError(t, err)
_, err = im.Indexes.Vat.Get(ctx, 2)
_, err = im.Indexes.Vat.MatchExact(ctx, 2)
require.ErrorIs(t, err, collections.ErrNotFound)
// test iteration

View File

@ -8,7 +8,7 @@ import (
)
func TestHelpers(t *testing.T) {
// uses MultiPair scenario.
// uses ReversePair scenario.
// We store balances as:
// Key: Pair[Address=string, Denom=string] => Value: Amount=uint64
@ -22,7 +22,7 @@ func TestHelpers(t *testing.T) {
keyCodec,
collections.Uint64Value,
balanceIndex{
Denom: NewMultiPair[Amount](sb, collections.NewPrefix("denom_index"), "denom_index", keyCodec),
Denom: NewReversePair[Amount](sb, collections.NewPrefix("denom_index"), "denom_index", keyCodec),
},
)

View File

@ -2,6 +2,7 @@ package indexes
import (
"context"
"errors"
"cosmossdk.io/collections"
"cosmossdk.io/collections/codec"
@ -10,7 +11,10 @@ import (
// Multi defines the most common index. It can be used to create a reference between
// a field of value and its primary key. Multiple primary keys can be mapped to the same
// reference key as the index does not enforce uniqueness constraints.
type Multi[ReferenceKey, PrimaryKey, Value any] collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value]
type Multi[ReferenceKey, PrimaryKey, Value any] struct {
getRefKey func(pk PrimaryKey, value Value) (ReferenceKey, error)
refKeys collections.KeySet[collections.Pair[ReferenceKey, PrimaryKey]]
}
// NewMulti instantiates a new Multi instance given a schema,
// a Prefix, the humanized name for the index, the reference key key codec
@ -24,32 +28,54 @@ func NewMulti[ReferenceKey, PrimaryKey, Value any](
pkCodec codec.KeyCodec[PrimaryKey],
getRefKeyFunc func(pk PrimaryKey, value Value) (ReferenceKey, error),
) *Multi[ReferenceKey, PrimaryKey, Value] {
i := collections.NewGenericMultiIndex(
schema, prefix, name, refCodec, pkCodec,
func(pk PrimaryKey, value Value) ([]collections.IndexReference[ReferenceKey, PrimaryKey], error) {
ref, err := getRefKeyFunc(pk, value)
if err != nil {
return nil, err
}
return []collections.IndexReference[ReferenceKey, PrimaryKey]{
collections.NewIndexReference(ref, pk),
}, nil
},
)
return (*Multi[ReferenceKey, PrimaryKey, Value])(i)
return &Multi[ReferenceKey, PrimaryKey, Value]{
getRefKey: getRefKeyFunc,
refKeys: collections.NewKeySet(schema, prefix, name, collections.PairKeyCodec(refCodec, pkCodec)),
}
}
func (m *Multi[ReferenceKey, PrimaryKey, Value]) Reference(ctx context.Context, pk PrimaryKey, newValue Value, oldValue *Value) error {
return (*collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(m).Reference(ctx, pk, newValue, oldValue)
func (m *Multi[ReferenceKey, PrimaryKey, Value]) Reference(ctx context.Context, pk PrimaryKey, newValue Value, lazyOldValue func() (Value, error)) error {
oldValue, err := lazyOldValue()
switch {
// if no error it means the value existed, and we need to remove the old indexes
case err == nil:
err = m.unreference(ctx, pk, oldValue)
if err != nil {
return err
}
// if error is ErrNotFound, it means that the object does not exist, so we're creating indexes for the first time.
// we do nothing.
case errors.Is(err, collections.ErrNotFound):
// default case means that there was some other error
default:
return err
}
// create new indexes
refKey, err := m.getRefKey(pk, newValue)
if err != nil {
return err
}
return m.refKeys.Set(ctx, collections.Join(refKey, pk))
}
func (m *Multi[ReferenceKey, PrimaryKey, Value]) Unreference(ctx context.Context, pk PrimaryKey, value Value) error {
return (*collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(m).Unreference(ctx, pk, value)
func (m *Multi[ReferenceKey, PrimaryKey, Value]) Unreference(ctx context.Context, pk PrimaryKey, getValue func() (Value, error)) error {
value, err := getValue()
if err != nil {
return err
}
return m.unreference(ctx, pk, value)
}
func (m *Multi[ReferenceKey, PrimaryKey, Value]) unreference(ctx context.Context, pk PrimaryKey, value Value) error {
refKey, err := m.getRefKey(pk, value)
if err != nil {
return err
}
return m.refKeys.Remove(ctx, collections.Join(refKey, pk))
}
func (m *Multi[ReferenceKey, PrimaryKey, Value]) Iterate(ctx context.Context, ranger collections.Ranger[collections.Pair[ReferenceKey, PrimaryKey]]) (MultiIterator[ReferenceKey, PrimaryKey], error) {
iter, err := (*collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(m).Iterate(ctx, ranger)
iter, err := m.refKeys.Iterate(ctx, ranger)
return (MultiIterator[ReferenceKey, PrimaryKey])(iter), err
}
@ -58,7 +84,9 @@ func (m *Multi[ReferenceKey, PrimaryKey, Value]) Walk(
ranger collections.Ranger[collections.Pair[ReferenceKey, PrimaryKey]],
walkFunc func(indexingKey ReferenceKey, indexedKey PrimaryKey) bool,
) error {
return (*collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(m).Walk(ctx, ranger, walkFunc)
return m.refKeys.Walk(ctx, ranger, func(key collections.Pair[ReferenceKey, PrimaryKey]) bool {
return walkFunc(key.K1(), key.K2())
})
}
// MatchExact returns a MultiIterator containing all the primary keys referenced by the provided reference key.
@ -66,8 +94,8 @@ func (m *Multi[ReferenceKey, PrimaryKey, Value]) MatchExact(ctx context.Context,
return m.Iterate(ctx, collections.NewPrefixedPairRange[ReferenceKey, PrimaryKey](refKey))
}
func (i *MultiPair[K1, K2, Value]) KeyCodec() codec.KeyCodec[collections.Pair[K2, K1]] {
return (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).KeyCodec()
func (m *Multi[K1, K2, Value]) KeyCodec() codec.KeyCodec[collections.Pair[K1, K2]] {
return m.refKeys.KeyCodec()
}
// MultiIterator is just a KeySetIterator with key as Pair[ReferenceKey, PrimaryKey].

View File

@ -1,132 +0,0 @@
package indexes
import (
"context"
"cosmossdk.io/collections"
"cosmossdk.io/collections/codec"
)
// MultiPair is an index that is used with collections.Pair keys. It indexes objects by their second part of the key.
// When the value is being indexed by collections.IndexedMap then MultiPair will create a relationship between
// the second part of the primary key and the first part.
type MultiPair[K1, K2, Value any] collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value]
// TODO(tip): this is an interface to cast a collections.KeyCodec
// to a pair codec. currently we return it as a KeyCodec[Pair[K1, K2]]
// to improve dev experience with type inference, which means we cannot
// get the concrete implementation which exposes KeyCodec1 and KeyCodec2.
type pairKeyCodec[K1, K2 any] interface {
KeyCodec1() codec.KeyCodec[K1]
KeyCodec2() codec.KeyCodec[K2]
}
// NewMultiPair instantiates a new MultiPair index.
// NOTE: when using this function you will need to type hint: doing NewMultiPair[Value]()
// Example: if the value of the indexed map is string, you need to do NewMultiPair[string](...)
func NewMultiPair[Value, K1, K2 any](
sb *collections.SchemaBuilder,
prefix collections.Prefix,
name string,
pairCodec codec.KeyCodec[collections.Pair[K1, K2]],
) *MultiPair[K1, K2, Value] {
pkc := pairCodec.(pairKeyCodec[K1, K2])
mi := collections.NewGenericMultiIndex(
sb,
prefix,
name,
pkc.KeyCodec2(),
pkc.KeyCodec1(),
func(pk collections.Pair[K1, K2], _ Value) ([]collections.IndexReference[K2, K1], error) {
return []collections.IndexReference[K2, K1]{
collections.NewIndexReference(pk.K2(), pk.K1()),
}, nil
},
)
return (*MultiPair[K1, K2, Value])(mi)
}
// Iterate exposes the raw iterator API.
func (i *MultiPair[K1, K2, Value]) Iterate(ctx context.Context, ranger collections.Ranger[collections.Pair[K2, K1]]) (iter MultiPairIterator[K2, K1], err error) {
sIter, err := (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).Iterate(ctx, ranger)
if err != nil {
return iter, err
}
return (MultiPairIterator[K2, K1])(sIter), nil
}
// MatchExact will return an iterator containing only the primary keys starting with the provided second part of the multipart pair key.
func (i *MultiPair[K1, K2, Value]) MatchExact(ctx context.Context, key K2) (MultiPairIterator[K2, K1], error) {
return i.Iterate(ctx, collections.NewPrefixedPairRange[K2, K1](key))
}
// Reference implements collections.Index
func (i *MultiPair[K1, K2, Value]) Reference(ctx context.Context, pk collections.Pair[K1, K2], value Value, oldValue *Value) error {
return (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).Reference(ctx, pk, value, oldValue)
}
// Unreference implements collections.Index
func (i *MultiPair[K1, K2, Value]) Unreference(ctx context.Context, pk collections.Pair[K1, K2], value Value) error {
return (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).Unreference(ctx, pk, value)
}
func (i *MultiPair[K1, K2, Value]) Walk(
ctx context.Context,
ranger collections.Ranger[collections.Pair[K2, K1]],
walkFunc func(indexingKey K2, indexedKey K1) bool,
) error {
return (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).Walk(ctx, ranger, walkFunc)
}
func (i *MultiPair[K1, K2, Value]) IterateRaw(
ctx context.Context, start, end []byte, order collections.Order,
) (
iter collections.Iterator[collections.Pair[K2, K1], collections.NoValue], err error,
) {
return (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).IterateRaw(ctx, start, end, order)
}
// MultiPairIterator is a helper type around a collections.KeySetIterator when used to work
// with MultiPair indexes iterations.
type MultiPairIterator[K2, K1 any] collections.KeySetIterator[collections.Pair[K2, K1]]
// PrimaryKey returns the primary key from the index. The index is composed like a reverse
// pair key. So we just fetch the pair key from the index and return the reverse.
func (m MultiPairIterator[K2, K1]) PrimaryKey() (pair collections.Pair[K1, K2], err error) {
reversePair, err := m.FullKey()
if err != nil {
return pair, err
}
pair = collections.Join(reversePair.K2(), reversePair.K1())
return pair, nil
}
// PrimaryKeys returns all the primary keys contained in the iterator.
func (m MultiPairIterator[K2, K1]) PrimaryKeys() (pairs []collections.Pair[K1, K2], err error) {
defer m.Close()
for ; m.Valid(); m.Next() {
pair, err := m.PrimaryKey()
if err != nil {
return nil, err
}
pairs = append(pairs, pair)
}
return pairs, err
}
func (m MultiPairIterator[K2, K1]) FullKey() (p collections.Pair[K2, K1], err error) {
return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Key()
}
func (m MultiPairIterator[K2, K1]) Next() {
(collections.KeySetIterator[collections.Pair[K2, K1]])(m).Next()
}
func (m MultiPairIterator[K2, K1]) Valid() bool {
return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Valid()
}
func (m MultiPairIterator[K2, K1]) Close() error {
return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Close()
}

View File

@ -16,8 +16,8 @@ func TestMultiIndex(t *testing.T) {
})
// we crete two reference keys for primary key 1 and 2 associated with "milan"
require.NoError(t, mi.Reference(ctx, 1, company{City: "milan"}, nil))
require.NoError(t, mi.Reference(ctx, 2, company{City: "milan"}, nil))
require.NoError(t, mi.Reference(ctx, 1, company{City: "milan"}, func() (company, error) { return company{}, collections.ErrNotFound }))
require.NoError(t, mi.Reference(ctx, 2, company{City: "milan"}, func() (company, error) { return company{}, collections.ErrNotFound }))
iter, err := mi.MatchExact(ctx, "milan")
require.NoError(t, err)
@ -26,7 +26,7 @@ func TestMultiIndex(t *testing.T) {
require.Equal(t, []uint64{1, 2}, pks)
// replace
require.NoError(t, mi.Reference(ctx, 1, company{City: "new york"}, &company{City: "milan"}))
require.NoError(t, mi.Reference(ctx, 1, company{City: "new york"}, func() (company, error) { return company{City: "milan"}, nil }))
// assert after replace only company with id 2 is referenced by milan
iter, err = mi.MatchExact(ctx, "milan")

View File

@ -0,0 +1,131 @@
package indexes
import (
"context"
"cosmossdk.io/collections"
"cosmossdk.io/collections/codec"
)
// ReversePair is an index that is used with collections.Pair keys. It indexes objects by their second part of the key.
// When the value is being indexed by collections.IndexedMap then ReversePair will create a relationship between
// the second part of the primary key and the first part.
type ReversePair[K1, K2, Value any] struct {
refKeys collections.KeySet[collections.Pair[K2, K1]] // refKeys has the relationships between Join(K2, K1)
}
// TODO(tip): this is an interface to cast a collections.KeyCodec
// to a pair codec. currently we return it as a KeyCodec[Pair[K1, K2]]
// to improve dev experience with type inference, which means we cannot
// get the concrete implementation which exposes KeyCodec1 and KeyCodec2.
type pairKeyCodec[K1, K2 any] interface {
KeyCodec1() codec.KeyCodec[K1]
KeyCodec2() codec.KeyCodec[K2]
}
// NewReversePair instantiates a new ReversePair index.
// NOTE: when using this function you will need to type hint: doing NewReversePair[Value]()
// Example: if the value of the indexed map is string, you need to do NewReversePair[string](...)
func NewReversePair[Value any, K1, K2 any](
sb *collections.SchemaBuilder,
prefix collections.Prefix,
name string,
pairCodec codec.KeyCodec[collections.Pair[K1, K2]],
) *ReversePair[K1, K2, Value] {
pkc := pairCodec.(pairKeyCodec[K1, K2])
mi := &ReversePair[K1, K2, Value]{
refKeys: collections.NewKeySet(sb, prefix, name, collections.PairKeyCodec(pkc.KeyCodec2(), pkc.KeyCodec1())),
}
return mi
}
// Iterate exposes the raw iterator API.
func (i *ReversePair[K1, K2, Value]) Iterate(ctx context.Context, ranger collections.Ranger[collections.Pair[K2, K1]]) (iter ReversePairIterator[K2, K1], err error) {
sIter, err := i.refKeys.Iterate(ctx, ranger)
if err != nil {
return
}
return (ReversePairIterator[K2, K1])(sIter), nil
}
// MatchExact will return an iterator containing only the primary keys starting with the provided second part of the multipart pair key.
func (i *ReversePair[K1, K2, Value]) MatchExact(ctx context.Context, key K2) (ReversePairIterator[K2, K1], error) {
return i.Iterate(ctx, collections.NewPrefixedPairRange[K2, K1](key))
}
// Reference implements collections.Index
func (i *ReversePair[K1, K2, Value]) Reference(ctx context.Context, pk collections.Pair[K1, K2], _ Value, _ func() (Value, error)) error {
return i.refKeys.Set(ctx, collections.Join(pk.K2(), pk.K1()))
}
// Unreference implements collections.Index
func (i *ReversePair[K1, K2, Value]) Unreference(ctx context.Context, pk collections.Pair[K1, K2], _ func() (Value, error)) error {
return i.refKeys.Remove(ctx, collections.Join(pk.K2(), pk.K1()))
}
func (i *ReversePair[K1, K2, Value]) Walk(
ctx context.Context,
ranger collections.Ranger[collections.Pair[K2, K1]],
walkFunc func(indexingKey K2, indexedKey K1) bool,
) error {
return i.refKeys.Walk(ctx, ranger, func(key collections.Pair[K2, K1]) bool {
return walkFunc(key.K1(), key.K2())
})
}
func (i *ReversePair[K1, K2, Value]) IterateRaw(
ctx context.Context, start, end []byte, order collections.Order,
) (
iter collections.Iterator[collections.Pair[K2, K1], collections.NoValue], err error,
) {
return i.refKeys.IterateRaw(ctx, start, end, order)
}
func (i *ReversePair[K1, K2, Value]) KeyCodec() codec.KeyCodec[collections.Pair[K2, K1]] {
return i.refKeys.KeyCodec()
}
// ReversePairIterator is a helper type around a collections.KeySetIterator when used to work
// with ReversePair indexes iterations.
type ReversePairIterator[K2, K1 any] collections.KeySetIterator[collections.Pair[K2, K1]]
// PrimaryKey returns the primary key from the index. The index is composed like a reverse
// pair key. So we just fetch the pair key from the index and return the reverse.
func (m ReversePairIterator[K2, K1]) PrimaryKey() (pair collections.Pair[K1, K2], err error) {
reversePair, err := m.FullKey()
if err != nil {
return pair, err
}
pair = collections.Join(reversePair.K2(), reversePair.K1())
return pair, nil
}
// PrimaryKeys returns all the primary keys contained in the iterator.
func (m ReversePairIterator[K2, K1]) PrimaryKeys() (pairs []collections.Pair[K1, K2], err error) {
defer m.Close()
for ; m.Valid(); m.Next() {
pair, err := m.PrimaryKey()
if err != nil {
return nil, err
}
pairs = append(pairs, pair)
}
return pairs, err
}
func (m ReversePairIterator[K2, K1]) FullKey() (p collections.Pair[K2, K1], err error) {
return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Key()
}
func (m ReversePairIterator[K2, K1]) Next() {
(collections.KeySetIterator[collections.Pair[K2, K1]])(m).Next()
}
func (m ReversePairIterator[K2, K1]) Valid() bool {
return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Valid()
}
func (m ReversePairIterator[K2, K1]) Close() error {
return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Close()
}

View File

@ -16,14 +16,14 @@ type (
// our balance index, allows us to efficiently create an index between the key that maps
// balances which is a collections.Pair[Address, Denom] and the Denom.
type balanceIndex struct {
Denom *MultiPair[Address, Denom, Amount]
Denom *ReversePair[Address, Denom, Amount]
}
func (b balanceIndex) IndexesList() []collections.Index[collections.Pair[Address, Denom], Amount] {
return []collections.Index[collections.Pair[Address, Denom], Amount]{b.Denom}
}
func TestMultiPair(t *testing.T) {
func TestReversePair(t *testing.T) {
sk, ctx := deps()
sb := collections.NewSchemaBuilder(sk)
// we create an indexed map that maps balances, which are saved as
@ -37,7 +37,7 @@ func TestMultiPair(t *testing.T) {
keyCodec,
collections.Uint64Value,
balanceIndex{
Denom: NewMultiPair[Amount](sb, collections.NewPrefix("denom_index"), "denom_index", keyCodec),
Denom: NewReversePair[Amount](sb, collections.NewPrefix("denom_index"), "denom_index", keyCodec),
},
)

View File

@ -2,6 +2,8 @@ package indexes
import (
"context"
"errors"
"fmt"
"cosmossdk.io/collections"
"cosmossdk.io/collections/codec"
@ -9,7 +11,10 @@ import (
// Unique identifies an index that imposes uniqueness constraints on the reference key.
// It creates relationships between reference and primary key of the value.
type Unique[ReferenceKey, PrimaryKey, Value any] collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value]
type Unique[ReferenceKey, PrimaryKey, Value any] struct {
getRefKey func(PrimaryKey, Value) (ReferenceKey, error)
refKeys collections.Map[ReferenceKey, PrimaryKey]
}
// NewUnique instantiates a new Unique index.
func NewUnique[ReferenceKey, PrimaryKey, Value any](
@ -20,34 +25,65 @@ func NewUnique[ReferenceKey, PrimaryKey, Value any](
pkCodec codec.KeyCodec[PrimaryKey],
getRefKeyFunc func(pk PrimaryKey, v Value) (ReferenceKey, error),
) *Unique[ReferenceKey, PrimaryKey, Value] {
i := collections.NewGenericUniqueIndex(schema, prefix, name, refCodec, pkCodec, func(pk PrimaryKey, value Value) ([]collections.IndexReference[ReferenceKey, PrimaryKey], error) {
ref, err := getRefKeyFunc(pk, value)
return &Unique[ReferenceKey, PrimaryKey, Value]{
getRefKey: getRefKeyFunc,
refKeys: collections.NewMap(schema, prefix, name, refCodec, codec.KeyToValueCodec(pkCodec)),
}
}
func (i *Unique[ReferenceKey, PrimaryKey, Value]) Reference(ctx context.Context, pk PrimaryKey, newValue Value, lazyOldValue func() (Value, error)) error {
oldValue, err := lazyOldValue()
switch {
// if no error it means the value existed, and we need to remove the old indexes
case err == nil:
err = i.unreference(ctx, pk, oldValue)
if err != nil {
return nil, err
return err
}
return []collections.IndexReference[ReferenceKey, PrimaryKey]{
collections.NewIndexReference(ref, pk),
}, nil
})
return (*Unique[ReferenceKey, PrimaryKey, Value])(i)
// if error is ErrNotFound, it means that the object does not exist, so we're creating indexes for the first time.
// we do nothing.
case errors.Is(err, collections.ErrNotFound):
// default case means that there was some other error
default:
return err
}
// create new indexes, asserting no uniqueness constraint violation
refKey, err := i.getRefKey(pk, newValue)
if err != nil {
return err
}
has, err := i.refKeys.Has(ctx, refKey)
if err != nil {
return err
}
if has {
return fmt.Errorf("%w: index uniqueness constrain violation: %s", collections.ErrConflict, i.refKeys.KeyCodec().Stringify(refKey))
}
return i.refKeys.Set(ctx, refKey, pk)
}
func (i *Unique[ReferenceKey, PrimaryKey, Value]) Reference(ctx context.Context, pk PrimaryKey, newValue Value, oldValue *Value) error {
return (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Reference(ctx, pk, newValue, oldValue)
func (i *Unique[ReferenceKey, PrimaryKey, Value]) Unreference(ctx context.Context, pk PrimaryKey, getValue func() (Value, error)) error {
value, err := getValue()
if err != nil {
return err
}
return i.unreference(ctx, pk, value)
}
func (i *Unique[ReferenceKey, PrimaryKey, Value]) Unreference(ctx context.Context, pk PrimaryKey, value Value) error {
return (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Unreference(ctx, pk, value)
func (i *Unique[ReferenceKey, PrimaryKey, Value]) unreference(ctx context.Context, pk PrimaryKey, value Value) error {
refKey, err := i.getRefKey(pk, value)
if err != nil {
return err
}
return i.refKeys.Remove(ctx, refKey)
}
func (i *Unique[ReferenceKey, PrimaryKey, Value]) MatchExact(ctx context.Context, ref ReferenceKey) (PrimaryKey, error) {
return (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Get(ctx, ref)
return i.refKeys.Get(ctx, ref)
}
func (i *Unique[ReferenceKey, PrimaryKey, Value]) Iterate(ctx context.Context, ranger collections.Ranger[ReferenceKey]) (UniqueIterator[ReferenceKey, PrimaryKey], error) {
iter, err := (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Iterate(ctx, ranger)
iter, err := i.refKeys.Iterate(ctx, ranger)
return (UniqueIterator[ReferenceKey, PrimaryKey])(iter), err
}
@ -56,11 +92,11 @@ func (i *Unique[ReferenceKey, PrimaryKey, Value]) Walk(
ranger collections.Ranger[ReferenceKey],
walkFunc func(indexingKey ReferenceKey, indexedKey PrimaryKey) bool,
) error {
return (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Walk(ctx, ranger, walkFunc)
return i.refKeys.Walk(ctx, ranger, walkFunc)
}
func (i *Unique[ReferenceKey, PrimaryKey, Value]) IterateRaw(ctx context.Context, start, end []byte, order collections.Order) (u UniqueIterator[ReferenceKey, PrimaryKey], err error) {
iter, err := (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).IterateRaw(ctx, start, end, order)
iter, err := i.refKeys.IterateRaw(ctx, start, end, order)
if err != nil {
return
}

View File

@ -15,15 +15,15 @@ func TestUniqueIndex(t *testing.T) {
})
// map company with id 1 to vat 1_1
err := ui.Reference(ctx, 1, company{Vat: 1_1}, nil)
err := ui.Reference(ctx, 1, company{Vat: 1_1}, func() (company, error) { return company{}, collections.ErrNotFound })
require.NoError(t, err)
// map company with id 2 to vat 2_2
err = ui.Reference(ctx, 2, company{Vat: 2_2}, nil)
err = ui.Reference(ctx, 2, company{Vat: 2_2}, func() (company, error) { return company{}, collections.ErrNotFound })
require.NoError(t, err)
// mapping company 3 with vat 1_1 must yield to a ErrConflict
err = ui.Reference(ctx, 1, company{Vat: 1_1}, nil)
err = ui.Reference(ctx, 1, company{Vat: 1_1}, func() (company, error) { return company{}, collections.ErrNotFound })
require.ErrorIs(t, err, collections.ErrConflict)
// assert references are correct
@ -36,7 +36,7 @@ func TestUniqueIndex(t *testing.T) {
require.Equal(t, uint64(2), id)
// on reference updates, the new referencing key is created and the old is removed
err = ui.Reference(ctx, 1, company{Vat: 1_2}, &company{Vat: 1_1})
err = ui.Reference(ctx, 1, company{Vat: 1_2}, func() (company, error) { return company{Vat: 1_1}, nil })
require.NoError(t, err)
id, err = ui.MatchExact(ctx, 1_2) // assert a new reference is created
require.NoError(t, err)

View File

@ -1,157 +0,0 @@
package collections
import (
"context"
"cosmossdk.io/collections/codec"
)
func NewIndexReference[ReferencingKey, ReferencedKey any](referencing ReferencingKey, referenced ReferencedKey) IndexReference[ReferencingKey, ReferencedKey] {
return IndexReference[ReferencingKey, ReferencedKey]{
Referring: referencing,
Referred: referenced,
}
}
// IndexReference defines a generic index reference.
type IndexReference[ReferencingKey, ReferencedKey any] struct {
// Referring is the key that refers, points to the Referred key.
Referring ReferencingKey
// Referred is the key that is being pointed to by the Referring key.
Referred ReferencedKey
}
// GenericMultiIndex defines a generic Index type that given a primary key
// and the value associated with that primary key returns one or multiple IndexReference.
//
// The referencing key can be anything, usually it is either a part of the primary
// key when we deal with multipart keys, or a field of Value.
//
// The referenced key usually is the primary key, or it can be a part
// of the primary key in the context of multipart keys.
//
// The Referencing and Referenced keys are joined and saved as a Pair in a KeySet
// where the key is Pair[ReferencingKey, ReferencedKey].
// So if we wanted to get all the keys referenced by a generic (concrete) ReferencingKey
// we would just need to iterate over all the keys starting with bytes(ReferencingKey).
//
// Unless you're trying to build your generic multi index, you should be using the indexes package.
type GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any] struct {
refs KeySet[Pair[ReferencingKey, ReferencedKey]]
getRefs func(pk PrimaryKey, v Value) ([]IndexReference[ReferencingKey, ReferencedKey], error)
}
// NewGenericMultiIndex instantiates a GenericMultiIndex, given
// schema, Prefix, humanized name, the key codec used to encode the referencing key
// to bytes, the key codec used to encode the referenced key to bytes and a function
// which given the primary key and a value of an object being saved or removed in IndexedMap
// returns all the possible IndexReference of that object.
//
// The IndexReference is usually just one. But in certain cases can be multiple,
// for example when the Value has an array field, and we want to create a relationship
// between the object and all the elements of the array contained in the object.
func NewGenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any](
schema *SchemaBuilder,
prefix Prefix,
name string,
referencingKeyCodec codec.KeyCodec[ReferencingKey],
referencedKeyCodec codec.KeyCodec[ReferencedKey],
getRefsFunc func(pk PrimaryKey, value Value) ([]IndexReference[ReferencingKey, ReferencedKey], error),
) *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value] {
return &GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]{
getRefs: getRefsFunc,
refs: NewKeySet(schema, prefix, name, PairKeyCodec(referencingKeyCodec, referencedKeyCodec)),
}
}
// Iterate allows to iterate over the index. It returns a KeySetIterator of Pair[ReferencingKey, ReferencedKey].
// K1 of the Pair is the key (referencing) pointing to K2 (referenced).
func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Iterate(
ctx context.Context,
ranger Ranger[Pair[ReferencingKey, ReferencedKey]],
) (KeySetIterator[Pair[ReferencingKey, ReferencedKey]], error) {
return i.refs.Iterate(ctx, ranger)
}
// Has reports if there is a relationship in the index between the referencing and the referenced key.
func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Has(
ctx context.Context,
referencing ReferencingKey,
referenced ReferencedKey,
) (bool, error) {
return i.refs.Has(ctx, Join(referencing, referenced))
}
// Reference implements the Index interface.
func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Reference(
ctx context.Context,
pk PrimaryKey,
value Value,
oldValue *Value,
) error {
if oldValue != nil {
err := i.Unreference(ctx, pk, *oldValue)
if err != nil {
return err
}
}
refKeys, err := i.getRefs(pk, value)
if err != nil {
return err
}
for _, ref := range refKeys {
err := i.refs.Set(ctx, Join(ref.Referring, ref.Referred))
if err != nil {
return err
}
}
return nil
}
// Unreference implements the Index interface.
func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Unreference(
ctx context.Context,
pk PrimaryKey,
value Value,
) error {
refs, err := i.getRefs(pk, value)
if err != nil {
return err
}
for _, ref := range refs {
err = i.refs.Remove(ctx, Join(ref.Referring, ref.Referred))
if err != nil {
return err
}
}
return nil
}
func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) IterateRaw(
ctx context.Context,
start, end []byte,
order Order,
) (Iterator[Pair[ReferencingKey, ReferencedKey], NoValue], error) {
return i.refs.IterateRaw(ctx, start, end, order)
}
func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Walk(
ctx context.Context,
ranger Ranger[Pair[ReferencingKey, ReferencedKey]],
walkFunc func(referencingKey ReferencingKey, referencedKey ReferencedKey) bool,
) error {
return i.refs.Walk(ctx, ranger, func(key Pair[ReferencingKey, ReferencedKey]) bool { return walkFunc(key.K1(), key.K2()) })
}
func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) KeyCodec() codec.KeyCodec[Pair[ReferencingKey, ReferencedKey]] {
return i.refs.KeyCodec()
}
func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) ValueCodec() codec.ValueCodec[NoValue] {
return i.refs.ValueCodec()
}

View File

@ -1,83 +0,0 @@
package collections
import (
"testing"
"github.com/stretchr/testify/require"
)
type coin struct {
denom string // this will be used as indexing field.
amount uint64
}
type balance struct {
coins []coin
}
func TestGenericMultiIndex(t *testing.T) {
// we are simulating a context in which we have the following mapping:
//
// address (represented as string) => balance (slice of coins).
//
// we want to create an index that creates a relationship between the coin
// denom, which is part of the balance structure, and the address. This means
// we know given a denom who are the addresses holding that denom.
// From GenericMultiIndex point of view, the denom field of the array becomes
// the referencing key which points to the address (string), which is the key
// being referenced.
sk, ctx := deps()
sb := NewSchemaBuilder(sk)
mi := NewGenericMultiIndex(
sb, NewPrefix("denoms"), "denom_to_owner", StringKey, StringKey,
func(pk string, value balance) ([]IndexReference[string, string], error) {
// the referencing keys are all the denoms.
refs := make([]IndexReference[string, string], len(value.coins))
// the index reference being created, generates a relationship
// between denom (the key that references) and pk (address, the key
// that is being referenced).
for i, coin := range value.coins {
refs[i] = NewIndexReference(coin.denom, pk)
}
return refs, nil
},
)
// let's create the relationships
err := mi.Reference(ctx, "cosmosAddr1", balance{coins: []coin{
{"atom", 1000}, {"osmo", 5000},
}}, nil)
require.NoError(t, err)
// we must find relations between cosmosaddr1 and the denom atom and osmo
iter, err := mi.Iterate(ctx, nil)
require.NoError(t, err)
keys, err := iter.Keys()
require.NoError(t, err)
require.Len(t, keys, 2)
require.Equal(t, keys[0].K1(), "atom") // assert relationship with atom created
require.Equal(t, keys[1].K1(), "osmo") // assert relationship with osmo created
// if we update the reference to remove osmo as balance then we must not find it anymore
err = mi.Reference(ctx, "cosmosAddr1", balance{coins: []coin{{"atom", 1000}}}, // this is the update which does not have osmo
&balance{coins: []coin{{"atom", 1000}, {"osmo", 5000}}}, // this is the previous record
)
require.NoError(t, err)
exists, err := mi.Has(ctx, "osmo", "cosmosAddr1") // osmo must not exist anymore
require.NoError(t, err)
require.False(t, exists)
exists, err = mi.Has(ctx, "atom", "cosmosAddr1") // atom still exists
require.NoError(t, err)
require.True(t, exists)
// if we unreference then no relationship is maintained anymore
err = mi.Unreference(ctx, "cosmosAddr1", balance{coins: []coin{{"atom", 1000}}})
require.NoError(t, err)
exists, err = mi.Has(ctx, "atom", "cosmosAddr1") // atom is not part of the index anymore because cosmosAddr1 was removed.
require.NoError(t, err)
require.False(t, exists)
}

View File

@ -1,122 +0,0 @@
package collections
import (
"context"
"fmt"
"cosmossdk.io/collections/codec"
)
// GenericUniqueIndex defines a generic index which enforces uniqueness constraints
// between ReferencingKey and ReferencedKey, meaning that one referencing key maps
// only one referenced key. The same referenced key can be mapped by multiple referencing keys.
//
// The referencing key can be anything, usually it is either a part of the primary
// key when we deal with multipart keys, or a field of Value.
//
// The referenced key usually is the primary key, or it can be a part
// of the primary key in the context of multipart keys.
//
// The referencing and referenced keys are mapped together using a Map.
//
// Unless you're trying to build your generic unique index, you should be using the indexes package.
type GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any] struct {
refs Map[ReferencingKey, ReferencedKey]
getRefs func(pk PrimaryKey, value Value) ([]IndexReference[ReferencingKey, ReferencedKey], error)
}
// NewGenericUniqueIndex instantiates a GenericUniqueIndex. Works in the same way as NewGenericMultiIndex.
func NewGenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any](
schema *SchemaBuilder,
prefix Prefix,
name string,
referencingKeyCodec codec.KeyCodec[ReferencingKey],
referencedKeyCodec codec.KeyCodec[ReferencedKey],
getRefs func(pk PrimaryKey, value Value) ([]IndexReference[ReferencingKey, ReferencedKey], error),
) *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value] {
return &GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]{
refs: NewMap[ReferencingKey, ReferencedKey](schema, prefix, name, referencingKeyCodec, codec.KeyToValueCodec(referencedKeyCodec)),
getRefs: getRefs,
}
}
func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Iterate(
ctx context.Context,
ranger Ranger[ReferencingKey],
) (Iterator[ReferencingKey, ReferencedKey], error) {
return i.refs.Iterate(ctx, ranger)
}
func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Get(ctx context.Context, ref ReferencingKey) (ReferencedKey, error) {
return i.refs.Get(ctx, ref)
}
// Reference implements Index.
func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Reference(
ctx context.Context,
pk PrimaryKey,
newValue Value,
oldValue *Value,
) error {
if oldValue != nil {
err := i.Unreference(ctx, pk, *oldValue)
if err != nil {
return err
}
}
refs, err := i.getRefs(pk, newValue)
if err != nil {
return err
}
for _, ref := range refs {
has, err := i.refs.Has(ctx, ref.Referring)
if err != nil {
return err
}
if has {
return fmt.Errorf("%w: index uniqueness constrain violation: %s", ErrConflict, i.refs.kc.Stringify(ref.Referring))
}
err = i.refs.Set(ctx, ref.Referring, ref.Referred)
if err != nil {
return err
}
}
return nil
}
// Unreference implements Index.
func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Unreference(
ctx context.Context,
pk PrimaryKey,
value Value,
) error {
refs, err := i.getRefs(pk, value)
if err != nil {
return err
}
for _, ref := range refs {
err = i.refs.Remove(ctx, ref.Referring)
if err != nil {
return err
}
}
return nil
}
func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) IterateRaw(
ctx context.Context,
start, end []byte,
order Order,
) (Iterator[ReferencingKey, ReferencedKey], error) {
return i.refs.IterateRaw(ctx, start, end, order)
}
func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Walk(
ctx context.Context,
ranger Ranger[ReferencingKey],
walkFunc func(referencingKey ReferencingKey, referencedKey ReferencedKey) bool,
) error {
return i.refs.Walk(ctx, ranger, func(k ReferencingKey, v ReferencedKey) bool { return walkFunc(k, v) })
}

View File

@ -1,72 +0,0 @@
package collections
import (
"testing"
"github.com/stretchr/testify/require"
)
type nftBalance struct {
nftIDs []uint64
}
func TestGenericUniqueIndex(t *testing.T) {
// we create the same testing context as with GenericMultiIndex. We have a mapping:
// Address => NFT balance.
// An NFT balance is represented as a slice of IDs, those IDs are unique, meaning that
// they can be held only by one address.
sk, ctx := deps()
sb := NewSchemaBuilder(sk)
ui := NewGenericUniqueIndex(
sb, NewPrefix("nft_to_owner_index"), "ntf_to_owner_index", Uint64Key, StringKey,
func(pk string, value nftBalance) ([]IndexReference[uint64, string], error) {
// the referencing keys are all the NFT unique ids.
refs := make([]IndexReference[uint64, string], len(value.nftIDs))
// for each NFT contained in the balance we create an index reference
// between the NFT unique ID and the owner of the balance.
for i, id := range value.nftIDs {
refs[i] = NewIndexReference(id, pk)
}
return refs, nil
},
)
// let's create the relationships
err := ui.Reference(ctx, "cosmosAddr1", nftBalance{nftIDs: []uint64{0, 1}}, nil)
require.NoError(t, err)
// assert relations were created
iter, err := ui.Iterate(ctx, nil)
require.NoError(t, err)
defer iter.Close()
kv, err := iter.KeyValues()
require.NoError(t, err)
require.Len(t, kv, 2)
require.Equal(t, kv[0].Key, uint64(0))
require.Equal(t, kv[0].Value, "cosmosAddr1")
require.Equal(t, kv[1].Key, uint64(1))
require.Equal(t, kv[1].Value, "cosmosAddr1")
// assert only one address can own a unique NFT
err = ui.Reference(ctx, "cosmosAddr2", nftBalance{nftIDs: []uint64{0}}, nil) // nft with ID 0 is already owned by cosmosAddr1
require.ErrorIs(t, err, ErrConflict)
// during modifications references are updated, we update the index in
// such a way that cosmosAddr1 loses ownership of nft with id 0.
err = ui.Reference(ctx, "cosmosAddr1",
nftBalance{nftIDs: []uint64{1}}, // this is the update nft balance, which contains only id 1
&nftBalance{nftIDs: []uint64{0, 1}}, // this is the old nft balance, which contains both 0 and 1
)
require.NoError(t, err)
// the updated balance does not contain nft with id 0
_, err = ui.Get(ctx, 0)
require.ErrorIs(t, err, ErrNotFound)
// unreferencing clears all the indexes
err = ui.Unreference(ctx, "cosmosAddr1", nftBalance{nftIDs: []uint64{1}})
require.NoError(t, err)
_, err = ui.Get(ctx, 1)
require.ErrorIs(t, err, ErrNotFound)
}