diff --git a/collections/CHANGELOG.md b/collections/CHANGELOG.md index 736c20657d..7958edea26 100644 --- a/collections/CHANGELOG.md +++ b/collections/CHANGELOG.md @@ -35,4 +35,5 @@ Ref: https://keepachangelog.com/en/1.0.0/ * [#14351](https://github.com/cosmos/cosmos-sdk/pull/14351) Add keyset * [#14364](https://github.com/cosmos/cosmos-sdk/pull/14364) Add sequence * [#14468](https://github.com/cosmos/cosmos-sdk/pull/14468) Add Map.IterateRaw API. -* [#14310](https://github.com/cosmos/cosmos-sdk/pull/14310) Add Pair keys \ No newline at end of file +* [#14310](https://github.com/cosmos/cosmos-sdk/pull/14310) Add Pair keys +* [#14397](https://github.com/cosmos/cosmos-sdk/pull/14397) Add IndexedMap \ No newline at end of file diff --git a/collections/collections.go b/collections/collections.go index 84889a493c..e175ffc497 100644 --- a/collections/collections.go +++ b/collections/collections.go @@ -10,6 +10,8 @@ var ( ErrNotFound = errors.New("collections: not found") // ErrEncoding is returned when something fails during key or value encoding/decoding. ErrEncoding = errors.New("collections: encoding error") + // ErrConflict is returned when there are conflicts, for example in UniqueIndex. + ErrConflict = errors.New("collections: conflict") ) // collection is the interface that all collections support. It will eventually diff --git a/collections/collections_test.go b/collections/collections_test.go index 51b476b20f..a85e0ecd2c 100644 --- a/collections/collections_test.go +++ b/collections/collections_test.go @@ -29,7 +29,6 @@ func (t testStore) Has(key []byte) (bool, error) { func (t testStore) Set(key, value []byte) error { return t.db.Set(key, value) - } func (t testStore) Delete(key []byte) error { diff --git a/collections/colltest/codec.go b/collections/colltest/codec.go index 33bcec2469..78ab86e060 100644 --- a/collections/colltest/codec.go +++ b/collections/colltest/codec.go @@ -1,9 +1,14 @@ package colltest import ( - "cosmossdk.io/collections" - "github.com/stretchr/testify/require" + "encoding/json" + "fmt" + "reflect" "testing" + + "cosmossdk.io/collections" + + "github.com/stretchr/testify/require" ) // TestKeyCodec asserts the correct behaviour of a KeyCodec over the type T. @@ -21,6 +26,7 @@ func TestKeyCodec[T any](t *testing.T, keyCodec collections.KeyCodec[T], key T) pairKey := collections.Join(key, "TEST") buffer = make([]byte, pairCodec.Size(pairKey)) written, err = pairCodec.Encode(buffer, pairKey) + require.Equal(t, len(buffer), written, "the pair buffer should have been fully written") require.NoError(t, err) read, decodedPairKey, err := pairCodec.Decode(buffer) require.NoError(t, err) @@ -53,3 +59,101 @@ func TestValueCodec[T any](t *testing.T, encoder collections.ValueCodec[T], valu _ = encoder.Stringify(value) } + +// MockValueCodec returns a mock of collections.ValueCodec for type T, it +// can be used for collections Values testing. It also supports interfaces. +// For the interfaces cases, in order for an interface to be decoded it must +// have been encoded first. Not concurrency safe. +// EG: +// Let's say the value is interface Animal +// if I want to decode Dog which implements Animal, then I need to first encode +// it in order to make the type known by the MockValueCodec. +func MockValueCodec[T any]() collections.ValueCodec[T] { + typ := reflect.ValueOf(new(T)).Elem().Type() + isInterface := false + if typ.Kind() == reflect.Interface { + isInterface = true + } + return &mockValueCodec[T]{ + isInterface: isInterface, + seenTypes: map[string]reflect.Type{}, + valueType: fmt.Sprintf("%s.%s", typ.PkgPath(), typ.Name()), + } +} + +type mockValueJSON struct { + TypeName string `json:"type_name"` + Value json.RawMessage `json:"value"` +} + +type mockValueCodec[T any] struct { + isInterface bool + seenTypes map[string]reflect.Type + valueType string +} + +func (m mockValueCodec[T]) Encode(value T) ([]byte, error) { + typeName := m.getTypeName(value) + valueBytes, err := json.Marshal(value) + if err != nil { + return nil, err + } + + return json.Marshal(mockValueJSON{ + TypeName: typeName, + Value: valueBytes, + }) +} + +func (m mockValueCodec[T]) Decode(b []byte) (t T, err error) { + wrappedValue := mockValueJSON{} + err = json.Unmarshal(b, &wrappedValue) + if err != nil { + return + } + if !m.isInterface { + err = json.Unmarshal(wrappedValue.Value, &t) + return t, err + } + + typ, exists := m.seenTypes[wrappedValue.TypeName] + if !exists { + return t, fmt.Errorf("uknown type %s, you're dealing with interfaces... in order to make the interface types known for the MockValueCodec, you need to first encode them", wrappedValue.TypeName) + } + + newT := reflect.New(typ).Interface() + err = json.Unmarshal(wrappedValue.Value, newT) + if err != nil { + return t, err + } + + iface := new(T) + reflect.ValueOf(iface).Elem().Set(reflect.ValueOf(newT).Elem()) + return *iface, nil +} + +func (m mockValueCodec[T]) EncodeJSON(value T) ([]byte, error) { + return m.Encode(value) +} + +func (m mockValueCodec[T]) DecodeJSON(b []byte) (T, error) { + return m.Decode(b) +} + +func (m mockValueCodec[T]) Stringify(value T) string { + return fmt.Sprintf("%#v", value) +} + +func (m mockValueCodec[T]) ValueType() string { + return m.valueType +} + +func (m mockValueCodec[T]) getTypeName(value T) string { + if !m.isInterface { + return m.valueType + } + typ := reflect.TypeOf(value) + name := fmt.Sprintf("%s.%s", typ.PkgPath(), typ.Name()) + m.seenTypes[name] = typ + return name +} diff --git a/collections/colltest/codec_test.go b/collections/colltest/codec_test.go new file mode 100644 index 0000000000..bcfeed57bb --- /dev/null +++ b/collections/colltest/codec_test.go @@ -0,0 +1,48 @@ +package colltest + +import "testing" + +type animal interface { + name() string +} + +type dog struct { + Name string `json:"name"` + BarksLoudly bool `json:"barks_loudly"` +} + +type cat struct { + Name string `json:"name"` + Scratches bool `json:"scratches"` +} + +func (d *cat) name() string { return d.Name } + +func (d dog) name() string { return d.Name } + +func TestMockValueCodec(t *testing.T) { + t.Run("primitive type", func(t *testing.T) { + x := MockValueCodec[string]() + TestValueCodec(t, x, "hello") + }) + + t.Run("struct type", func(t *testing.T) { + x := MockValueCodec[dog]() + TestValueCodec(t, x, dog{ + Name: "kernel", + BarksLoudly: true, + }) + }) + + t.Run("interface type", func(t *testing.T) { + x := MockValueCodec[animal]() + TestValueCodec[animal](t, x, dog{ + Name: "kernel", + BarksLoudly: true, + }) + TestValueCodec[animal](t, x, &cat{ + Name: "echo", + Scratches: true, + }) + }) +} diff --git a/collections/colltest/store.go b/collections/colltest/store.go new file mode 100644 index 0000000000..a02b1f421f --- /dev/null +++ b/collections/colltest/store.go @@ -0,0 +1,49 @@ +package colltest + +import ( + "context" + + "cosmossdk.io/core/store" + db "github.com/cosmos/cosmos-db" +) + +// MockStore returns a mock store.KVStoreService and a mock context.Context. +// They can be used to test collections. +func MockStore() (store.KVStoreService, context.Context) { + kv := db.NewMemDB() + return &testStore{kv}, context.Background() +} + +type testStore struct { + db db.DB +} + +func (t testStore) OpenKVStore(ctx context.Context) store.KVStore { + return t +} + +func (t testStore) Get(key []byte) ([]byte, error) { + return t.db.Get(key) +} + +func (t testStore) Has(key []byte) (bool, error) { + return t.db.Has(key) +} + +func (t testStore) Set(key, value []byte) error { + return t.db.Set(key, value) +} + +func (t testStore) Delete(key []byte) error { + return t.db.Delete(key) +} + +func (t testStore) Iterator(start, end []byte) (store.Iterator, error) { + return t.db.Iterator(start, end) +} + +func (t testStore) ReverseIterator(start, end []byte) (store.Iterator, error) { + return t.db.ReverseIterator(start, end) +} + +var _ store.KVStore = testStore{} diff --git a/collections/correctness_test.go b/collections/correctness_test.go index 2f059e3ea7..cd999869c0 100644 --- a/collections/correctness_test.go +++ b/collections/correctness_test.go @@ -1,9 +1,10 @@ package collections_test import ( + "testing" + "cosmossdk.io/collections" "cosmossdk.io/collections/colltest" - "testing" ) func TestKeyCorrectness(t *testing.T) { diff --git a/collections/genesis.go b/collections/genesis.go index e03a766b29..454ba496e1 100644 --- a/collections/genesis.go +++ b/collections/genesis.go @@ -20,13 +20,13 @@ type jsonMapEntry struct { } func (m Map[K, V]) validateGenesis(reader io.Reader) error { - return m.doDecodeJson(reader, func(key K, value V) error { + return m.doDecodeJSON(reader, func(key K, value V) error { return nil }) } func (m Map[K, V]) importGenesis(ctx context.Context, reader io.Reader) error { - return m.doDecodeJson(reader, func(key K, value V) error { + return m.doDecodeJSON(reader, func(key K, value V) error { return m.Set(ctx, key, value) }) } @@ -95,7 +95,7 @@ func (m Map[K, V]) exportGenesis(ctx context.Context, writer io.Writer) error { return err } -func (m Map[K, V]) doDecodeJson(reader io.Reader, onEntry func(key K, value V) error) error { +func (m Map[K, V]) doDecodeJSON(reader io.Reader, onEntry func(key K, value V) error) error { decoder := json.NewDecoder(reader) token, err := decoder.Token() if err != nil { @@ -107,14 +107,14 @@ func (m Map[K, V]) doDecodeJson(reader io.Reader, onEntry func(key K, value V) e } for decoder.More() { - var rawJson json.RawMessage - err := decoder.Decode(&rawJson) + var rawJSON json.RawMessage + err := decoder.Decode(&rawJSON) if err != nil { return err } var mapEntry jsonMapEntry - err = json.Unmarshal(rawJson, &mapEntry) + err = json.Unmarshal(rawJSON, &mapEntry) if err != nil { return err } diff --git a/collections/indexed_map.go b/collections/indexed_map.go new file mode 100644 index 0000000000..0a164fdd93 --- /dev/null +++ b/collections/indexed_map.go @@ -0,0 +1,135 @@ +package collections + +import ( + "context" + "errors" + "fmt" +) + +// Indexes represents a type which groups multiple Index +// of one Value saved with the provided PrimaryKey. +// Indexes is just meant to be a struct containing all +// the indexes to maintain relationship for. +type Indexes[PrimaryKey, Value any] interface { + // IndexesList is implemented by the Indexes type + // and returns all the grouped Index of Value. + IndexesList() []Index[PrimaryKey, Value] +} + +// Index represents an index of the Value indexed using the type PrimaryKey. +type Index[PrimaryKey, Value any] interface { + // Reference creates a reference between the provided primary key and value. + // If oldValue is not nil then the Index must update the references + // of the primary key associated with the new value and remove the + // old invalid references. + Reference(ctx context.Context, pk PrimaryKey, newValue Value, oldValue *Value) error + // Unreference removes the reference between the primary key and value. + Unreference(ctx context.Context, pk PrimaryKey, value Value) error +} + +// IndexedMap works like a Map but creates references between fields of Value and its PrimaryKey. +// These relationships are expressed and maintained using the Indexes type. +// Internally IndexedMap can be seen as a partitioned collection, one partition +// is a Map[PrimaryKey, Value], that maintains the object, the second +// are the Indexes. +type IndexedMap[PrimaryKey, Value any, Idx Indexes[PrimaryKey, Value]] struct { + Indexes Idx + m Map[PrimaryKey, Value] +} + +// NewIndexedMap instantiates a new IndexedMap. Accepts a SchemaBuilder, a Prefix, +// a humanised name that defines the name of the collection, the primary key codec +// which is basically what IndexedMap uses to encode the primary key to bytes, +// the value codec which is what the IndexedMap uses to encode the value. +// Then it expects the initialised indexes. +func NewIndexedMap[PrimaryKey, Value any, Idx Indexes[PrimaryKey, Value]]( + schema *SchemaBuilder, + prefix Prefix, + name string, + pkCodec KeyCodec[PrimaryKey], + valueCodec ValueCodec[Value], + indexes Idx, +) *IndexedMap[PrimaryKey, Value, Idx] { + return &IndexedMap[PrimaryKey, Value, Idx]{ + Indexes: indexes, + m: NewMap(schema, prefix, name, pkCodec, valueCodec), + } +} + +// Get gets the object given its primary key. +func (m *IndexedMap[PrimaryKey, Value, Idx]) Get(ctx context.Context, pk PrimaryKey) (Value, error) { + return m.m.Get(ctx, pk) +} + +// Iterate allows to iterate over the objects given a Ranger of the primary key. +func (m *IndexedMap[PrimaryKey, Value, Idx]) Iterate(ctx context.Context, ranger Ranger[PrimaryKey]) (Iterator[PrimaryKey, Value], error) { + return m.m.Iterate(ctx, ranger) +} + +// Has reports if exists a value with the provided primary key. +func (m *IndexedMap[PrimaryKey, Value, Idx]) Has(ctx context.Context, pk PrimaryKey) (bool, error) { + return m.m.Has(ctx, pk) +} + +// Set maps the value using the primary key. It will also iterate every index and instruct them to +// add or update the indexes. +func (m *IndexedMap[PrimaryKey, Value, Idx]) Set(ctx context.Context, pk PrimaryKey, value Value) error { + // we need to see if there was a previous instance of the value + oldValue, err := m.m.Get(ctx, pk) + switch { + // update indexes + case err == nil: + err = m.ref(ctx, pk, value, &oldValue) + if err != nil { + return fmt.Errorf("collections: indexing error: %w", err) + } + // create new indexes + case errors.Is(err, ErrNotFound): + err = m.ref(ctx, pk, value, nil) + if err != nil { + return fmt.Errorf("collections: indexing error: %w", err) + } + // cannot move forward error + default: + return err + } + + return m.m.Set(ctx, pk, value) +} + +// Remove removes the value associated with the primary key from the map. Then +// it iterates over all the indexes and instructs them to remove all the references +// associated with the removed value. +func (m *IndexedMap[PrimaryKey, Value, Idx]) Remove(ctx context.Context, pk PrimaryKey) error { + oldValue, err := m.m.Get(ctx, pk) + if err != nil { + // TODO retain Map behaviour? which does not error in case we remove a non-existing object + return err + } + + err = m.unref(ctx, pk, oldValue) + if err != nil { + return err + } + return m.m.Remove(ctx, pk) +} + +func (m *IndexedMap[PrimaryKey, Value, Idx]) ref(ctx context.Context, pk PrimaryKey, value Value, oldValue *Value) error { + for _, index := range m.Indexes.IndexesList() { + err := index.Reference(ctx, pk, value, oldValue) + if err != nil { + return err + } + } + return nil +} + +func (m *IndexedMap[PrimaryKey, Value, Idx]) unref(ctx context.Context, pk PrimaryKey, value Value) error { + for _, index := range m.Indexes.IndexesList() { + err := index.Unreference(ctx, pk, value) + if err != nil { + return err + } + } + return nil +} diff --git a/collections/indexed_map_test.go b/collections/indexed_map_test.go new file mode 100644 index 0000000000..707537883b --- /dev/null +++ b/collections/indexed_map_test.go @@ -0,0 +1,104 @@ +package collections_test + +import ( + "testing" + + "cosmossdk.io/collections" + "cosmossdk.io/collections/colltest" + "github.com/stretchr/testify/require" +) + +type company struct { + City string + Vat uint64 +} + +type companyIndexes struct { + // City is an index of the company indexed map. It indexes a company + // given its city. The index is multi, meaning that there can be multiple + // companies from the same city. + City *collections.GenericMultiIndex[string, string, string, company] + // Vat is an index of the company indexed map. It indexes a company + // given its VAT number. The index is unique, meaning that there can be + // only one VAT number for a company. + Vat *collections.GenericUniqueIndex[uint64, string, string, company] +} + +func (c companyIndexes) IndexesList() []collections.Index[string, company] { + return []collections.Index[string, company]{c.City, c.Vat} +} + +func newTestIndexedMap(schema *collections.SchemaBuilder) *collections.IndexedMap[string, company, companyIndexes] { + return collections.NewIndexedMap(schema, collections.NewPrefix(0), "companies", collections.StringKey, colltest.MockValueCodec[company](), + companyIndexes{ + City: collections.NewGenericMultiIndex(schema, collections.NewPrefix(1), "companies_by_city", collections.StringKey, collections.StringKey, func(pk string, value company) ([]collections.IndexReference[string, string], error) { + return []collections.IndexReference[string, string]{collections.NewIndexReference(value.City, pk)}, nil + }), + Vat: collections.NewGenericUniqueIndex(schema, collections.NewPrefix(2), "companies_by_vat", collections.Uint64Key, collections.StringKey, func(pk string, v company) ([]collections.IndexReference[uint64, string], error) { + return []collections.IndexReference[uint64, string]{collections.NewIndexReference(v.Vat, pk)}, nil + }), + }, + ) +} + +func TestIndexedMap(t *testing.T) { + sk, ctx := colltest.MockStore() + schema := collections.NewSchemaBuilder(sk) + + im := newTestIndexedMap(schema) + + // test insertion + err := im.Set(ctx, "1", company{ + City: "milan", + Vat: 0, + }) + require.NoError(t, err) + + err = im.Set(ctx, "2", company{ + City: "milan", + Vat: 1, + }) + require.NoError(t, err) + + err = im.Set(ctx, "3", company{ + City: "milan", + Vat: 4, + }) + require.NoError(t, err) + + pk, err := im.Indexes.Vat.Get(ctx, 1) + require.NoError(t, err) + require.Equal(t, "2", pk) + + // test a set which updates the indexes + err = im.Set(ctx, "2", company{ + City: "milan", + Vat: 2, + }) + require.NoError(t, err) + + pk, err = im.Indexes.Vat.Get(ctx, 2) + require.NoError(t, err) + require.Equal(t, "2", pk) + + _, err = im.Indexes.Vat.Get(ctx, 1) + require.ErrorIs(t, err, collections.ErrNotFound) + + // test removal + err = im.Remove(ctx, "2") + require.NoError(t, err) + _, err = im.Indexes.Vat.Get(ctx, 2) + require.ErrorIs(t, err, collections.ErrNotFound) + + // test iteration + iter, err := im.Iterate(ctx, nil) + require.NoError(t, err) + keys, err := iter.Keys() + require.NoError(t, err) + require.Equal(t, []string{"1", "3"}, keys) + + // test get + v, err := im.Get(ctx, "3") + require.NoError(t, err) + require.Equal(t, company{"milan", 4}, v) +} diff --git a/collections/indexes_generic_multi.go b/collections/indexes_generic_multi.go new file mode 100644 index 0000000000..a1a609ff78 --- /dev/null +++ b/collections/indexes_generic_multi.go @@ -0,0 +1,129 @@ +package collections + +import "context" + +func NewIndexReference[ReferencingKey, ReferencedKey any](referencing ReferencingKey, referenced ReferencedKey) IndexReference[ReferencingKey, ReferencedKey] { + return IndexReference[ReferencingKey, ReferencedKey]{ + Referring: referencing, + Referred: referenced, + } +} + +// IndexReference defines a generic index reference. +type IndexReference[ReferencingKey, ReferencedKey any] struct { + // Referring is the key that refers, points to the Referred key. + Referring ReferencingKey + // Referred is the key that is being pointed to by the Referring key. + Referred ReferencedKey +} + +// GenericMultiIndex defines a generic Index type that given a primary key +// and the value associated with that primary key returns one or multiple IndexReference. +// +// The referencing key can be anything, usually it is either a part of the primary +// key when we deal with multipart keys, or a field of Value. +// +// The referenced key usually is the primary key, or it can be a part +// of the primary key in the context of multipart keys. +// +// The Referencing and Referenced keys are joined and saved as a Pair in a KeySet +// where the key is Pair[ReferencingKey, ReferencedKey]. +// So if we wanted to get all the keys referenced by a generic (concrete) ReferencingKey +// we would just need to iterate over all the keys starting with bytes(ReferencingKey). +// +// Unless you're trying to build your generic multi index, you should be using the indexes package. +type GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any] struct { + refs KeySet[Pair[ReferencingKey, ReferencedKey]] + getRefs func(pk PrimaryKey, v Value) ([]IndexReference[ReferencingKey, ReferencedKey], error) +} + +// NewGenericMultiIndex instantiates a GenericMultiIndex, given +// schema, Prefix, humanised name, the key codec used to encode the referencing key +// to bytes, the key codec used to encode the referenced key to bytes and a function +// which given the primary key and a value of an object being saved or removed in IndexedMap +// returns all the possible IndexReference of that object. +// +// The IndexReference is usually just one. But in certain cases can be multiple, +// for example when the Value has an array field, and we want to create a relationship +// between the object and all the elements of the array contained in the object. +func NewGenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any]( + schema *SchemaBuilder, + prefix Prefix, + name string, + referencingKeyCodec KeyCodec[ReferencingKey], + referencedKeyCodec KeyCodec[ReferencedKey], + getRefsFunc func(pk PrimaryKey, value Value) ([]IndexReference[ReferencingKey, ReferencedKey], error), +) *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value] { + return &GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]{ + getRefs: getRefsFunc, + refs: NewKeySet(schema, prefix, name, PairKeyCodec(referencingKeyCodec, referencedKeyCodec)), + } +} + +// Iterate allows to iterate over the index. It returns a KeySetIterator of Pair[ReferencingKey, ReferencedKey]. +// K1 of the Pair is the key (referencing) pointing to K2 (referenced). +func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Iterate( + ctx context.Context, + ranger Ranger[Pair[ReferencingKey, ReferencedKey]], +) (KeySetIterator[Pair[ReferencingKey, ReferencedKey]], error) { + return i.refs.Iterate(ctx, ranger) +} + +// Has reports if there is a relationship in the index between the referencing and the referenced key. +func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Has( + ctx context.Context, + referencing ReferencingKey, + referenced ReferencedKey, +) (bool, error) { + return i.refs.Has(ctx, Join(referencing, referenced)) +} + +// Reference implements the Index interface. +func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Reference( + ctx context.Context, + pk PrimaryKey, + value Value, + oldValue *Value, +) error { + if oldValue != nil { + err := i.Unreference(ctx, pk, *oldValue) + if err != nil { + return err + } + } + + refKeys, err := i.getRefs(pk, value) + if err != nil { + return err + } + + for _, ref := range refKeys { + err := i.refs.Set(ctx, Join(ref.Referring, ref.Referred)) + if err != nil { + return err + } + } + + return nil +} + +// Unreference implements the Index interface. +func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Unreference( + ctx context.Context, + pk PrimaryKey, + value Value, +) error { + refs, err := i.getRefs(pk, value) + if err != nil { + return err + } + + for _, ref := range refs { + err = i.refs.Remove(ctx, Join(ref.Referring, ref.Referred)) + if err != nil { + return err + } + } + + return nil +} diff --git a/collections/indexes_generic_multi_test.go b/collections/indexes_generic_multi_test.go new file mode 100644 index 0000000000..6634dfcaec --- /dev/null +++ b/collections/indexes_generic_multi_test.go @@ -0,0 +1,83 @@ +package collections + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +type coin struct { + denom string // this will be used as indexing field. + amount uint64 +} + +type balance struct { + coins []coin +} + +func TestGenericMultiIndex(t *testing.T) { + // we are simulating a context in which we have the following mapping: + // + // address (represented as string) => balance (slice of coins). + // + // we want to create an index that creates a relationship between the coin + // denom, which is part of the balance structure, and the address. This means + // we know given a denom who are the addresses holding that denom. + // From GenericMultiIndex point of view, the denom field of the array becomes + // the referencing key which points to the address (string), which is the key + // being referenced. + sk, ctx := deps() + sb := NewSchemaBuilder(sk) + mi := NewGenericMultiIndex( + sb, NewPrefix("denoms"), "denom_to_owner", StringKey, StringKey, + func(pk string, value balance) ([]IndexReference[string, string], error) { + // the referencing keys are all the denoms. + refs := make([]IndexReference[string, string], len(value.coins)) + // the index reference being created, generates a relationship + // between denom (the key that references) and pk (address, the key + // that is being referenced). + for i, coin := range value.coins { + refs[i] = NewIndexReference(coin.denom, pk) + } + return refs, nil + }, + ) + + // let's create the relationships + err := mi.Reference(ctx, "cosmosAddr1", balance{coins: []coin{ + {"atom", 1000}, {"osmo", 5000}, + }}, nil) + require.NoError(t, err) + + // we must find relations between cosmosaddr1 and the denom atom and osmo + iter, err := mi.Iterate(ctx, nil) + require.NoError(t, err) + + keys, err := iter.Keys() + require.NoError(t, err) + require.Len(t, keys, 2) + require.Equal(t, keys[0].K1(), "atom") // assert relationship with atom created + require.Equal(t, keys[1].K1(), "osmo") // assert relationship with osmo created + + // if we update the reference to remove osmo as balance then we must not find it anymore + err = mi.Reference(ctx, "cosmosAddr1", balance{coins: []coin{{"atom", 1000}}}, // this is the update which does not have osmo + &balance{coins: []coin{{"atom", 1000}, {"osmo", 5000}}}, // this is the previous record + ) + require.NoError(t, err) + + exists, err := mi.Has(ctx, "osmo", "cosmosAddr1") // osmo must not exist anymore + require.NoError(t, err) + require.False(t, exists) + + exists, err = mi.Has(ctx, "atom", "cosmosAddr1") // atom still exists + require.NoError(t, err) + require.True(t, exists) + + // if we unreference then no relationship is maintained anymore + err = mi.Unreference(ctx, "cosmosAddr1", balance{coins: []coin{{"atom", 1000}}}) + require.NoError(t, err) + + exists, err = mi.Has(ctx, "atom", "cosmosAddr1") // atom is not part of the index anymore because cosmosAddr1 was removed. + require.NoError(t, err) + require.False(t, exists) +} diff --git a/collections/indexes_generic_unique.go b/collections/indexes_generic_unique.go new file mode 100644 index 0000000000..8855b9ce32 --- /dev/null +++ b/collections/indexes_generic_unique.go @@ -0,0 +1,145 @@ +package collections + +import ( + "context" + "fmt" +) + +// GenericUniqueIndex defines a generic index which enforces uniqueness constraints +// between ReferencingKey and ReferencedKey, meaning that one referencing key maps +// only one referenced key. The same referenced key can be mapped by multiple referencing keys. +// +// The referencing key can be anything, usually it is either a part of the primary +// key when we deal with multipart keys, or a field of Value. +// +// The referenced key usually is the primary key, or it can be a part +// of the primary key in the context of multipart keys. +// +// The referencing and referenced keys are mapped together using a Map. +// +// Unless you're trying to build your generic unique index, you should be using the indexes package. +type GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any] struct { + refs Map[ReferencingKey, ReferencedKey] + getRefs func(pk PrimaryKey, value Value) ([]IndexReference[ReferencingKey, ReferencedKey], error) +} + +// NewGenericUniqueIndex instantiates a GenericUniqueIndex. Works in the same way as NewGenericMultiIndex. +func NewGenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any]( + schema *SchemaBuilder, + prefix Prefix, + name string, + referencingKeyCodec KeyCodec[ReferencingKey], + referencedKeyCodec KeyCodec[ReferencedKey], + getRefs func(pk PrimaryKey, value Value) ([]IndexReference[ReferencingKey, ReferencedKey], error), +) *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value] { + return &GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]{ + refs: NewMap[ReferencingKey, ReferencedKey](schema, prefix, name, referencingKeyCodec, keyToValueCodec[ReferencedKey]{kc: referencedKeyCodec}), + getRefs: getRefs, + } +} + +func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Iterate( + ctx context.Context, + ranger Ranger[ReferencingKey], +) (Iterator[ReferencingKey, ReferencedKey], error) { + return i.refs.Iterate(ctx, ranger) +} + +func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Get(ctx context.Context, ref ReferencingKey) (ReferencedKey, error) { + return i.refs.Get(ctx, ref) +} + +// Reference implements Index. +func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Reference( + ctx context.Context, + pk PrimaryKey, + newValue Value, + oldValue *Value, +) error { + if oldValue != nil { + err := i.Unreference(ctx, pk, *oldValue) + if err != nil { + return err + } + } + refs, err := i.getRefs(pk, newValue) + if err != nil { + return err + } + for _, ref := range refs { + has, err := i.refs.Has(ctx, ref.Referring) + if err != nil { + return err + } + if has { + return fmt.Errorf("%w: index uniqueness constrain violation: %s", ErrConflict, i.refs.kc.Stringify(ref.Referring)) + } + err = i.refs.Set(ctx, ref.Referring, ref.Referred) + if err != nil { + return err + } + } + return nil +} + +// Unreference implements Index. +func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Unreference( + ctx context.Context, + pk PrimaryKey, + value Value, +) error { + refs, err := i.getRefs(pk, value) + if err != nil { + return err + } + + for _, ref := range refs { + err = i.refs.Remove(ctx, ref.Referring) + if err != nil { + return err + } + } + + return nil +} + +// keyToValueCodec is a ValueCodec that wraps a KeyCodec to make it behave like a ValueCodec. +type keyToValueCodec[K any] struct { + kc KeyCodec[K] +} + +func (k keyToValueCodec[K]) EncodeJSON(value K) ([]byte, error) { + return k.kc.EncodeJSON(value) +} + +func (k keyToValueCodec[K]) DecodeJSON(b []byte) (K, error) { + return k.kc.DecodeJSON(b) +} + +func (k keyToValueCodec[K]) Encode(value K) ([]byte, error) { + buf := make([]byte, k.kc.Size(value)) + _, err := k.kc.Encode(buf, value) + return buf, err +} + +func (k keyToValueCodec[K]) Decode(b []byte) (K, error) { + r, key, err := k.kc.Decode(b) + if err != nil { + var key K + return key, err + } + + if r != len(b) { + var key K + return key, fmt.Errorf("%w: was supposed to fully consume the key '%x', consumed %d out of %d", ErrEncoding, b, r, len(b)) + } + return key, nil +} + +func (k keyToValueCodec[K]) Stringify(value K) string { + return k.kc.Stringify(value) +} + +func (k keyToValueCodec[K]) ValueType() string { + return k.kc.KeyType() +} diff --git a/collections/indexes_generic_unique_test.go b/collections/indexes_generic_unique_test.go new file mode 100644 index 0000000000..ae1e76ccec --- /dev/null +++ b/collections/indexes_generic_unique_test.go @@ -0,0 +1,72 @@ +package collections + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +type nftBalance struct { + nftIDs []uint64 +} + +func TestGenericUniqueIndex(t *testing.T) { + // we create the same testing context as with GenericMultiIndex. We have a mapping: + // Address => NFT balance. + // An NFT balance is represented as a slice of IDs, those IDs are unique, meaning that + // they can be held only by one address. + sk, ctx := deps() + sb := NewSchemaBuilder(sk) + ui := NewGenericUniqueIndex( + sb, NewPrefix("nft_to_owner_index"), "ntf_to_owner_index", Uint64Key, StringKey, + func(pk string, value nftBalance) ([]IndexReference[uint64, string], error) { + // the referencing keys are all the NFT unique ids. + refs := make([]IndexReference[uint64, string], len(value.nftIDs)) + // for each NFT contained in the balance we create an index reference + // between the NFT unique ID and the owner of the balance. + for i, id := range value.nftIDs { + refs[i] = NewIndexReference(id, pk) + } + return refs, nil + }, + ) + + // let's create the relationships + err := ui.Reference(ctx, "cosmosAddr1", nftBalance{nftIDs: []uint64{0, 1}}, nil) + require.NoError(t, err) + + // assert relations were created + iter, err := ui.Iterate(ctx, nil) + require.NoError(t, err) + defer iter.Close() + + kv, err := iter.KeyValues() + require.NoError(t, err) + require.Len(t, kv, 2) + require.Equal(t, kv[0].Key, uint64(0)) + require.Equal(t, kv[0].Value, "cosmosAddr1") + require.Equal(t, kv[1].Key, uint64(1)) + require.Equal(t, kv[1].Value, "cosmosAddr1") + + // assert only one address can own a unique NFT + err = ui.Reference(ctx, "cosmosAddr2", nftBalance{nftIDs: []uint64{0}}, nil) // nft with ID 0 is already owned by cosmosAddr1 + require.ErrorIs(t, err, ErrConflict) + + // during modifications references are updated, we update the index in + // such a way that cosmosAddr1 loses ownership of nft with id 0. + err = ui.Reference(ctx, "cosmosAddr1", + nftBalance{nftIDs: []uint64{1}}, // this is the update nft balance, which contains only id 1 + &nftBalance{nftIDs: []uint64{0, 1}}, // this is the old nft balance, which contains both 0 and 1 + ) + require.NoError(t, err) + + // the updated balance does not contain nft with id 0 + _, err = ui.Get(ctx, 0) + require.ErrorIs(t, err, ErrNotFound) + + // unreferencing clears all the indexes + err = ui.Unreference(ctx, "cosmosAddr1", nftBalance{nftIDs: []uint64{1}}) + require.NoError(t, err) + _, err = ui.Get(ctx, 1) + require.ErrorIs(t, err, ErrNotFound) +} diff --git a/collections/item.go b/collections/item.go index 144eb264f8..93afc695b3 100644 --- a/collections/item.go +++ b/collections/item.go @@ -23,14 +23,6 @@ func NewItem[V any]( return item } -func (i Item[V]) getName() string { - return i.name -} - -func (i Item[V]) getPrefix() []byte { - return i.prefix -} - // Get gets the item, if it is not set it returns an ErrNotFound error. // If value decoding fails then an ErrEncoding is returned. func (i Item[V]) Get(ctx context.Context) (V, error) { diff --git a/collections/iter.go b/collections/iter.go index 4d3b1ef10a..b2103301ff 100644 --- a/collections/iter.go +++ b/collections/iter.go @@ -121,7 +121,6 @@ func (r *Range[K]) Descending() *Range[K] { // test sentinel error var ( - errRange = errors.New("collections: range error") errOrder = errors.New("collections: invalid order") ) @@ -161,6 +160,7 @@ func iteratorFromRanger[K, V any](ctx context.Context, m Map[K, V], r Ranger[K]) } else { endBytes = nextBytesPrefixKey(m.prefix) } + return newIterator(ctx, startBytes, endBytes, order, m) } diff --git a/collections/keys_test.go b/collections/keys_test.go index f3a94c90bc..f84a154a6c 100644 --- a/collections/keys_test.go +++ b/collections/keys_test.go @@ -14,5 +14,4 @@ func TestUint64Key(t *testing.T) { } func TestStringKey(t *testing.T) { - } diff --git a/collections/map.go b/collections/map.go index 29eb3c4936..3a94916698 100644 --- a/collections/map.go +++ b/collections/map.go @@ -70,9 +70,7 @@ func (m Map[K, V]) Set(ctx context.Context, key K, value V) error { // Get returns the value associated with the provided key, // errors with ErrNotFound if the key does not exist, or // with ErrEncoding if the key or value decoding fails. -func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) { - var v V - +func (m Map[K, V]) Get(ctx context.Context, key K) (v V, err error) { bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key) if err != nil { return v, err @@ -80,13 +78,12 @@ func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) { kvStore := m.sa(ctx) valueBytes, err := kvStore.Get(bytesKey) - if valueBytes == nil { - return v, fmt.Errorf("%w: key '%s' of type %s", ErrNotFound, m.kc.Stringify(key), m.vc.ValueType()) - } - if err != nil { return v, err } + if valueBytes == nil { + return v, fmt.Errorf("%w: key '%s' of type %s", ErrNotFound, m.kc.Stringify(key), m.vc.ValueType()) + } v, err = m.vc.Decode(valueBytes) if err != nil { @@ -115,8 +112,7 @@ func (m Map[K, V]) Remove(ctx context.Context, key K) error { return err } kvStore := m.sa(ctx) - kvStore.Delete(bytesKey) - return nil + return kvStore.Delete(bytesKey) } // Iterate provides an Iterator over K and V. It accepts a Ranger interface. diff --git a/collections/pair_test.go b/collections/pair_test.go index 0f778cf2ee..141c04b661 100644 --- a/collections/pair_test.go +++ b/collections/pair_test.go @@ -1,8 +1,9 @@ package collections import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestPair(t *testing.T) { @@ -63,5 +64,6 @@ func TestPairRange(t *testing.T) { iter, err = m.Iterate(ctx, NewPrefixedPairRange[string, uint64]("A").Descending().StartExclusive(0).EndInclusive(2)) require.NoError(t, err) keys, err = iter.Keys() + require.NoError(t, err) require.Equal(t, []Pair[string, uint64]{Join("A", uint64(2)), Join("A", uint64(1))}, keys) } diff --git a/collections/schema.go b/collections/schema.go index e52ccb2439..dedb55a257 100644 --- a/collections/schema.go +++ b/collections/schema.go @@ -151,26 +151,6 @@ func NewSchemaFromAccessor(accessor func(context.Context) store.KVStore) Schema } } -func (s Schema) addCollection(collection collection) { - prefix := collection.getPrefix() - name := collection.getName() - - if _, ok := s.collectionsByPrefix[string(prefix)]; ok { - panic(fmt.Errorf("prefix %v already taken within schema", prefix)) - } - - if _, ok := s.collectionsByName[name]; ok { - panic(fmt.Errorf("name %s already taken within schema", name)) - } - - if !nameRegex.MatchString(name) { - panic(fmt.Errorf("name must match regex %s, got %s", NameRegex, name)) - } - - s.collectionsByPrefix[string(prefix)] = collection - s.collectionsByName[name] = collection -} - // DefaultGenesis implements the appmodule.HasGenesis.DefaultGenesis method. func (s Schema) DefaultGenesis(target appmodule.GenesisTarget) error { for _, name := range s.collectionsOrdered { diff --git a/collections/values_test.go b/collections/values_test.go index 8e174b32ee..5d2830ef1b 100644 --- a/collections/values_test.go +++ b/collections/values_test.go @@ -7,7 +7,6 @@ import ( ) func TestUint64Value(t *testing.T) { - t.Run("invalid size", func(t *testing.T) { _, err := Uint64Value.Decode([]byte{0x1, 0x2}) require.ErrorIs(t, err, ErrEncoding)