feat(collections): IndexedMap (#14397)

Co-authored-by: testinginprod <testinginprod@somewhere.idk>
Co-authored-by: Marko <marbar3778@yahoo.com>
Co-authored-by: Likhita Polavarapu <78951027+likhita-809@users.noreply.github.com>
This commit is contained in:
testinginprod 2023-01-27 13:49:27 +01:00 committed by GitHub
parent ed17f2d437
commit 519630ea64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 892 additions and 52 deletions

View File

@ -35,4 +35,5 @@ Ref: https://keepachangelog.com/en/1.0.0/
* [#14351](https://github.com/cosmos/cosmos-sdk/pull/14351) Add keyset
* [#14364](https://github.com/cosmos/cosmos-sdk/pull/14364) Add sequence
* [#14468](https://github.com/cosmos/cosmos-sdk/pull/14468) Add Map.IterateRaw API.
* [#14310](https://github.com/cosmos/cosmos-sdk/pull/14310) Add Pair keys
* [#14310](https://github.com/cosmos/cosmos-sdk/pull/14310) Add Pair keys
* [#14397](https://github.com/cosmos/cosmos-sdk/pull/14397) Add IndexedMap

View File

@ -10,6 +10,8 @@ var (
ErrNotFound = errors.New("collections: not found")
// ErrEncoding is returned when something fails during key or value encoding/decoding.
ErrEncoding = errors.New("collections: encoding error")
// ErrConflict is returned when there are conflicts, for example in UniqueIndex.
ErrConflict = errors.New("collections: conflict")
)
// collection is the interface that all collections support. It will eventually

View File

@ -29,7 +29,6 @@ func (t testStore) Has(key []byte) (bool, error) {
func (t testStore) Set(key, value []byte) error {
return t.db.Set(key, value)
}
func (t testStore) Delete(key []byte) error {

View File

@ -1,9 +1,14 @@
package colltest
import (
"cosmossdk.io/collections"
"github.com/stretchr/testify/require"
"encoding/json"
"fmt"
"reflect"
"testing"
"cosmossdk.io/collections"
"github.com/stretchr/testify/require"
)
// TestKeyCodec asserts the correct behaviour of a KeyCodec over the type T.
@ -21,6 +26,7 @@ func TestKeyCodec[T any](t *testing.T, keyCodec collections.KeyCodec[T], key T)
pairKey := collections.Join(key, "TEST")
buffer = make([]byte, pairCodec.Size(pairKey))
written, err = pairCodec.Encode(buffer, pairKey)
require.Equal(t, len(buffer), written, "the pair buffer should have been fully written")
require.NoError(t, err)
read, decodedPairKey, err := pairCodec.Decode(buffer)
require.NoError(t, err)
@ -53,3 +59,101 @@ func TestValueCodec[T any](t *testing.T, encoder collections.ValueCodec[T], valu
_ = encoder.Stringify(value)
}
// MockValueCodec returns a mock of collections.ValueCodec for type T, it
// can be used for collections Values testing. It also supports interfaces.
// For the interfaces cases, in order for an interface to be decoded it must
// have been encoded first. Not concurrency safe.
// EG:
// Let's say the value is interface Animal
// if I want to decode Dog which implements Animal, then I need to first encode
// it in order to make the type known by the MockValueCodec.
func MockValueCodec[T any]() collections.ValueCodec[T] {
typ := reflect.ValueOf(new(T)).Elem().Type()
isInterface := false
if typ.Kind() == reflect.Interface {
isInterface = true
}
return &mockValueCodec[T]{
isInterface: isInterface,
seenTypes: map[string]reflect.Type{},
valueType: fmt.Sprintf("%s.%s", typ.PkgPath(), typ.Name()),
}
}
type mockValueJSON struct {
TypeName string `json:"type_name"`
Value json.RawMessage `json:"value"`
}
type mockValueCodec[T any] struct {
isInterface bool
seenTypes map[string]reflect.Type
valueType string
}
func (m mockValueCodec[T]) Encode(value T) ([]byte, error) {
typeName := m.getTypeName(value)
valueBytes, err := json.Marshal(value)
if err != nil {
return nil, err
}
return json.Marshal(mockValueJSON{
TypeName: typeName,
Value: valueBytes,
})
}
func (m mockValueCodec[T]) Decode(b []byte) (t T, err error) {
wrappedValue := mockValueJSON{}
err = json.Unmarshal(b, &wrappedValue)
if err != nil {
return
}
if !m.isInterface {
err = json.Unmarshal(wrappedValue.Value, &t)
return t, err
}
typ, exists := m.seenTypes[wrappedValue.TypeName]
if !exists {
return t, fmt.Errorf("uknown type %s, you're dealing with interfaces... in order to make the interface types known for the MockValueCodec, you need to first encode them", wrappedValue.TypeName)
}
newT := reflect.New(typ).Interface()
err = json.Unmarshal(wrappedValue.Value, newT)
if err != nil {
return t, err
}
iface := new(T)
reflect.ValueOf(iface).Elem().Set(reflect.ValueOf(newT).Elem())
return *iface, nil
}
func (m mockValueCodec[T]) EncodeJSON(value T) ([]byte, error) {
return m.Encode(value)
}
func (m mockValueCodec[T]) DecodeJSON(b []byte) (T, error) {
return m.Decode(b)
}
func (m mockValueCodec[T]) Stringify(value T) string {
return fmt.Sprintf("%#v", value)
}
func (m mockValueCodec[T]) ValueType() string {
return m.valueType
}
func (m mockValueCodec[T]) getTypeName(value T) string {
if !m.isInterface {
return m.valueType
}
typ := reflect.TypeOf(value)
name := fmt.Sprintf("%s.%s", typ.PkgPath(), typ.Name())
m.seenTypes[name] = typ
return name
}

View File

@ -0,0 +1,48 @@
package colltest
import "testing"
type animal interface {
name() string
}
type dog struct {
Name string `json:"name"`
BarksLoudly bool `json:"barks_loudly"`
}
type cat struct {
Name string `json:"name"`
Scratches bool `json:"scratches"`
}
func (d *cat) name() string { return d.Name }
func (d dog) name() string { return d.Name }
func TestMockValueCodec(t *testing.T) {
t.Run("primitive type", func(t *testing.T) {
x := MockValueCodec[string]()
TestValueCodec(t, x, "hello")
})
t.Run("struct type", func(t *testing.T) {
x := MockValueCodec[dog]()
TestValueCodec(t, x, dog{
Name: "kernel",
BarksLoudly: true,
})
})
t.Run("interface type", func(t *testing.T) {
x := MockValueCodec[animal]()
TestValueCodec[animal](t, x, dog{
Name: "kernel",
BarksLoudly: true,
})
TestValueCodec[animal](t, x, &cat{
Name: "echo",
Scratches: true,
})
})
}

View File

@ -0,0 +1,49 @@
package colltest
import (
"context"
"cosmossdk.io/core/store"
db "github.com/cosmos/cosmos-db"
)
// MockStore returns a mock store.KVStoreService and a mock context.Context.
// They can be used to test collections.
func MockStore() (store.KVStoreService, context.Context) {
kv := db.NewMemDB()
return &testStore{kv}, context.Background()
}
type testStore struct {
db db.DB
}
func (t testStore) OpenKVStore(ctx context.Context) store.KVStore {
return t
}
func (t testStore) Get(key []byte) ([]byte, error) {
return t.db.Get(key)
}
func (t testStore) Has(key []byte) (bool, error) {
return t.db.Has(key)
}
func (t testStore) Set(key, value []byte) error {
return t.db.Set(key, value)
}
func (t testStore) Delete(key []byte) error {
return t.db.Delete(key)
}
func (t testStore) Iterator(start, end []byte) (store.Iterator, error) {
return t.db.Iterator(start, end)
}
func (t testStore) ReverseIterator(start, end []byte) (store.Iterator, error) {
return t.db.ReverseIterator(start, end)
}
var _ store.KVStore = testStore{}

View File

@ -1,9 +1,10 @@
package collections_test
import (
"testing"
"cosmossdk.io/collections"
"cosmossdk.io/collections/colltest"
"testing"
)
func TestKeyCorrectness(t *testing.T) {

View File

@ -20,13 +20,13 @@ type jsonMapEntry struct {
}
func (m Map[K, V]) validateGenesis(reader io.Reader) error {
return m.doDecodeJson(reader, func(key K, value V) error {
return m.doDecodeJSON(reader, func(key K, value V) error {
return nil
})
}
func (m Map[K, V]) importGenesis(ctx context.Context, reader io.Reader) error {
return m.doDecodeJson(reader, func(key K, value V) error {
return m.doDecodeJSON(reader, func(key K, value V) error {
return m.Set(ctx, key, value)
})
}
@ -95,7 +95,7 @@ func (m Map[K, V]) exportGenesis(ctx context.Context, writer io.Writer) error {
return err
}
func (m Map[K, V]) doDecodeJson(reader io.Reader, onEntry func(key K, value V) error) error {
func (m Map[K, V]) doDecodeJSON(reader io.Reader, onEntry func(key K, value V) error) error {
decoder := json.NewDecoder(reader)
token, err := decoder.Token()
if err != nil {
@ -107,14 +107,14 @@ func (m Map[K, V]) doDecodeJson(reader io.Reader, onEntry func(key K, value V) e
}
for decoder.More() {
var rawJson json.RawMessage
err := decoder.Decode(&rawJson)
var rawJSON json.RawMessage
err := decoder.Decode(&rawJSON)
if err != nil {
return err
}
var mapEntry jsonMapEntry
err = json.Unmarshal(rawJson, &mapEntry)
err = json.Unmarshal(rawJSON, &mapEntry)
if err != nil {
return err
}

135
collections/indexed_map.go Normal file
View File

@ -0,0 +1,135 @@
package collections
import (
"context"
"errors"
"fmt"
)
// Indexes represents a type which groups multiple Index
// of one Value saved with the provided PrimaryKey.
// Indexes is just meant to be a struct containing all
// the indexes to maintain relationship for.
type Indexes[PrimaryKey, Value any] interface {
// IndexesList is implemented by the Indexes type
// and returns all the grouped Index of Value.
IndexesList() []Index[PrimaryKey, Value]
}
// Index represents an index of the Value indexed using the type PrimaryKey.
type Index[PrimaryKey, Value any] interface {
// Reference creates a reference between the provided primary key and value.
// If oldValue is not nil then the Index must update the references
// of the primary key associated with the new value and remove the
// old invalid references.
Reference(ctx context.Context, pk PrimaryKey, newValue Value, oldValue *Value) error
// Unreference removes the reference between the primary key and value.
Unreference(ctx context.Context, pk PrimaryKey, value Value) error
}
// IndexedMap works like a Map but creates references between fields of Value and its PrimaryKey.
// These relationships are expressed and maintained using the Indexes type.
// Internally IndexedMap can be seen as a partitioned collection, one partition
// is a Map[PrimaryKey, Value], that maintains the object, the second
// are the Indexes.
type IndexedMap[PrimaryKey, Value any, Idx Indexes[PrimaryKey, Value]] struct {
Indexes Idx
m Map[PrimaryKey, Value]
}
// NewIndexedMap instantiates a new IndexedMap. Accepts a SchemaBuilder, a Prefix,
// a humanised name that defines the name of the collection, the primary key codec
// which is basically what IndexedMap uses to encode the primary key to bytes,
// the value codec which is what the IndexedMap uses to encode the value.
// Then it expects the initialised indexes.
func NewIndexedMap[PrimaryKey, Value any, Idx Indexes[PrimaryKey, Value]](
schema *SchemaBuilder,
prefix Prefix,
name string,
pkCodec KeyCodec[PrimaryKey],
valueCodec ValueCodec[Value],
indexes Idx,
) *IndexedMap[PrimaryKey, Value, Idx] {
return &IndexedMap[PrimaryKey, Value, Idx]{
Indexes: indexes,
m: NewMap(schema, prefix, name, pkCodec, valueCodec),
}
}
// Get gets the object given its primary key.
func (m *IndexedMap[PrimaryKey, Value, Idx]) Get(ctx context.Context, pk PrimaryKey) (Value, error) {
return m.m.Get(ctx, pk)
}
// Iterate allows to iterate over the objects given a Ranger of the primary key.
func (m *IndexedMap[PrimaryKey, Value, Idx]) Iterate(ctx context.Context, ranger Ranger[PrimaryKey]) (Iterator[PrimaryKey, Value], error) {
return m.m.Iterate(ctx, ranger)
}
// Has reports if exists a value with the provided primary key.
func (m *IndexedMap[PrimaryKey, Value, Idx]) Has(ctx context.Context, pk PrimaryKey) (bool, error) {
return m.m.Has(ctx, pk)
}
// Set maps the value using the primary key. It will also iterate every index and instruct them to
// add or update the indexes.
func (m *IndexedMap[PrimaryKey, Value, Idx]) Set(ctx context.Context, pk PrimaryKey, value Value) error {
// we need to see if there was a previous instance of the value
oldValue, err := m.m.Get(ctx, pk)
switch {
// update indexes
case err == nil:
err = m.ref(ctx, pk, value, &oldValue)
if err != nil {
return fmt.Errorf("collections: indexing error: %w", err)
}
// create new indexes
case errors.Is(err, ErrNotFound):
err = m.ref(ctx, pk, value, nil)
if err != nil {
return fmt.Errorf("collections: indexing error: %w", err)
}
// cannot move forward error
default:
return err
}
return m.m.Set(ctx, pk, value)
}
// Remove removes the value associated with the primary key from the map. Then
// it iterates over all the indexes and instructs them to remove all the references
// associated with the removed value.
func (m *IndexedMap[PrimaryKey, Value, Idx]) Remove(ctx context.Context, pk PrimaryKey) error {
oldValue, err := m.m.Get(ctx, pk)
if err != nil {
// TODO retain Map behaviour? which does not error in case we remove a non-existing object
return err
}
err = m.unref(ctx, pk, oldValue)
if err != nil {
return err
}
return m.m.Remove(ctx, pk)
}
func (m *IndexedMap[PrimaryKey, Value, Idx]) ref(ctx context.Context, pk PrimaryKey, value Value, oldValue *Value) error {
for _, index := range m.Indexes.IndexesList() {
err := index.Reference(ctx, pk, value, oldValue)
if err != nil {
return err
}
}
return nil
}
func (m *IndexedMap[PrimaryKey, Value, Idx]) unref(ctx context.Context, pk PrimaryKey, value Value) error {
for _, index := range m.Indexes.IndexesList() {
err := index.Unreference(ctx, pk, value)
if err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,104 @@
package collections_test
import (
"testing"
"cosmossdk.io/collections"
"cosmossdk.io/collections/colltest"
"github.com/stretchr/testify/require"
)
type company struct {
City string
Vat uint64
}
type companyIndexes struct {
// City is an index of the company indexed map. It indexes a company
// given its city. The index is multi, meaning that there can be multiple
// companies from the same city.
City *collections.GenericMultiIndex[string, string, string, company]
// Vat is an index of the company indexed map. It indexes a company
// given its VAT number. The index is unique, meaning that there can be
// only one VAT number for a company.
Vat *collections.GenericUniqueIndex[uint64, string, string, company]
}
func (c companyIndexes) IndexesList() []collections.Index[string, company] {
return []collections.Index[string, company]{c.City, c.Vat}
}
func newTestIndexedMap(schema *collections.SchemaBuilder) *collections.IndexedMap[string, company, companyIndexes] {
return collections.NewIndexedMap(schema, collections.NewPrefix(0), "companies", collections.StringKey, colltest.MockValueCodec[company](),
companyIndexes{
City: collections.NewGenericMultiIndex(schema, collections.NewPrefix(1), "companies_by_city", collections.StringKey, collections.StringKey, func(pk string, value company) ([]collections.IndexReference[string, string], error) {
return []collections.IndexReference[string, string]{collections.NewIndexReference(value.City, pk)}, nil
}),
Vat: collections.NewGenericUniqueIndex(schema, collections.NewPrefix(2), "companies_by_vat", collections.Uint64Key, collections.StringKey, func(pk string, v company) ([]collections.IndexReference[uint64, string], error) {
return []collections.IndexReference[uint64, string]{collections.NewIndexReference(v.Vat, pk)}, nil
}),
},
)
}
func TestIndexedMap(t *testing.T) {
sk, ctx := colltest.MockStore()
schema := collections.NewSchemaBuilder(sk)
im := newTestIndexedMap(schema)
// test insertion
err := im.Set(ctx, "1", company{
City: "milan",
Vat: 0,
})
require.NoError(t, err)
err = im.Set(ctx, "2", company{
City: "milan",
Vat: 1,
})
require.NoError(t, err)
err = im.Set(ctx, "3", company{
City: "milan",
Vat: 4,
})
require.NoError(t, err)
pk, err := im.Indexes.Vat.Get(ctx, 1)
require.NoError(t, err)
require.Equal(t, "2", pk)
// test a set which updates the indexes
err = im.Set(ctx, "2", company{
City: "milan",
Vat: 2,
})
require.NoError(t, err)
pk, err = im.Indexes.Vat.Get(ctx, 2)
require.NoError(t, err)
require.Equal(t, "2", pk)
_, err = im.Indexes.Vat.Get(ctx, 1)
require.ErrorIs(t, err, collections.ErrNotFound)
// test removal
err = im.Remove(ctx, "2")
require.NoError(t, err)
_, err = im.Indexes.Vat.Get(ctx, 2)
require.ErrorIs(t, err, collections.ErrNotFound)
// test iteration
iter, err := im.Iterate(ctx, nil)
require.NoError(t, err)
keys, err := iter.Keys()
require.NoError(t, err)
require.Equal(t, []string{"1", "3"}, keys)
// test get
v, err := im.Get(ctx, "3")
require.NoError(t, err)
require.Equal(t, company{"milan", 4}, v)
}

View File

@ -0,0 +1,129 @@
package collections
import "context"
func NewIndexReference[ReferencingKey, ReferencedKey any](referencing ReferencingKey, referenced ReferencedKey) IndexReference[ReferencingKey, ReferencedKey] {
return IndexReference[ReferencingKey, ReferencedKey]{
Referring: referencing,
Referred: referenced,
}
}
// IndexReference defines a generic index reference.
type IndexReference[ReferencingKey, ReferencedKey any] struct {
// Referring is the key that refers, points to the Referred key.
Referring ReferencingKey
// Referred is the key that is being pointed to by the Referring key.
Referred ReferencedKey
}
// GenericMultiIndex defines a generic Index type that given a primary key
// and the value associated with that primary key returns one or multiple IndexReference.
//
// The referencing key can be anything, usually it is either a part of the primary
// key when we deal with multipart keys, or a field of Value.
//
// The referenced key usually is the primary key, or it can be a part
// of the primary key in the context of multipart keys.
//
// The Referencing and Referenced keys are joined and saved as a Pair in a KeySet
// where the key is Pair[ReferencingKey, ReferencedKey].
// So if we wanted to get all the keys referenced by a generic (concrete) ReferencingKey
// we would just need to iterate over all the keys starting with bytes(ReferencingKey).
//
// Unless you're trying to build your generic multi index, you should be using the indexes package.
type GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any] struct {
refs KeySet[Pair[ReferencingKey, ReferencedKey]]
getRefs func(pk PrimaryKey, v Value) ([]IndexReference[ReferencingKey, ReferencedKey], error)
}
// NewGenericMultiIndex instantiates a GenericMultiIndex, given
// schema, Prefix, humanised name, the key codec used to encode the referencing key
// to bytes, the key codec used to encode the referenced key to bytes and a function
// which given the primary key and a value of an object being saved or removed in IndexedMap
// returns all the possible IndexReference of that object.
//
// The IndexReference is usually just one. But in certain cases can be multiple,
// for example when the Value has an array field, and we want to create a relationship
// between the object and all the elements of the array contained in the object.
func NewGenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any](
schema *SchemaBuilder,
prefix Prefix,
name string,
referencingKeyCodec KeyCodec[ReferencingKey],
referencedKeyCodec KeyCodec[ReferencedKey],
getRefsFunc func(pk PrimaryKey, value Value) ([]IndexReference[ReferencingKey, ReferencedKey], error),
) *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value] {
return &GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]{
getRefs: getRefsFunc,
refs: NewKeySet(schema, prefix, name, PairKeyCodec(referencingKeyCodec, referencedKeyCodec)),
}
}
// Iterate allows to iterate over the index. It returns a KeySetIterator of Pair[ReferencingKey, ReferencedKey].
// K1 of the Pair is the key (referencing) pointing to K2 (referenced).
func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Iterate(
ctx context.Context,
ranger Ranger[Pair[ReferencingKey, ReferencedKey]],
) (KeySetIterator[Pair[ReferencingKey, ReferencedKey]], error) {
return i.refs.Iterate(ctx, ranger)
}
// Has reports if there is a relationship in the index between the referencing and the referenced key.
func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Has(
ctx context.Context,
referencing ReferencingKey,
referenced ReferencedKey,
) (bool, error) {
return i.refs.Has(ctx, Join(referencing, referenced))
}
// Reference implements the Index interface.
func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Reference(
ctx context.Context,
pk PrimaryKey,
value Value,
oldValue *Value,
) error {
if oldValue != nil {
err := i.Unreference(ctx, pk, *oldValue)
if err != nil {
return err
}
}
refKeys, err := i.getRefs(pk, value)
if err != nil {
return err
}
for _, ref := range refKeys {
err := i.refs.Set(ctx, Join(ref.Referring, ref.Referred))
if err != nil {
return err
}
}
return nil
}
// Unreference implements the Index interface.
func (i *GenericMultiIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Unreference(
ctx context.Context,
pk PrimaryKey,
value Value,
) error {
refs, err := i.getRefs(pk, value)
if err != nil {
return err
}
for _, ref := range refs {
err = i.refs.Remove(ctx, Join(ref.Referring, ref.Referred))
if err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,83 @@
package collections
import (
"testing"
"github.com/stretchr/testify/require"
)
type coin struct {
denom string // this will be used as indexing field.
amount uint64
}
type balance struct {
coins []coin
}
func TestGenericMultiIndex(t *testing.T) {
// we are simulating a context in which we have the following mapping:
//
// address (represented as string) => balance (slice of coins).
//
// we want to create an index that creates a relationship between the coin
// denom, which is part of the balance structure, and the address. This means
// we know given a denom who are the addresses holding that denom.
// From GenericMultiIndex point of view, the denom field of the array becomes
// the referencing key which points to the address (string), which is the key
// being referenced.
sk, ctx := deps()
sb := NewSchemaBuilder(sk)
mi := NewGenericMultiIndex(
sb, NewPrefix("denoms"), "denom_to_owner", StringKey, StringKey,
func(pk string, value balance) ([]IndexReference[string, string], error) {
// the referencing keys are all the denoms.
refs := make([]IndexReference[string, string], len(value.coins))
// the index reference being created, generates a relationship
// between denom (the key that references) and pk (address, the key
// that is being referenced).
for i, coin := range value.coins {
refs[i] = NewIndexReference(coin.denom, pk)
}
return refs, nil
},
)
// let's create the relationships
err := mi.Reference(ctx, "cosmosAddr1", balance{coins: []coin{
{"atom", 1000}, {"osmo", 5000},
}}, nil)
require.NoError(t, err)
// we must find relations between cosmosaddr1 and the denom atom and osmo
iter, err := mi.Iterate(ctx, nil)
require.NoError(t, err)
keys, err := iter.Keys()
require.NoError(t, err)
require.Len(t, keys, 2)
require.Equal(t, keys[0].K1(), "atom") // assert relationship with atom created
require.Equal(t, keys[1].K1(), "osmo") // assert relationship with osmo created
// if we update the reference to remove osmo as balance then we must not find it anymore
err = mi.Reference(ctx, "cosmosAddr1", balance{coins: []coin{{"atom", 1000}}}, // this is the update which does not have osmo
&balance{coins: []coin{{"atom", 1000}, {"osmo", 5000}}}, // this is the previous record
)
require.NoError(t, err)
exists, err := mi.Has(ctx, "osmo", "cosmosAddr1") // osmo must not exist anymore
require.NoError(t, err)
require.False(t, exists)
exists, err = mi.Has(ctx, "atom", "cosmosAddr1") // atom still exists
require.NoError(t, err)
require.True(t, exists)
// if we unreference then no relationship is maintained anymore
err = mi.Unreference(ctx, "cosmosAddr1", balance{coins: []coin{{"atom", 1000}}})
require.NoError(t, err)
exists, err = mi.Has(ctx, "atom", "cosmosAddr1") // atom is not part of the index anymore because cosmosAddr1 was removed.
require.NoError(t, err)
require.False(t, exists)
}

View File

@ -0,0 +1,145 @@
package collections
import (
"context"
"fmt"
)
// GenericUniqueIndex defines a generic index which enforces uniqueness constraints
// between ReferencingKey and ReferencedKey, meaning that one referencing key maps
// only one referenced key. The same referenced key can be mapped by multiple referencing keys.
//
// The referencing key can be anything, usually it is either a part of the primary
// key when we deal with multipart keys, or a field of Value.
//
// The referenced key usually is the primary key, or it can be a part
// of the primary key in the context of multipart keys.
//
// The referencing and referenced keys are mapped together using a Map.
//
// Unless you're trying to build your generic unique index, you should be using the indexes package.
type GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any] struct {
refs Map[ReferencingKey, ReferencedKey]
getRefs func(pk PrimaryKey, value Value) ([]IndexReference[ReferencingKey, ReferencedKey], error)
}
// NewGenericUniqueIndex instantiates a GenericUniqueIndex. Works in the same way as NewGenericMultiIndex.
func NewGenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value any](
schema *SchemaBuilder,
prefix Prefix,
name string,
referencingKeyCodec KeyCodec[ReferencingKey],
referencedKeyCodec KeyCodec[ReferencedKey],
getRefs func(pk PrimaryKey, value Value) ([]IndexReference[ReferencingKey, ReferencedKey], error),
) *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value] {
return &GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]{
refs: NewMap[ReferencingKey, ReferencedKey](schema, prefix, name, referencingKeyCodec, keyToValueCodec[ReferencedKey]{kc: referencedKeyCodec}),
getRefs: getRefs,
}
}
func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Iterate(
ctx context.Context,
ranger Ranger[ReferencingKey],
) (Iterator[ReferencingKey, ReferencedKey], error) {
return i.refs.Iterate(ctx, ranger)
}
func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Get(ctx context.Context, ref ReferencingKey) (ReferencedKey, error) {
return i.refs.Get(ctx, ref)
}
// Reference implements Index.
func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Reference(
ctx context.Context,
pk PrimaryKey,
newValue Value,
oldValue *Value,
) error {
if oldValue != nil {
err := i.Unreference(ctx, pk, *oldValue)
if err != nil {
return err
}
}
refs, err := i.getRefs(pk, newValue)
if err != nil {
return err
}
for _, ref := range refs {
has, err := i.refs.Has(ctx, ref.Referring)
if err != nil {
return err
}
if has {
return fmt.Errorf("%w: index uniqueness constrain violation: %s", ErrConflict, i.refs.kc.Stringify(ref.Referring))
}
err = i.refs.Set(ctx, ref.Referring, ref.Referred)
if err != nil {
return err
}
}
return nil
}
// Unreference implements Index.
func (i *GenericUniqueIndex[ReferencingKey, ReferencedKey, PrimaryKey, Value]) Unreference(
ctx context.Context,
pk PrimaryKey,
value Value,
) error {
refs, err := i.getRefs(pk, value)
if err != nil {
return err
}
for _, ref := range refs {
err = i.refs.Remove(ctx, ref.Referring)
if err != nil {
return err
}
}
return nil
}
// keyToValueCodec is a ValueCodec that wraps a KeyCodec to make it behave like a ValueCodec.
type keyToValueCodec[K any] struct {
kc KeyCodec[K]
}
func (k keyToValueCodec[K]) EncodeJSON(value K) ([]byte, error) {
return k.kc.EncodeJSON(value)
}
func (k keyToValueCodec[K]) DecodeJSON(b []byte) (K, error) {
return k.kc.DecodeJSON(b)
}
func (k keyToValueCodec[K]) Encode(value K) ([]byte, error) {
buf := make([]byte, k.kc.Size(value))
_, err := k.kc.Encode(buf, value)
return buf, err
}
func (k keyToValueCodec[K]) Decode(b []byte) (K, error) {
r, key, err := k.kc.Decode(b)
if err != nil {
var key K
return key, err
}
if r != len(b) {
var key K
return key, fmt.Errorf("%w: was supposed to fully consume the key '%x', consumed %d out of %d", ErrEncoding, b, r, len(b))
}
return key, nil
}
func (k keyToValueCodec[K]) Stringify(value K) string {
return k.kc.Stringify(value)
}
func (k keyToValueCodec[K]) ValueType() string {
return k.kc.KeyType()
}

View File

@ -0,0 +1,72 @@
package collections
import (
"testing"
"github.com/stretchr/testify/require"
)
type nftBalance struct {
nftIDs []uint64
}
func TestGenericUniqueIndex(t *testing.T) {
// we create the same testing context as with GenericMultiIndex. We have a mapping:
// Address => NFT balance.
// An NFT balance is represented as a slice of IDs, those IDs are unique, meaning that
// they can be held only by one address.
sk, ctx := deps()
sb := NewSchemaBuilder(sk)
ui := NewGenericUniqueIndex(
sb, NewPrefix("nft_to_owner_index"), "ntf_to_owner_index", Uint64Key, StringKey,
func(pk string, value nftBalance) ([]IndexReference[uint64, string], error) {
// the referencing keys are all the NFT unique ids.
refs := make([]IndexReference[uint64, string], len(value.nftIDs))
// for each NFT contained in the balance we create an index reference
// between the NFT unique ID and the owner of the balance.
for i, id := range value.nftIDs {
refs[i] = NewIndexReference(id, pk)
}
return refs, nil
},
)
// let's create the relationships
err := ui.Reference(ctx, "cosmosAddr1", nftBalance{nftIDs: []uint64{0, 1}}, nil)
require.NoError(t, err)
// assert relations were created
iter, err := ui.Iterate(ctx, nil)
require.NoError(t, err)
defer iter.Close()
kv, err := iter.KeyValues()
require.NoError(t, err)
require.Len(t, kv, 2)
require.Equal(t, kv[0].Key, uint64(0))
require.Equal(t, kv[0].Value, "cosmosAddr1")
require.Equal(t, kv[1].Key, uint64(1))
require.Equal(t, kv[1].Value, "cosmosAddr1")
// assert only one address can own a unique NFT
err = ui.Reference(ctx, "cosmosAddr2", nftBalance{nftIDs: []uint64{0}}, nil) // nft with ID 0 is already owned by cosmosAddr1
require.ErrorIs(t, err, ErrConflict)
// during modifications references are updated, we update the index in
// such a way that cosmosAddr1 loses ownership of nft with id 0.
err = ui.Reference(ctx, "cosmosAddr1",
nftBalance{nftIDs: []uint64{1}}, // this is the update nft balance, which contains only id 1
&nftBalance{nftIDs: []uint64{0, 1}}, // this is the old nft balance, which contains both 0 and 1
)
require.NoError(t, err)
// the updated balance does not contain nft with id 0
_, err = ui.Get(ctx, 0)
require.ErrorIs(t, err, ErrNotFound)
// unreferencing clears all the indexes
err = ui.Unreference(ctx, "cosmosAddr1", nftBalance{nftIDs: []uint64{1}})
require.NoError(t, err)
_, err = ui.Get(ctx, 1)
require.ErrorIs(t, err, ErrNotFound)
}

View File

@ -23,14 +23,6 @@ func NewItem[V any](
return item
}
func (i Item[V]) getName() string {
return i.name
}
func (i Item[V]) getPrefix() []byte {
return i.prefix
}
// Get gets the item, if it is not set it returns an ErrNotFound error.
// If value decoding fails then an ErrEncoding is returned.
func (i Item[V]) Get(ctx context.Context) (V, error) {

View File

@ -121,7 +121,6 @@ func (r *Range[K]) Descending() *Range[K] {
// test sentinel error
var (
errRange = errors.New("collections: range error")
errOrder = errors.New("collections: invalid order")
)
@ -161,6 +160,7 @@ func iteratorFromRanger[K, V any](ctx context.Context, m Map[K, V], r Ranger[K])
} else {
endBytes = nextBytesPrefixKey(m.prefix)
}
return newIterator(ctx, startBytes, endBytes, order, m)
}

View File

@ -14,5 +14,4 @@ func TestUint64Key(t *testing.T) {
}
func TestStringKey(t *testing.T) {
}

View File

@ -70,9 +70,7 @@ func (m Map[K, V]) Set(ctx context.Context, key K, value V) error {
// Get returns the value associated with the provided key,
// errors with ErrNotFound if the key does not exist, or
// with ErrEncoding if the key or value decoding fails.
func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) {
var v V
func (m Map[K, V]) Get(ctx context.Context, key K) (v V, err error) {
bytesKey, err := encodeKeyWithPrefix(m.prefix, m.kc, key)
if err != nil {
return v, err
@ -80,13 +78,12 @@ func (m Map[K, V]) Get(ctx context.Context, key K) (V, error) {
kvStore := m.sa(ctx)
valueBytes, err := kvStore.Get(bytesKey)
if valueBytes == nil {
return v, fmt.Errorf("%w: key '%s' of type %s", ErrNotFound, m.kc.Stringify(key), m.vc.ValueType())
}
if err != nil {
return v, err
}
if valueBytes == nil {
return v, fmt.Errorf("%w: key '%s' of type %s", ErrNotFound, m.kc.Stringify(key), m.vc.ValueType())
}
v, err = m.vc.Decode(valueBytes)
if err != nil {
@ -115,8 +112,7 @@ func (m Map[K, V]) Remove(ctx context.Context, key K) error {
return err
}
kvStore := m.sa(ctx)
kvStore.Delete(bytesKey)
return nil
return kvStore.Delete(bytesKey)
}
// Iterate provides an Iterator over K and V. It accepts a Ranger interface.

View File

@ -1,8 +1,9 @@
package collections
import (
"github.com/stretchr/testify/require"
"testing"
"github.com/stretchr/testify/require"
)
func TestPair(t *testing.T) {
@ -63,5 +64,6 @@ func TestPairRange(t *testing.T) {
iter, err = m.Iterate(ctx, NewPrefixedPairRange[string, uint64]("A").Descending().StartExclusive(0).EndInclusive(2))
require.NoError(t, err)
keys, err = iter.Keys()
require.NoError(t, err)
require.Equal(t, []Pair[string, uint64]{Join("A", uint64(2)), Join("A", uint64(1))}, keys)
}

View File

@ -151,26 +151,6 @@ func NewSchemaFromAccessor(accessor func(context.Context) store.KVStore) Schema
}
}
func (s Schema) addCollection(collection collection) {
prefix := collection.getPrefix()
name := collection.getName()
if _, ok := s.collectionsByPrefix[string(prefix)]; ok {
panic(fmt.Errorf("prefix %v already taken within schema", prefix))
}
if _, ok := s.collectionsByName[name]; ok {
panic(fmt.Errorf("name %s already taken within schema", name))
}
if !nameRegex.MatchString(name) {
panic(fmt.Errorf("name must match regex %s, got %s", NameRegex, name))
}
s.collectionsByPrefix[string(prefix)] = collection
s.collectionsByName[name] = collection
}
// DefaultGenesis implements the appmodule.HasGenesis.DefaultGenesis method.
func (s Schema) DefaultGenesis(target appmodule.GenesisTarget) error {
for _, name := range s.collectionsOrdered {

View File

@ -7,7 +7,6 @@ import (
)
func TestUint64Value(t *testing.T) {
t.Run("invalid size", func(t *testing.T) {
_, err := Uint64Value.Decode([]byte{0x1, 0x2})
require.ErrorIs(t, err, ErrEncoding)