From 182f3fd5512922f84962d5e6493875eca9835a58 Mon Sep 17 00:00:00 2001 From: testinginprod <98415576+testinginprod@users.noreply.github.com> Date: Thu, 30 Mar 2023 22:11:16 +0200 Subject: [PATCH] refactor(collections): Indexes perf improvements and simplification (#15552) Co-authored-by: testinginprod --- collections/indexed_map.go | 74 ++++----- collections/indexed_map_test.go | 21 +-- collections/indexes/helpers_test.go | 4 +- collections/indexes/multi.go | 74 ++++++--- collections/indexes/multi_pair.go | 132 --------------- collections/indexes/multi_test.go | 6 +- collections/indexes/reverse_pair.go | 131 +++++++++++++++ ...ulti_pair_test.go => reverse_pair_test.go} | 6 +- collections/indexes/unique.go | 74 ++++++--- collections/indexes/unique_test.go | 8 +- collections/indexes_generic_multi.go | 157 ------------------ collections/indexes_generic_multi_test.go | 83 --------- collections/indexes_generic_unique.go | 122 -------------- collections/indexes_generic_unique_test.go | 72 -------- 14 files changed, 296 insertions(+), 668 deletions(-) delete mode 100644 collections/indexes/multi_pair.go create mode 100644 collections/indexes/reverse_pair.go rename collections/indexes/{multi_pair_test.go => reverse_pair_test.go} (89%) delete mode 100644 collections/indexes_generic_multi.go delete mode 100644 collections/indexes_generic_multi_test.go delete mode 100644 collections/indexes_generic_unique.go delete mode 100644 collections/indexes_generic_unique_test.go diff --git a/collections/indexed_map.go b/collections/indexed_map.go index 3736e824cc..c761658ea4 100644 --- a/collections/indexed_map.go +++ b/collections/indexed_map.go @@ -2,8 +2,6 @@ package collections import ( "context" - "errors" - "fmt" "cosmossdk.io/collections/codec" ) @@ -21,12 +19,12 @@ type Indexes[PrimaryKey, Value any] interface { // Index represents an index of the Value indexed using the type PrimaryKey. type Index[PrimaryKey, Value any] interface { // Reference creates a reference between the provided primary key and value. - // If oldValue is not nil then the Index must update the references - // of the primary key associated with the new value and remove the - // old invalid references. - Reference(ctx context.Context, pk PrimaryKey, newValue Value, oldValue *Value) error + // It provides a lazyOldValue function that if called will attempt to fetch + // the previous old value, returns ErrNotFound if no value existed. + Reference(ctx context.Context, pk PrimaryKey, newValue Value, lazyOldValue func() (Value, error)) error // Unreference removes the reference between the primary key and value. - Unreference(ctx context.Context, pk PrimaryKey, value Value) error + // If error is ErrNotFound then it means that the value did not exist before. + Unreference(ctx context.Context, pk PrimaryKey, lazyOldValue func() (Value, error)) error } // IndexedMap works like a Map but creates references between fields of Value and its PrimaryKey. @@ -40,10 +38,10 @@ type IndexedMap[PrimaryKey, Value any, Idx Indexes[PrimaryKey, Value]] struct { } // NewIndexedMap instantiates a new IndexedMap. Accepts a SchemaBuilder, a Prefix, -// a humanized name that defines the name of the collection, the primary key codec +// a humanised name that defines the name of the collection, the primary key codec // which is basically what IndexedMap uses to encode the primary key to bytes, // the value codec which is what the IndexedMap uses to encode the value. -// Then it expects the initialized indexes. +// Then it expects the initialised indexes. func NewIndexedMap[PrimaryKey, Value any, Idx Indexes[PrimaryKey, Value]]( schema *SchemaBuilder, prefix Prefix, @@ -76,26 +74,10 @@ func (m *IndexedMap[PrimaryKey, Value, Idx]) Has(ctx context.Context, pk Primary // Set maps the value using the primary key. It will also iterate every index and instruct them to // add or update the indexes. func (m *IndexedMap[PrimaryKey, Value, Idx]) Set(ctx context.Context, pk PrimaryKey, value Value) error { - // we need to see if there was a previous instance of the value - oldValue, err := m.m.Get(ctx, pk) - switch { - // update indexes - case err == nil: - err = m.ref(ctx, pk, value, &oldValue) - if err != nil { - return fmt.Errorf("collections: indexing error: %w", err) - } - // create new indexes - case errors.Is(err, ErrNotFound): - err = m.ref(ctx, pk, value, nil) - if err != nil { - return fmt.Errorf("collections: indexing error: %w", err) - } - // cannot move forward error - default: + err := m.ref(ctx, pk, value) + if err != nil { return err } - return m.m.Set(ctx, pk, value) } @@ -103,13 +85,7 @@ func (m *IndexedMap[PrimaryKey, Value, Idx]) Set(ctx context.Context, pk Primary // it iterates over all the indexes and instructs them to remove all the references // associated with the removed value. func (m *IndexedMap[PrimaryKey, Value, Idx]) Remove(ctx context.Context, pk PrimaryKey) error { - oldValue, err := m.m.Get(ctx, pk) - if err != nil { - // TODO retain Map behavior? which does not error in case we remove a non-existing object - return err - } - - err = m.unref(ctx, pk, oldValue) + err := m.unref(ctx, pk) if err != nil { return err } @@ -134,9 +110,9 @@ func (m *IndexedMap[PrimaryKey, Value, Idx]) ValueCodec() codec.ValueCodec[Value return m.m.ValueCodec() } -func (m *IndexedMap[PrimaryKey, Value, Idx]) ref(ctx context.Context, pk PrimaryKey, value Value, oldValue *Value) error { +func (m *IndexedMap[PrimaryKey, Value, Idx]) ref(ctx context.Context, pk PrimaryKey, value Value) error { for _, index := range m.Indexes.IndexesList() { - err := index.Reference(ctx, pk, value, oldValue) + err := index.Reference(ctx, pk, value, cachedGet[PrimaryKey, Value](m, ctx, pk)) if err != nil { return err } @@ -144,12 +120,34 @@ func (m *IndexedMap[PrimaryKey, Value, Idx]) ref(ctx context.Context, pk Primary return nil } -func (m *IndexedMap[PrimaryKey, Value, Idx]) unref(ctx context.Context, pk PrimaryKey, value Value) error { +func (m *IndexedMap[PrimaryKey, Value, Idx]) unref(ctx context.Context, pk PrimaryKey) error { for _, index := range m.Indexes.IndexesList() { - err := index.Unreference(ctx, pk, value) + err := index.Unreference(ctx, pk, cachedGet[PrimaryKey, Value](m, ctx, pk)) if err != nil { return err } } return nil } + +// cachedGet returns a function that gets the value V, given the key K but +// returns always the same result on multiple calls. +func cachedGet[K, V any, M interface { + Get(ctx context.Context, key K) (V, error) +}](m M, ctx context.Context, key K, +) func() (V, error) { + var ( + value V + err error + calledOnce bool + ) + + return func() (V, error) { + if calledOnce { + return value, err + } + value, err = m.Get(ctx, key) + calledOnce = true + return value, err + } +} diff --git a/collections/indexed_map_test.go b/collections/indexed_map_test.go index 707537883b..9af4dd74b7 100644 --- a/collections/indexed_map_test.go +++ b/collections/indexed_map_test.go @@ -5,6 +5,7 @@ import ( "cosmossdk.io/collections" "cosmossdk.io/collections/colltest" + "cosmossdk.io/collections/indexes" "github.com/stretchr/testify/require" ) @@ -17,11 +18,11 @@ type companyIndexes struct { // City is an index of the company indexed map. It indexes a company // given its city. The index is multi, meaning that there can be multiple // companies from the same city. - City *collections.GenericMultiIndex[string, string, string, company] + City *indexes.Multi[string, string, company] // Vat is an index of the company indexed map. It indexes a company // given its VAT number. The index is unique, meaning that there can be // only one VAT number for a company. - Vat *collections.GenericUniqueIndex[uint64, string, string, company] + Vat *indexes.Unique[uint64, string, company] } func (c companyIndexes) IndexesList() []collections.Index[string, company] { @@ -31,11 +32,11 @@ func (c companyIndexes) IndexesList() []collections.Index[string, company] { func newTestIndexedMap(schema *collections.SchemaBuilder) *collections.IndexedMap[string, company, companyIndexes] { return collections.NewIndexedMap(schema, collections.NewPrefix(0), "companies", collections.StringKey, colltest.MockValueCodec[company](), companyIndexes{ - City: collections.NewGenericMultiIndex(schema, collections.NewPrefix(1), "companies_by_city", collections.StringKey, collections.StringKey, func(pk string, value company) ([]collections.IndexReference[string, string], error) { - return []collections.IndexReference[string, string]{collections.NewIndexReference(value.City, pk)}, nil + City: indexes.NewMulti(schema, collections.NewPrefix(1), "companies_by_city", collections.StringKey, collections.StringKey, func(pk string, value company) (string, error) { + return value.City, nil }), - Vat: collections.NewGenericUniqueIndex(schema, collections.NewPrefix(2), "companies_by_vat", collections.Uint64Key, collections.StringKey, func(pk string, v company) ([]collections.IndexReference[uint64, string], error) { - return []collections.IndexReference[uint64, string]{collections.NewIndexReference(v.Vat, pk)}, nil + Vat: indexes.NewUnique(schema, collections.NewPrefix(2), "companies_by_vat", collections.Uint64Key, collections.StringKey, func(pk string, value company) (uint64, error) { + return value.Vat, nil }), }, ) @@ -66,7 +67,7 @@ func TestIndexedMap(t *testing.T) { }) require.NoError(t, err) - pk, err := im.Indexes.Vat.Get(ctx, 1) + pk, err := im.Indexes.Vat.MatchExact(ctx, 1) require.NoError(t, err) require.Equal(t, "2", pk) @@ -77,17 +78,17 @@ func TestIndexedMap(t *testing.T) { }) require.NoError(t, err) - pk, err = im.Indexes.Vat.Get(ctx, 2) + pk, err = im.Indexes.Vat.MatchExact(ctx, 2) require.NoError(t, err) require.Equal(t, "2", pk) - _, err = im.Indexes.Vat.Get(ctx, 1) + _, err = im.Indexes.Vat.MatchExact(ctx, 1) require.ErrorIs(t, err, collections.ErrNotFound) // test removal err = im.Remove(ctx, "2") require.NoError(t, err) - _, err = im.Indexes.Vat.Get(ctx, 2) + _, err = im.Indexes.Vat.MatchExact(ctx, 2) require.ErrorIs(t, err, collections.ErrNotFound) // test iteration diff --git a/collections/indexes/helpers_test.go b/collections/indexes/helpers_test.go index 00052bd5e1..469aa38dba 100644 --- a/collections/indexes/helpers_test.go +++ b/collections/indexes/helpers_test.go @@ -8,7 +8,7 @@ import ( ) func TestHelpers(t *testing.T) { - // uses MultiPair scenario. + // uses ReversePair scenario. // We store balances as: // Key: Pair[Address=string, Denom=string] => Value: Amount=uint64 @@ -22,7 +22,7 @@ func TestHelpers(t *testing.T) { keyCodec, collections.Uint64Value, balanceIndex{ - Denom: NewMultiPair[Amount](sb, collections.NewPrefix("denom_index"), "denom_index", keyCodec), + Denom: NewReversePair[Amount](sb, collections.NewPrefix("denom_index"), "denom_index", keyCodec), }, ) diff --git a/collections/indexes/multi.go b/collections/indexes/multi.go index 8f2452dcd8..410aae00fd 100644 --- a/collections/indexes/multi.go +++ b/collections/indexes/multi.go @@ -2,6 +2,7 @@ package indexes import ( "context" + "errors" "cosmossdk.io/collections" "cosmossdk.io/collections/codec" @@ -10,7 +11,10 @@ import ( // Multi defines the most common index. It can be used to create a reference between // a field of value and its primary key. Multiple primary keys can be mapped to the same // reference key as the index does not enforce uniqueness constraints. -type Multi[ReferenceKey, PrimaryKey, Value any] collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value] +type Multi[ReferenceKey, PrimaryKey, Value any] struct { + getRefKey func(pk PrimaryKey, value Value) (ReferenceKey, error) + refKeys collections.KeySet[collections.Pair[ReferenceKey, PrimaryKey]] +} // NewMulti instantiates a new Multi instance given a schema, // a Prefix, the humanized name for the index, the reference key key codec @@ -24,32 +28,54 @@ func NewMulti[ReferenceKey, PrimaryKey, Value any]( pkCodec codec.KeyCodec[PrimaryKey], getRefKeyFunc func(pk PrimaryKey, value Value) (ReferenceKey, error), ) *Multi[ReferenceKey, PrimaryKey, Value] { - i := collections.NewGenericMultiIndex( - schema, prefix, name, refCodec, pkCodec, - func(pk PrimaryKey, value Value) ([]collections.IndexReference[ReferenceKey, PrimaryKey], error) { - ref, err := getRefKeyFunc(pk, value) - if err != nil { - return nil, err - } - return []collections.IndexReference[ReferenceKey, PrimaryKey]{ - collections.NewIndexReference(ref, pk), - }, nil - }, - ) - - return (*Multi[ReferenceKey, PrimaryKey, Value])(i) + return &Multi[ReferenceKey, PrimaryKey, Value]{ + getRefKey: getRefKeyFunc, + refKeys: collections.NewKeySet(schema, prefix, name, collections.PairKeyCodec(refCodec, pkCodec)), + } } -func (m *Multi[ReferenceKey, PrimaryKey, Value]) Reference(ctx context.Context, pk PrimaryKey, newValue Value, oldValue *Value) error { - return (*collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(m).Reference(ctx, pk, newValue, oldValue) +func (m *Multi[ReferenceKey, PrimaryKey, Value]) Reference(ctx context.Context, pk PrimaryKey, newValue Value, lazyOldValue func() (Value, error)) error { + oldValue, err := lazyOldValue() + switch { + // if no error it means the value existed, and we need to remove the old indexes + case err == nil: + err = m.unreference(ctx, pk, oldValue) + if err != nil { + return err + } + // if error is ErrNotFound, it means that the object does not exist, so we're creating indexes for the first time. + // we do nothing. + case errors.Is(err, collections.ErrNotFound): + // default case means that there was some other error + default: + return err + } + // create new indexes + refKey, err := m.getRefKey(pk, newValue) + if err != nil { + return err + } + return m.refKeys.Set(ctx, collections.Join(refKey, pk)) } -func (m *Multi[ReferenceKey, PrimaryKey, Value]) Unreference(ctx context.Context, pk PrimaryKey, value Value) error { - return (*collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(m).Unreference(ctx, pk, value) +func (m *Multi[ReferenceKey, PrimaryKey, Value]) Unreference(ctx context.Context, pk PrimaryKey, getValue func() (Value, error)) error { + value, err := getValue() + if err != nil { + return err + } + return m.unreference(ctx, pk, value) +} + +func (m *Multi[ReferenceKey, PrimaryKey, Value]) unreference(ctx context.Context, pk PrimaryKey, value Value) error { + refKey, err := m.getRefKey(pk, value) + if err != nil { + return err + } + return m.refKeys.Remove(ctx, collections.Join(refKey, pk)) } func (m *Multi[ReferenceKey, PrimaryKey, Value]) Iterate(ctx context.Context, ranger collections.Ranger[collections.Pair[ReferenceKey, PrimaryKey]]) (MultiIterator[ReferenceKey, PrimaryKey], error) { - iter, err := (*collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(m).Iterate(ctx, ranger) + iter, err := m.refKeys.Iterate(ctx, ranger) return (MultiIterator[ReferenceKey, PrimaryKey])(iter), err } @@ -58,7 +84,9 @@ func (m *Multi[ReferenceKey, PrimaryKey, Value]) Walk( ranger collections.Ranger[collections.Pair[ReferenceKey, PrimaryKey]], walkFunc func(indexingKey ReferenceKey, indexedKey PrimaryKey) bool, ) error { - return (*collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(m).Walk(ctx, ranger, walkFunc) + return m.refKeys.Walk(ctx, ranger, func(key collections.Pair[ReferenceKey, PrimaryKey]) bool { + return walkFunc(key.K1(), key.K2()) + }) } // MatchExact returns a MultiIterator containing all the primary keys referenced by the provided reference key. @@ -66,8 +94,8 @@ func (m *Multi[ReferenceKey, PrimaryKey, Value]) MatchExact(ctx context.Context, return m.Iterate(ctx, collections.NewPrefixedPairRange[ReferenceKey, PrimaryKey](refKey)) } -func (i *MultiPair[K1, K2, Value]) KeyCodec() codec.KeyCodec[collections.Pair[K2, K1]] { - return (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).KeyCodec() +func (m *Multi[K1, K2, Value]) KeyCodec() codec.KeyCodec[collections.Pair[K1, K2]] { + return m.refKeys.KeyCodec() } // MultiIterator is just a KeySetIterator with key as Pair[ReferenceKey, PrimaryKey]. diff --git a/collections/indexes/multi_pair.go b/collections/indexes/multi_pair.go deleted file mode 100644 index f22a330e46..0000000000 --- a/collections/indexes/multi_pair.go +++ /dev/null @@ -1,132 +0,0 @@ -package indexes - -import ( - "context" - - "cosmossdk.io/collections" - "cosmossdk.io/collections/codec" -) - -// MultiPair is an index that is used with collections.Pair keys. It indexes objects by their second part of the key. -// When the value is being indexed by collections.IndexedMap then MultiPair will create a relationship between -// the second part of the primary key and the first part. -type MultiPair[K1, K2, Value any] collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value] - -// TODO(tip): this is an interface to cast a collections.KeyCodec -// to a pair codec. currently we return it as a KeyCodec[Pair[K1, K2]] -// to improve dev experience with type inference, which means we cannot -// get the concrete implementation which exposes KeyCodec1 and KeyCodec2. -type pairKeyCodec[K1, K2 any] interface { - KeyCodec1() codec.KeyCodec[K1] - KeyCodec2() codec.KeyCodec[K2] -} - -// NewMultiPair instantiates a new MultiPair index. -// NOTE: when using this function you will need to type hint: doing NewMultiPair[Value]() -// Example: if the value of the indexed map is string, you need to do NewMultiPair[string](...) -func NewMultiPair[Value, K1, K2 any]( - sb *collections.SchemaBuilder, - prefix collections.Prefix, - name string, - pairCodec codec.KeyCodec[collections.Pair[K1, K2]], -) *MultiPair[K1, K2, Value] { - pkc := pairCodec.(pairKeyCodec[K1, K2]) - mi := collections.NewGenericMultiIndex( - sb, - prefix, - name, - pkc.KeyCodec2(), - pkc.KeyCodec1(), - func(pk collections.Pair[K1, K2], _ Value) ([]collections.IndexReference[K2, K1], error) { - return []collections.IndexReference[K2, K1]{ - collections.NewIndexReference(pk.K2(), pk.K1()), - }, nil - }, - ) - - return (*MultiPair[K1, K2, Value])(mi) -} - -// Iterate exposes the raw iterator API. -func (i *MultiPair[K1, K2, Value]) Iterate(ctx context.Context, ranger collections.Ranger[collections.Pair[K2, K1]]) (iter MultiPairIterator[K2, K1], err error) { - sIter, err := (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).Iterate(ctx, ranger) - if err != nil { - return iter, err - } - return (MultiPairIterator[K2, K1])(sIter), nil -} - -// MatchExact will return an iterator containing only the primary keys starting with the provided second part of the multipart pair key. -func (i *MultiPair[K1, K2, Value]) MatchExact(ctx context.Context, key K2) (MultiPairIterator[K2, K1], error) { - return i.Iterate(ctx, collections.NewPrefixedPairRange[K2, K1](key)) -} - -// Reference implements collections.Index -func (i *MultiPair[K1, K2, Value]) Reference(ctx context.Context, pk collections.Pair[K1, K2], value Value, oldValue *Value) error { - return (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).Reference(ctx, pk, value, oldValue) -} - -// Unreference implements collections.Index -func (i *MultiPair[K1, K2, Value]) Unreference(ctx context.Context, pk collections.Pair[K1, K2], value Value) error { - return (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).Unreference(ctx, pk, value) -} - -func (i *MultiPair[K1, K2, Value]) Walk( - ctx context.Context, - ranger collections.Ranger[collections.Pair[K2, K1]], - walkFunc func(indexingKey K2, indexedKey K1) bool, -) error { - return (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).Walk(ctx, ranger, walkFunc) -} - -func (i *MultiPair[K1, K2, Value]) IterateRaw( - ctx context.Context, start, end []byte, order collections.Order, -) ( - iter collections.Iterator[collections.Pair[K2, K1], collections.NoValue], err error, -) { - return (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).IterateRaw(ctx, start, end, order) -} - -// MultiPairIterator is a helper type around a collections.KeySetIterator when used to work -// with MultiPair indexes iterations. -type MultiPairIterator[K2, K1 any] collections.KeySetIterator[collections.Pair[K2, K1]] - -// PrimaryKey returns the primary key from the index. The index is composed like a reverse -// pair key. So we just fetch the pair key from the index and return the reverse. -func (m MultiPairIterator[K2, K1]) PrimaryKey() (pair collections.Pair[K1, K2], err error) { - reversePair, err := m.FullKey() - if err != nil { - return pair, err - } - pair = collections.Join(reversePair.K2(), reversePair.K1()) - return pair, nil -} - -// PrimaryKeys returns all the primary keys contained in the iterator. -func (m MultiPairIterator[K2, K1]) PrimaryKeys() (pairs []collections.Pair[K1, K2], err error) { - defer m.Close() - for ; m.Valid(); m.Next() { - pair, err := m.PrimaryKey() - if err != nil { - return nil, err - } - pairs = append(pairs, pair) - } - return pairs, err -} - -func (m MultiPairIterator[K2, K1]) FullKey() (p collections.Pair[K2, K1], err error) { - return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Key() -} - -func (m MultiPairIterator[K2, K1]) Next() { - (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Next() -} - -func (m MultiPairIterator[K2, K1]) Valid() bool { - return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Valid() -} - -func (m MultiPairIterator[K2, K1]) Close() error { - return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Close() -} diff --git a/collections/indexes/multi_test.go b/collections/indexes/multi_test.go index 98518cd1be..ed11195bf0 100644 --- a/collections/indexes/multi_test.go +++ b/collections/indexes/multi_test.go @@ -16,8 +16,8 @@ func TestMultiIndex(t *testing.T) { }) // we crete two reference keys for primary key 1 and 2 associated with "milan" - require.NoError(t, mi.Reference(ctx, 1, company{City: "milan"}, nil)) - require.NoError(t, mi.Reference(ctx, 2, company{City: "milan"}, nil)) + require.NoError(t, mi.Reference(ctx, 1, company{City: "milan"}, func() (company, error) { return company{}, collections.ErrNotFound })) + require.NoError(t, mi.Reference(ctx, 2, company{City: "milan"}, func() (company, error) { return company{}, collections.ErrNotFound })) iter, err := mi.MatchExact(ctx, "milan") require.NoError(t, err) @@ -26,7 +26,7 @@ func TestMultiIndex(t *testing.T) { require.Equal(t, []uint64{1, 2}, pks) // replace - require.NoError(t, mi.Reference(ctx, 1, company{City: "new york"}, &company{City: "milan"})) + require.NoError(t, mi.Reference(ctx, 1, company{City: "new york"}, func() (company, error) { return company{City: "milan"}, nil })) // assert after replace only company with id 2 is referenced by milan iter, err = mi.MatchExact(ctx, "milan") diff --git a/collections/indexes/reverse_pair.go b/collections/indexes/reverse_pair.go new file mode 100644 index 0000000000..cc846e92d9 --- /dev/null +++ b/collections/indexes/reverse_pair.go @@ -0,0 +1,131 @@ +package indexes + +import ( + "context" + + "cosmossdk.io/collections" + "cosmossdk.io/collections/codec" +) + +// ReversePair is an index that is used with collections.Pair keys. It indexes objects by their second part of the key. +// When the value is being indexed by collections.IndexedMap then ReversePair will create a relationship between +// the second part of the primary key and the first part. +type ReversePair[K1, K2, Value any] struct { + refKeys collections.KeySet[collections.Pair[K2, K1]] // refKeys has the relationships between Join(K2, K1) +} + +// TODO(tip): this is an interface to cast a collections.KeyCodec +// to a pair codec. currently we return it as a KeyCodec[Pair[K1, K2]] +// to improve dev experience with type inference, which means we cannot +// get the concrete implementation which exposes KeyCodec1 and KeyCodec2. +type pairKeyCodec[K1, K2 any] interface { + KeyCodec1() codec.KeyCodec[K1] + KeyCodec2() codec.KeyCodec[K2] +} + +// NewReversePair instantiates a new ReversePair index. +// NOTE: when using this function you will need to type hint: doing NewReversePair[Value]() +// Example: if the value of the indexed map is string, you need to do NewReversePair[string](...) +func NewReversePair[Value any, K1, K2 any]( + sb *collections.SchemaBuilder, + prefix collections.Prefix, + name string, + pairCodec codec.KeyCodec[collections.Pair[K1, K2]], +) *ReversePair[K1, K2, Value] { + pkc := pairCodec.(pairKeyCodec[K1, K2]) + mi := &ReversePair[K1, K2, Value]{ + refKeys: collections.NewKeySet(sb, prefix, name, collections.PairKeyCodec(pkc.KeyCodec2(), pkc.KeyCodec1())), + } + + return mi +} + +// Iterate exposes the raw iterator API. +func (i *ReversePair[K1, K2, Value]) Iterate(ctx context.Context, ranger collections.Ranger[collections.Pair[K2, K1]]) (iter ReversePairIterator[K2, K1], err error) { + sIter, err := i.refKeys.Iterate(ctx, ranger) + if err != nil { + return + } + return (ReversePairIterator[K2, K1])(sIter), nil +} + +// MatchExact will return an iterator containing only the primary keys starting with the provided second part of the multipart pair key. +func (i *ReversePair[K1, K2, Value]) MatchExact(ctx context.Context, key K2) (ReversePairIterator[K2, K1], error) { + return i.Iterate(ctx, collections.NewPrefixedPairRange[K2, K1](key)) +} + +// Reference implements collections.Index +func (i *ReversePair[K1, K2, Value]) Reference(ctx context.Context, pk collections.Pair[K1, K2], _ Value, _ func() (Value, error)) error { + return i.refKeys.Set(ctx, collections.Join(pk.K2(), pk.K1())) +} + +// Unreference implements collections.Index +func (i *ReversePair[K1, K2, Value]) Unreference(ctx context.Context, pk collections.Pair[K1, K2], _ func() (Value, error)) error { + return i.refKeys.Remove(ctx, collections.Join(pk.K2(), pk.K1())) +} + +func (i *ReversePair[K1, K2, Value]) Walk( + ctx context.Context, + ranger collections.Ranger[collections.Pair[K2, K1]], + walkFunc func(indexingKey K2, indexedKey K1) bool, +) error { + return i.refKeys.Walk(ctx, ranger, func(key collections.Pair[K2, K1]) bool { + return walkFunc(key.K1(), key.K2()) + }) +} + +func (i *ReversePair[K1, K2, Value]) IterateRaw( + ctx context.Context, start, end []byte, order collections.Order, +) ( + iter collections.Iterator[collections.Pair[K2, K1], collections.NoValue], err error, +) { + return i.refKeys.IterateRaw(ctx, start, end, order) +} + +func (i *ReversePair[K1, K2, Value]) KeyCodec() codec.KeyCodec[collections.Pair[K2, K1]] { + return i.refKeys.KeyCodec() +} + +// ReversePairIterator is a helper type around a collections.KeySetIterator when used to work +// with ReversePair indexes iterations. +type ReversePairIterator[K2, K1 any] collections.KeySetIterator[collections.Pair[K2, K1]] + +// PrimaryKey returns the primary key from the index. The index is composed like a reverse +// pair key. So we just fetch the pair key from the index and return the reverse. +func (m ReversePairIterator[K2, K1]) PrimaryKey() (pair collections.Pair[K1, K2], err error) { + reversePair, err := m.FullKey() + if err != nil { + return pair, err + } + pair = collections.Join(reversePair.K2(), reversePair.K1()) + return pair, nil +} + +// PrimaryKeys returns all the primary keys contained in the iterator. +func (m ReversePairIterator[K2, K1]) PrimaryKeys() (pairs []collections.Pair[K1, K2], err error) { + defer m.Close() + for ; m.Valid(); m.Next() { + pair, err := m.PrimaryKey() + if err != nil { + return nil, err + } + pairs = append(pairs, pair) + } + return pairs, err +} + +func (m ReversePairIterator[K2, K1]) FullKey() (p collections.Pair[K2, K1], err error) { + return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Key() +} + +func (m ReversePairIterator[K2, K1]) Next() { + (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Next() +} + +func (m ReversePairIterator[K2, K1]) Valid() bool { + return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Valid() +} + +func (m ReversePairIterator[K2, K1]) Close() error { + return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Close() +} diff --git a/collections/indexes/multi_pair_test.go b/collections/indexes/reverse_pair_test.go similarity index 89% rename from collections/indexes/multi_pair_test.go rename to collections/indexes/reverse_pair_test.go index 052a62ebea..55ee354f1f 100644 --- a/collections/indexes/multi_pair_test.go +++ b/collections/indexes/reverse_pair_test.go @@ -16,14 +16,14 @@ type ( // our balance index, allows us to efficiently create an index between the key that maps // balances which is a collections.Pair[Address, Denom] and the Denom. type balanceIndex struct { - Denom *MultiPair[Address, Denom, Amount] + Denom *ReversePair[Address, Denom, Amount] } func (b balanceIndex) IndexesList() []collections.Index[collections.Pair[Address, Denom], Amount] { return []collections.Index[collections.Pair[Address, Denom], Amount]{b.Denom} } -func TestMultiPair(t *testing.T) { +func TestReversePair(t *testing.T) { sk, ctx := deps() sb := collections.NewSchemaBuilder(sk) // we create an indexed map that maps balances, which are saved as @@ -37,7 +37,7 @@ func TestMultiPair(t *testing.T) { keyCodec, collections.Uint64Value, balanceIndex{ - Denom: NewMultiPair[Amount](sb, collections.NewPrefix("denom_index"), "denom_index", keyCodec), + Denom: NewReversePair[Amount](sb, collections.NewPrefix("denom_index"), "denom_index", keyCodec), }, ) diff --git a/collections/indexes/unique.go b/collections/indexes/unique.go index 7b5c9654cb..561ffad930 100644 --- a/collections/indexes/unique.go +++ b/collections/indexes/unique.go @@ -2,6 +2,8 @@ package indexes import ( "context" + "errors" + "fmt" "cosmossdk.io/collections" "cosmossdk.io/collections/codec" @@ -9,7 +11,10 @@ import ( // Unique identifies an index that imposes uniqueness constraints on the reference key. // It creates relationships between reference and primary key of the value. -type Unique[ReferenceKey, PrimaryKey, Value any] collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value] +type Unique[ReferenceKey, PrimaryKey, Value any] struct { + getRefKey func(PrimaryKey, Value) (ReferenceKey, error) + refKeys collections.Map[ReferenceKey, PrimaryKey] +} // NewUnique instantiates a new Unique index. func NewUnique[ReferenceKey, PrimaryKey, Value any]( @@ -20,34 +25,65 @@ func NewUnique[ReferenceKey, PrimaryKey, Value any]( pkCodec codec.KeyCodec[PrimaryKey], getRefKeyFunc func(pk PrimaryKey, v Value) (ReferenceKey, error), ) *Unique[ReferenceKey, PrimaryKey, Value] { - i := collections.NewGenericUniqueIndex(schema, prefix, name, refCodec, pkCodec, func(pk PrimaryKey, value Value) ([]collections.IndexReference[ReferenceKey, PrimaryKey], error) { - ref, err := getRefKeyFunc(pk, value) + return &Unique[ReferenceKey, PrimaryKey, Value]{ + getRefKey: getRefKeyFunc, + refKeys: collections.NewMap(schema, prefix, name, refCodec, codec.KeyToValueCodec(pkCodec)), + } +} + +func (i *Unique[ReferenceKey, PrimaryKey, Value]) Reference(ctx context.Context, pk PrimaryKey, newValue Value, lazyOldValue func() (Value, error)) error { + oldValue, err := lazyOldValue() + switch { + // if no error it means the value existed, and we need to remove the old indexes + case err == nil: + err = i.unreference(ctx, pk, oldValue) if err != nil { - return nil, err + return err } - - return []collections.IndexReference[ReferenceKey, PrimaryKey]{ - collections.NewIndexReference(ref, pk), - }, nil - }) - - return (*Unique[ReferenceKey, PrimaryKey, Value])(i) + // if error is ErrNotFound, it means that the object does not exist, so we're creating indexes for the first time. + // we do nothing. + case errors.Is(err, collections.ErrNotFound): + // default case means that there was some other error + default: + return err + } + // create new indexes, asserting no uniqueness constraint violation + refKey, err := i.getRefKey(pk, newValue) + if err != nil { + return err + } + has, err := i.refKeys.Has(ctx, refKey) + if err != nil { + return err + } + if has { + return fmt.Errorf("%w: index uniqueness constrain violation: %s", collections.ErrConflict, i.refKeys.KeyCodec().Stringify(refKey)) + } + return i.refKeys.Set(ctx, refKey, pk) } -func (i *Unique[ReferenceKey, PrimaryKey, Value]) Reference(ctx context.Context, pk PrimaryKey, newValue Value, oldValue *Value) error { - return (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Reference(ctx, pk, newValue, oldValue) +func (i *Unique[ReferenceKey, PrimaryKey, Value]) Unreference(ctx context.Context, pk PrimaryKey, getValue func() (Value, error)) error { + value, err := getValue() + if err != nil { + return err + } + return i.unreference(ctx, pk, value) } -func (i *Unique[ReferenceKey, PrimaryKey, Value]) Unreference(ctx context.Context, pk PrimaryKey, value Value) error { - return (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Unreference(ctx, pk, value) +func (i *Unique[ReferenceKey, PrimaryKey, Value]) unreference(ctx context.Context, pk PrimaryKey, value Value) error { + refKey, err := i.getRefKey(pk, value) + if err != nil { + return err + } + return i.refKeys.Remove(ctx, refKey) } func (i *Unique[ReferenceKey, PrimaryKey, Value]) MatchExact(ctx context.Context, ref ReferenceKey) (PrimaryKey, error) { - return (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Get(ctx, ref) + return i.refKeys.Get(ctx, ref) } func (i *Unique[ReferenceKey, PrimaryKey, Value]) Iterate(ctx context.Context, ranger collections.Ranger[ReferenceKey]) (UniqueIterator[ReferenceKey, PrimaryKey], error) { - iter, err := (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Iterate(ctx, ranger) + iter, err := i.refKeys.Iterate(ctx, ranger) return (UniqueIterator[ReferenceKey, PrimaryKey])(iter), err } @@ -56,11 +92,11 @@ func (i *Unique[ReferenceKey, PrimaryKey, Value]) Walk( ranger collections.Ranger[ReferenceKey], walkFunc func(indexingKey ReferenceKey, indexedKey PrimaryKey) bool, ) error { - return (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Walk(ctx, ranger, walkFunc) + return i.refKeys.Walk(ctx, ranger, walkFunc) } func (i *Unique[ReferenceKey, PrimaryKey, Value]) IterateRaw(ctx context.Context, start, end []byte, order collections.Order) (u UniqueIterator[ReferenceKey, PrimaryKey], err error) { - iter, err := (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).IterateRaw(ctx, start, end, order) + iter, err := i.refKeys.IterateRaw(ctx, start, end, order) if err != nil { return } diff --git a/collections/indexes/unique_test.go b/collections/indexes/unique_test.go index 6a03e94215..f3a174fc39 100644 --- a/collections/indexes/unique_test.go +++ b/collections/indexes/unique_test.go @@ -15,15 +15,15 @@ func TestUniqueIndex(t *testing.T) { }) // map company with id 1 to vat 1_1 - err := ui.Reference(ctx, 1, company{Vat: 1_1}, nil) + err := ui.Reference(ctx, 1, company{Vat: 1_1}, func() (company, error) { return company{}, collections.ErrNotFound }) require.NoError(t, err) // map company with id 2 to vat 2_2 - err = ui.Reference(ctx, 2, company{Vat: 2_2}, nil) + err = ui.Reference(ctx, 2, company{Vat: 2_2}, func() (company, error) { return company{}, collections.ErrNotFound }) require.NoError(t, err) // mapping company 3 with vat 1_1 must yield to a ErrConflict - err = ui.Reference(ctx, 1, company{Vat: 1_1}, nil) + err = ui.Reference(ctx, 1, company{Vat: 1_1}, func() (company, error) { return company{}, collections.ErrNotFound }) require.ErrorIs(t, err, collections.ErrConflict) // assert references are correct @@ -36,7 +36,7 @@ func TestUniqueIndex(t *testing.T) { require.Equal(t, uint64(2), id) // on reference updates, the new referencing key is created and the old is removed - err = ui.Reference(ctx, 1, company{Vat: 1_2}, &company{Vat: 1_1}) + err = ui.Reference(ctx, 1, company{Vat: 1_2}, func() (company, error) { return company{Vat: 1_1}, nil }) require.NoError(t, err) id, err = ui.MatchExact(ctx, 1_2) // assert a new reference is created require.NoError(t, err) diff --git a/collections/indexes_generic_multi.go b/collections/indexes_generic_multi.go deleted file mode 100644 index 97def1725e..0000000000 --- a/collections/indexes_generic_multi.go +++ /dev/null @@ -1,157 +0,0 @@ -package collections - -import ( - "context" - - "cosmossdk.io/collections/codec" -) - -func NewIndexReference[ReferencingKey, ReferencedKey any](referencing ReferencingKey, referenced ReferencedKey) IndexReference[ReferencingKey, ReferencedKey] { - return IndexReference[ReferencingKey, ReferencedKey]{ - Referring: referencing, - Referred: referenced, - } -} - -// IndexReference defines a generic index reference. -type IndexReference[ReferencingKey, ReferencedKey any] struct { - // Referring is the key that refers, points to the Referred key. - Referring ReferencingKey - // Referred is the key that is being pointed to by the Referring key. - Referred ReferencedKey -} - -// GenericMultiIndex defines a generic Index type that given a primary key -// and the value associated with that primary key returns one or multiple IndexReference. -// -// The referencing key can be anything, usually it is either a part of the primary -// key when we deal with multipart keys, or a field of Value. -// -// The referenced key usually is the primary key, or it can be a part -// of the primary key in the context of multipart keys. -// -// The Referencing and Referenced keys are joined and saved as a Pair in a KeySet -// where the key is Pair[ReferencingKey, ReferencedKey]. -// So if we wanted to get all the keys referenced by a generic (concrete) ReferencingKey -// we would just need to iterate over all the keys starting with bytes(ReferencingKey). -// -// Unless you're trying to build your generic multi index, you should be using the indexes package. -type GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any] struct { - refs KeySet[Pair[ReferencingKey, ReferencedKey]] - getRefs func(pk PrimaryKey, v Value) ([]IndexReference[ReferencingKey, ReferencedKey], error) -} - -// NewGenericMultiIndex instantiates a GenericMultiIndex, given -// schema, Prefix, humanized name, the key codec used to encode the referencing key -// to bytes, the key codec used to encode the referenced key to bytes and a function -// which given the primary key and a value of an object being saved or removed in IndexedMap -// returns all the possible IndexReference of that object. -// -// The IndexReference is usually just one. But in certain cases can be multiple, -// for example when the Value has an array field, and we want to create a relationship -// between the object and all the elements of the array contained in the object. -func NewGenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any]( - schema *SchemaBuilder, - prefix Prefix, - name string, - referencingKeyCodec codec.KeyCodec[ReferencingKey], - referencedKeyCodec codec.KeyCodec[ReferencedKey], - getRefsFunc func(pk PrimaryKey, value Value) ([]IndexReference[ReferencingKey, ReferencedKey], error), -) *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value] { - return &GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]{ - getRefs: getRefsFunc, - refs: NewKeySet(schema, prefix, name, PairKeyCodec(referencingKeyCodec, referencedKeyCodec)), - } -} - -// Iterate allows to iterate over the index. It returns a KeySetIterator of Pair[ReferencingKey, ReferencedKey]. -// K1 of the Pair is the key (referencing) pointing to K2 (referenced). -func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Iterate( - ctx context.Context, - ranger Ranger[Pair[ReferencingKey, ReferencedKey]], -) (KeySetIterator[Pair[ReferencingKey, ReferencedKey]], error) { - return i.refs.Iterate(ctx, ranger) -} - -// Has reports if there is a relationship in the index between the referencing and the referenced key. -func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Has( - ctx context.Context, - referencing ReferencingKey, - referenced ReferencedKey, -) (bool, error) { - return i.refs.Has(ctx, Join(referencing, referenced)) -} - -// Reference implements the Index interface. -func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Reference( - ctx context.Context, - pk PrimaryKey, - value Value, - oldValue *Value, -) error { - if oldValue != nil { - err := i.Unreference(ctx, pk, *oldValue) - if err != nil { - return err - } - } - - refKeys, err := i.getRefs(pk, value) - if err != nil { - return err - } - - for _, ref := range refKeys { - err := i.refs.Set(ctx, Join(ref.Referring, ref.Referred)) - if err != nil { - return err - } - } - - return nil -} - -// Unreference implements the Index interface. -func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Unreference( - ctx context.Context, - pk PrimaryKey, - value Value, -) error { - refs, err := i.getRefs(pk, value) - if err != nil { - return err - } - - for _, ref := range refs { - err = i.refs.Remove(ctx, Join(ref.Referring, ref.Referred)) - if err != nil { - return err - } - } - - return nil -} - -func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) IterateRaw( - ctx context.Context, - start, end []byte, - order Order, -) (Iterator[Pair[ReferencingKey, ReferencedKey], NoValue], error) { - return i.refs.IterateRaw(ctx, start, end, order) -} - -func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Walk( - ctx context.Context, - ranger Ranger[Pair[ReferencingKey, ReferencedKey]], - walkFunc func(referencingKey ReferencingKey, referencedKey ReferencedKey) bool, -) error { - return i.refs.Walk(ctx, ranger, func(key Pair[ReferencingKey, ReferencedKey]) bool { return walkFunc(key.K1(), key.K2()) }) -} - -func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) KeyCodec() codec.KeyCodec[Pair[ReferencingKey, ReferencedKey]] { - return i.refs.KeyCodec() -} - -func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) ValueCodec() codec.ValueCodec[NoValue] { - return i.refs.ValueCodec() -} diff --git a/collections/indexes_generic_multi_test.go b/collections/indexes_generic_multi_test.go deleted file mode 100644 index 6634dfcaec..0000000000 --- a/collections/indexes_generic_multi_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package collections - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -type coin struct { - denom string // this will be used as indexing field. - amount uint64 -} - -type balance struct { - coins []coin -} - -func TestGenericMultiIndex(t *testing.T) { - // we are simulating a context in which we have the following mapping: - // - // address (represented as string) => balance (slice of coins). - // - // we want to create an index that creates a relationship between the coin - // denom, which is part of the balance structure, and the address. This means - // we know given a denom who are the addresses holding that denom. - // From GenericMultiIndex point of view, the denom field of the array becomes - // the referencing key which points to the address (string), which is the key - // being referenced. - sk, ctx := deps() - sb := NewSchemaBuilder(sk) - mi := NewGenericMultiIndex( - sb, NewPrefix("denoms"), "denom_to_owner", StringKey, StringKey, - func(pk string, value balance) ([]IndexReference[string, string], error) { - // the referencing keys are all the denoms. - refs := make([]IndexReference[string, string], len(value.coins)) - // the index reference being created, generates a relationship - // between denom (the key that references) and pk (address, the key - // that is being referenced). - for i, coin := range value.coins { - refs[i] = NewIndexReference(coin.denom, pk) - } - return refs, nil - }, - ) - - // let's create the relationships - err := mi.Reference(ctx, "cosmosAddr1", balance{coins: []coin{ - {"atom", 1000}, {"osmo", 5000}, - }}, nil) - require.NoError(t, err) - - // we must find relations between cosmosaddr1 and the denom atom and osmo - iter, err := mi.Iterate(ctx, nil) - require.NoError(t, err) - - keys, err := iter.Keys() - require.NoError(t, err) - require.Len(t, keys, 2) - require.Equal(t, keys[0].K1(), "atom") // assert relationship with atom created - require.Equal(t, keys[1].K1(), "osmo") // assert relationship with osmo created - - // if we update the reference to remove osmo as balance then we must not find it anymore - err = mi.Reference(ctx, "cosmosAddr1", balance{coins: []coin{{"atom", 1000}}}, // this is the update which does not have osmo - &balance{coins: []coin{{"atom", 1000}, {"osmo", 5000}}}, // this is the previous record - ) - require.NoError(t, err) - - exists, err := mi.Has(ctx, "osmo", "cosmosAddr1") // osmo must not exist anymore - require.NoError(t, err) - require.False(t, exists) - - exists, err = mi.Has(ctx, "atom", "cosmosAddr1") // atom still exists - require.NoError(t, err) - require.True(t, exists) - - // if we unreference then no relationship is maintained anymore - err = mi.Unreference(ctx, "cosmosAddr1", balance{coins: []coin{{"atom", 1000}}}) - require.NoError(t, err) - - exists, err = mi.Has(ctx, "atom", "cosmosAddr1") // atom is not part of the index anymore because cosmosAddr1 was removed. - require.NoError(t, err) - require.False(t, exists) -} diff --git a/collections/indexes_generic_unique.go b/collections/indexes_generic_unique.go deleted file mode 100644 index 1b2d7b4c58..0000000000 --- a/collections/indexes_generic_unique.go +++ /dev/null @@ -1,122 +0,0 @@ -package collections - -import ( - "context" - "fmt" - - "cosmossdk.io/collections/codec" -) - -// GenericUniqueIndex defines a generic index which enforces uniqueness constraints -// between ReferencingKey and ReferencedKey, meaning that one referencing key maps -// only one referenced key. The same referenced key can be mapped by multiple referencing keys. -// -// The referencing key can be anything, usually it is either a part of the primary -// key when we deal with multipart keys, or a field of Value. -// -// The referenced key usually is the primary key, or it can be a part -// of the primary key in the context of multipart keys. -// -// The referencing and referenced keys are mapped together using a Map. -// -// Unless you're trying to build your generic unique index, you should be using the indexes package. -type GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any] struct { - refs Map[ReferencingKey, ReferencedKey] - getRefs func(pk PrimaryKey, value Value) ([]IndexReference[ReferencingKey, ReferencedKey], error) -} - -// NewGenericUniqueIndex instantiates a GenericUniqueIndex. Works in the same way as NewGenericMultiIndex. -func NewGenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any]( - schema *SchemaBuilder, - prefix Prefix, - name string, - referencingKeyCodec codec.KeyCodec[ReferencingKey], - referencedKeyCodec codec.KeyCodec[ReferencedKey], - getRefs func(pk PrimaryKey, value Value) ([]IndexReference[ReferencingKey, ReferencedKey], error), -) *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value] { - return &GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]{ - refs: NewMap[ReferencingKey, ReferencedKey](schema, prefix, name, referencingKeyCodec, codec.KeyToValueCodec(referencedKeyCodec)), - getRefs: getRefs, - } -} - -func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Iterate( - ctx context.Context, - ranger Ranger[ReferencingKey], -) (Iterator[ReferencingKey, ReferencedKey], error) { - return i.refs.Iterate(ctx, ranger) -} - -func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Get(ctx context.Context, ref ReferencingKey) (ReferencedKey, error) { - return i.refs.Get(ctx, ref) -} - -// Reference implements Index. -func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Reference( - ctx context.Context, - pk PrimaryKey, - newValue Value, - oldValue *Value, -) error { - if oldValue != nil { - err := i.Unreference(ctx, pk, *oldValue) - if err != nil { - return err - } - } - refs, err := i.getRefs(pk, newValue) - if err != nil { - return err - } - for _, ref := range refs { - has, err := i.refs.Has(ctx, ref.Referring) - if err != nil { - return err - } - if has { - return fmt.Errorf("%w: index uniqueness constrain violation: %s", ErrConflict, i.refs.kc.Stringify(ref.Referring)) - } - err = i.refs.Set(ctx, ref.Referring, ref.Referred) - if err != nil { - return err - } - } - return nil -} - -// Unreference implements Index. -func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Unreference( - ctx context.Context, - pk PrimaryKey, - value Value, -) error { - refs, err := i.getRefs(pk, value) - if err != nil { - return err - } - - for _, ref := range refs { - err = i.refs.Remove(ctx, ref.Referring) - if err != nil { - return err - } - } - - return nil -} - -func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) IterateRaw( - ctx context.Context, - start, end []byte, - order Order, -) (Iterator[ReferencingKey, ReferencedKey], error) { - return i.refs.IterateRaw(ctx, start, end, order) -} - -func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Walk( - ctx context.Context, - ranger Ranger[ReferencingKey], - walkFunc func(referencingKey ReferencingKey, referencedKey ReferencedKey) bool, -) error { - return i.refs.Walk(ctx, ranger, func(k ReferencingKey, v ReferencedKey) bool { return walkFunc(k, v) }) -} diff --git a/collections/indexes_generic_unique_test.go b/collections/indexes_generic_unique_test.go deleted file mode 100644 index ae1e76ccec..0000000000 --- a/collections/indexes_generic_unique_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package collections - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -type nftBalance struct { - nftIDs []uint64 -} - -func TestGenericUniqueIndex(t *testing.T) { - // we create the same testing context as with GenericMultiIndex. We have a mapping: - // Address => NFT balance. - // An NFT balance is represented as a slice of IDs, those IDs are unique, meaning that - // they can be held only by one address. - sk, ctx := deps() - sb := NewSchemaBuilder(sk) - ui := NewGenericUniqueIndex( - sb, NewPrefix("nft_to_owner_index"), "ntf_to_owner_index", Uint64Key, StringKey, - func(pk string, value nftBalance) ([]IndexReference[uint64, string], error) { - // the referencing keys are all the NFT unique ids. - refs := make([]IndexReference[uint64, string], len(value.nftIDs)) - // for each NFT contained in the balance we create an index reference - // between the NFT unique ID and the owner of the balance. - for i, id := range value.nftIDs { - refs[i] = NewIndexReference(id, pk) - } - return refs, nil - }, - ) - - // let's create the relationships - err := ui.Reference(ctx, "cosmosAddr1", nftBalance{nftIDs: []uint64{0, 1}}, nil) - require.NoError(t, err) - - // assert relations were created - iter, err := ui.Iterate(ctx, nil) - require.NoError(t, err) - defer iter.Close() - - kv, err := iter.KeyValues() - require.NoError(t, err) - require.Len(t, kv, 2) - require.Equal(t, kv[0].Key, uint64(0)) - require.Equal(t, kv[0].Value, "cosmosAddr1") - require.Equal(t, kv[1].Key, uint64(1)) - require.Equal(t, kv[1].Value, "cosmosAddr1") - - // assert only one address can own a unique NFT - err = ui.Reference(ctx, "cosmosAddr2", nftBalance{nftIDs: []uint64{0}}, nil) // nft with ID 0 is already owned by cosmosAddr1 - require.ErrorIs(t, err, ErrConflict) - - // during modifications references are updated, we update the index in - // such a way that cosmosAddr1 loses ownership of nft with id 0. - err = ui.Reference(ctx, "cosmosAddr1", - nftBalance{nftIDs: []uint64{1}}, // this is the update nft balance, which contains only id 1 - &nftBalance{nftIDs: []uint64{0, 1}}, // this is the old nft balance, which contains both 0 and 1 - ) - require.NoError(t, err) - - // the updated balance does not contain nft with id 0 - _, err = ui.Get(ctx, 0) - require.ErrorIs(t, err, ErrNotFound) - - // unreferencing clears all the indexes - err = ui.Unreference(ctx, "cosmosAddr1", nftBalance{nftIDs: []uint64{1}}) - require.NoError(t, err) - _, err = ui.Get(ctx, 1) - require.ErrorIs(t, err, ErrNotFound) -}