diff --git a/schema/decoding/decoding_test.go b/schema/decoding/decoding_test.go new file mode 100644 index 0000000000..988c1b5ea2 --- /dev/null +++ b/schema/decoding/decoding_test.go @@ -0,0 +1,456 @@ +package decoding + +import ( + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "testing" + + "cosmossdk.io/schema" + "cosmossdk.io/schema/appdata" +) + +func TestMiddleware(t *testing.T) { + tl := newTestFixture(t) + listener, err := Middleware(tl.Listener, tl.resolver, MiddlewareOptions{}) + if err != nil { + t.Fatal("unexpected error", err) + } + tl.setListener(listener) + + tl.bankMod.Mint("bob", "foo", 100) + err = tl.bankMod.Send("bob", "alice", "foo", 50) + if err != nil { + t.Fatal("unexpected error", err) + } + + tl.oneMod.SetValue("abc") + + expectedBank := []schema.ObjectUpdate{ + { + TypeName: "supply", + Key: []interface{}{"foo"}, + Value: uint64(100), + }, + { + TypeName: "balances", + Key: []interface{}{"bob", "foo"}, + Value: uint64(100), + }, + { + TypeName: "balances", + Key: []interface{}{"bob", "foo"}, + Value: uint64(50), + }, + { + TypeName: "balances", + Key: []interface{}{"alice", "foo"}, + Value: uint64(50), + }, + } + + if !reflect.DeepEqual(tl.bankUpdates, expectedBank) { + t.Fatalf("expected %v, got %v", expectedBank, tl.bankUpdates) + } + + expectedOne := []schema.ObjectUpdate{ + {TypeName: "item", Value: "abc"}, + } + + if !reflect.DeepEqual(tl.oneValueUpdates, expectedOne) { + t.Fatalf("expected %v, got %v", expectedOne, tl.oneValueUpdates) + } +} + +func TestMiddleware_filtered(t *testing.T) { + tl := newTestFixture(t) + listener, err := Middleware(tl.Listener, tl.resolver, MiddlewareOptions{ + ModuleFilter: func(moduleName string) bool { + return moduleName == "one" + }, + }) + if err != nil { + t.Fatal("unexpected error", err) + } + tl.setListener(listener) + + tl.bankMod.Mint("bob", "foo", 100) + tl.oneMod.SetValue("abc") + + if len(tl.bankUpdates) != 0 { + t.Fatalf("expected no bank updates") + } + + expectedOne := []schema.ObjectUpdate{ + {TypeName: "item", Value: "abc"}, + } + + if !reflect.DeepEqual(tl.oneValueUpdates, expectedOne) { + t.Fatalf("expected %v, got %v", expectedOne, tl.oneValueUpdates) + } +} + +func TestSync(t *testing.T) { + tl := newTestFixture(t) + tl.bankMod.Mint("bob", "foo", 100) + err := tl.bankMod.Send("bob", "alice", "foo", 50) + if err != nil { + t.Fatal("unexpected error", err) + } + + tl.oneMod.SetValue("def") + + err = Sync(tl.Listener, tl.multiStore, tl.resolver, SyncOptions{}) + if err != nil { + t.Fatal("unexpected error", err) + } + + expected := []schema.ObjectUpdate{ + { + TypeName: "balances", + Key: []interface{}{"alice", "foo"}, + Value: uint64(50), + }, + { + TypeName: "balances", + Key: []interface{}{"bob", "foo"}, + Value: uint64(50), + }, + { + TypeName: "supply", + Key: []interface{}{"foo"}, + Value: uint64(100), + }, + } + + if !reflect.DeepEqual(tl.bankUpdates, expected) { + t.Fatalf("expected %v, got %v", expected, tl.bankUpdates) + } + + expectedOne := []schema.ObjectUpdate{ + {TypeName: "item", Value: "def"}, + } + + if !reflect.DeepEqual(tl.oneValueUpdates, expectedOne) { + t.Fatalf("expected %v, got %v", expectedOne, tl.oneValueUpdates) + } +} + +func TestSync_filtered(t *testing.T) { + tl := newTestFixture(t) + tl.bankMod.Mint("bob", "foo", 100) + tl.oneMod.SetValue("def") + + err := Sync(tl.Listener, tl.multiStore, tl.resolver, SyncOptions{ + ModuleFilter: func(moduleName string) bool { + return moduleName == "one" + }, + }) + if err != nil { + t.Fatal("unexpected error", err) + } + + if len(tl.bankUpdates) != 0 { + t.Fatalf("expected no bank updates") + } + + expectedOne := []schema.ObjectUpdate{ + {TypeName: "item", Value: "def"}, + } + + if !reflect.DeepEqual(tl.oneValueUpdates, expectedOne) { + t.Fatalf("expected %v, got %v", expectedOne, tl.oneValueUpdates) + } +} + +type testFixture struct { + appdata.Listener + bankUpdates []schema.ObjectUpdate + oneValueUpdates []schema.ObjectUpdate + resolver DecoderResolver + multiStore *testMultiStore + bankMod *exampleBankModule + oneMod *oneValueModule +} + +func newTestFixture(t *testing.T) *testFixture { + res := &testFixture{} + res.Listener = appdata.Listener{ + InitializeModuleData: func(data appdata.ModuleInitializationData) error { + var expected schema.ModuleSchema + switch data.ModuleName { + case "bank": + expected = exampleBankSchema + case "one": + + expected = oneValueModSchema + default: + t.Fatalf("unexpected module %s", data.ModuleName) + } + + if !reflect.DeepEqual(data.Schema, expected) { + t.Errorf("expected %v, got %v", expected, data.Schema) + } + return nil + }, + OnObjectUpdate: func(data appdata.ObjectUpdateData) error { + switch data.ModuleName { + case "bank": + res.bankUpdates = append(res.bankUpdates, data.Updates...) + case "one": + res.oneValueUpdates = append(res.oneValueUpdates, data.Updates...) + default: + t.Errorf("unexpected module %s", data.ModuleName) + } + return nil + }, + } + res.multiStore = newTestMultiStore() + res.bankMod = &exampleBankModule{ + store: res.multiStore.newTestStore(t, "bank"), + } + res.oneMod = &oneValueModule{ + store: res.multiStore.newTestStore(t, "one"), + } + modSet := map[string]interface{}{ + "bank": res.bankMod, + "one": res.oneMod, + } + res.resolver = ModuleSetDecoderResolver(modSet) + return res +} + +func (f *testFixture) setListener(listener appdata.Listener) { + f.bankMod.store.listener = listener + f.oneMod.store.listener = listener +} + +type testMultiStore struct { + stores map[string]*testStore +} + +type testStore struct { + t *testing.T + modName string + store map[string][]byte + listener appdata.Listener +} + +func newTestMultiStore() *testMultiStore { + return &testMultiStore{ + stores: map[string]*testStore{}, + } +} + +var _ SyncSource = &testMultiStore{} + +func (ms *testMultiStore) IterateAllKVPairs(moduleName string, fn func(key []byte, value []byte) error) error { + s, ok := ms.stores[moduleName] + if !ok { + return fmt.Errorf("don't have state for module %s", moduleName) + } + + var keys []string + for key := range s.store { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + err := fn([]byte(key), s.store[key]) + if err != nil { + return err + } + } + return nil +} + +func (ms *testMultiStore) newTestStore(t *testing.T, modName string) *testStore { + s := &testStore{ + t: t, + modName: modName, + store: map[string][]byte{}, + } + ms.stores[modName] = s + return s +} + +func (t testStore) Get(key []byte) []byte { + return t.store[string(key)] +} + +func (t testStore) GetUInt64(key []byte) uint64 { + bz := t.store[string(key)] + if len(bz) == 0 { + return 0 + } + x, err := strconv.ParseUint(string(bz), 10, 64) + if err != nil { + t.t.Fatalf("unexpected error: %v", err) + } + return x +} + +func (t testStore) Set(key, value []byte) { + if t.listener.OnKVPair != nil { + err := t.listener.OnKVPair(appdata.KVPairData{Updates: []appdata.ModuleKVPairUpdate{ + { + ModuleName: t.modName, + Update: schema.KVPairUpdate{ + Key: key, + Value: value, + }, + }, + }}) + if err != nil { + t.t.Fatalf("unexpected error: %v", err) + } + } + t.store[string(key)] = value +} + +func (t testStore) SetUInt64(key []byte, value uint64) { + t.Set(key, []byte(strconv.FormatUint(value, 10))) +} + +type exampleBankModule struct { + store *testStore +} + +func (e exampleBankModule) Mint(acct, denom string, amount uint64) { + key := supplyKey(denom) + e.store.SetUInt64(key, e.store.GetUInt64(key)+amount) + e.addBalance(acct, denom, amount) +} + +func (e exampleBankModule) Send(from, to, denom string, amount uint64) error { + err := e.subBalance(from, denom, amount) + if err != nil { + return nil + } + e.addBalance(to, denom, amount) + return nil +} + +func (e exampleBankModule) GetBalance(acct, denom string) uint64 { + return e.store.GetUInt64(balanceKey(acct, denom)) +} + +func (e exampleBankModule) GetSupply(denom string) uint64 { + return e.store.GetUInt64(supplyKey(denom)) +} + +func balanceKey(acct, denom string) []byte { + return []byte(fmt.Sprintf("balance/%s/%s", acct, denom)) +} + +func supplyKey(denom string) []byte { + return []byte(fmt.Sprintf("supply/%s", denom)) +} + +func (e exampleBankModule) addBalance(acct, denom string, amount uint64) { + key := balanceKey(acct, denom) + e.store.SetUInt64(key, e.store.GetUInt64(key)+amount) +} + +func (e exampleBankModule) subBalance(acct, denom string, amount uint64) error { + key := balanceKey(acct, denom) + cur := e.store.GetUInt64(key) + if cur < amount { + return fmt.Errorf("insufficient balance") + } + e.store.SetUInt64(key, cur-amount) + return nil +} + +var exampleBankSchema = schema.ModuleSchema{ + ObjectTypes: []schema.ObjectType{ + { + Name: "balances", + KeyFields: []schema.Field{ + { + Name: "account", + Kind: schema.StringKind, + }, + { + Name: "denom", + Kind: schema.StringKind, + }, + }, + ValueFields: []schema.Field{ + { + Name: "amount", + Kind: schema.Uint64Kind, + }, + }, + }, + }, +} + +func (e exampleBankModule) ModuleCodec() (schema.ModuleCodec, error) { + return schema.ModuleCodec{ + Schema: exampleBankSchema, + KVDecoder: func(update schema.KVPairUpdate) ([]schema.ObjectUpdate, error) { + key := string(update.Key) + value, err := strconv.ParseUint(string(update.Value), 10, 64) + if err != nil { + return nil, err + } + if strings.HasPrefix(key, "balance/") { + parts := strings.Split(key, "/") + return []schema.ObjectUpdate{{ + TypeName: "balances", + Key: []interface{}{parts[1], parts[2]}, + Value: value, + }}, nil + } else if strings.HasPrefix(key, "supply/") { + parts := strings.Split(key, "/") + return []schema.ObjectUpdate{{ + TypeName: "supply", + Key: []interface{}{parts[1]}, + Value: value, + }}, nil + } else { + return nil, fmt.Errorf("unexpected key: %s", key) + } + }, + }, nil +} + +var _ schema.HasModuleCodec = exampleBankModule{} + +type oneValueModule struct { + store *testStore +} + +var oneValueModSchema = schema.ModuleSchema{ + ObjectTypes: []schema.ObjectType{ + { + Name: "item", + ValueFields: []schema.Field{ + {Name: "value", Kind: schema.StringKind}, + }, + }, + }, +} + +func (i oneValueModule) ModuleCodec() (schema.ModuleCodec, error) { + return schema.ModuleCodec{ + Schema: oneValueModSchema, + KVDecoder: func(update schema.KVPairUpdate) ([]schema.ObjectUpdate, error) { + if string(update.Key) != "key" { + return nil, fmt.Errorf("unexpected key: %v", update.Key) + } + return []schema.ObjectUpdate{ + {TypeName: "item", Value: string(update.Value)}, + }, nil + }, + }, nil +} + +func (i oneValueModule) SetValue(x string) { + i.store.Set([]byte("key"), []byte(x)) +} + +var _ schema.HasModuleCodec = oneValueModule{} diff --git a/schema/decoding/middleware.go b/schema/decoding/middleware.go new file mode 100644 index 0000000000..57c0783c62 --- /dev/null +++ b/schema/decoding/middleware.go @@ -0,0 +1,106 @@ +package decoding + +import ( + "cosmossdk.io/schema" + "cosmossdk.io/schema/appdata" +) + +type MiddlewareOptions struct { + ModuleFilter func(moduleName string) bool +} + +// Middleware decodes raw data passed to the listener as kv-updates into decoded object updates. Module initialization +// is done lazily as modules are encountered in the kv-update stream. +func Middleware(target appdata.Listener, resolver DecoderResolver, opts MiddlewareOptions) (appdata.Listener, error) { + initializeModuleData := target.InitializeModuleData + onObjectUpdate := target.OnObjectUpdate + + // no-op if not listening to decoded data + if initializeModuleData == nil && onObjectUpdate == nil { + return target, nil + } + + onKVPair := target.OnKVPair + + moduleCodecs := map[string]*schema.ModuleCodec{} + + target.OnKVPair = func(data appdata.KVPairData) error { + // first forward kv pair updates + if onKVPair != nil { + err := onKVPair(data) + if err != nil { + return err + } + } + + for _, kvUpdate := range data.Updates { + // look for an existing codec + pcdc, ok := moduleCodecs[kvUpdate.ModuleName] + if !ok { + if opts.ModuleFilter != nil && !opts.ModuleFilter(kvUpdate.ModuleName) { + // we don't care about this module so store nil and continue + moduleCodecs[kvUpdate.ModuleName] = nil + continue + } + + // look for a new codec + cdc, found, err := resolver.LookupDecoder(kvUpdate.ModuleName) + if err != nil { + return err + } + + if !found { + // store nil to indicate we've seen this module and don't have a codec + // and keep processing the kv updates + moduleCodecs[kvUpdate.ModuleName] = nil + continue + } + + pcdc = &cdc + moduleCodecs[kvUpdate.ModuleName] = pcdc + + if initializeModuleData != nil { + err = initializeModuleData(appdata.ModuleInitializationData{ + ModuleName: kvUpdate.ModuleName, + Schema: cdc.Schema, + }) + if err != nil { + return err + } + } + } + + if pcdc == nil { + // we've already seen this module and can't decode + continue + } + + if onObjectUpdate == nil || pcdc.KVDecoder == nil { + // not listening to updates or can't decode so continue + continue + } + + updates, err := pcdc.KVDecoder(kvUpdate.Update) + if err != nil { + return err + } + + if len(updates) == 0 { + // no updates + continue + } + + err = target.OnObjectUpdate(appdata.ObjectUpdateData{ + ModuleName: kvUpdate.ModuleName, + Updates: updates, + }) + if err != nil { + return err + } + } + + return nil + } + + return target, nil +} diff --git a/schema/decoding/resolver_test.go b/schema/decoding/resolver_test.go index ecea614d19..188de96af4 100644 --- a/schema/decoding/resolver_test.go +++ b/schema/decoding/resolver_test.go @@ -31,14 +31,15 @@ var moduleSet = map[string]interface{}{ "modC": modC{}, } -var resolver = ModuleSetDecoderResolver(moduleSet) +var testResolver = ModuleSetDecoderResolver(moduleSet) func TestModuleSetDecoderResolver_IterateAll(t *testing.T) { objectTypes := map[string]bool{} - err := resolver.IterateAll(func(moduleName string, cdc schema.ModuleCodec) error { + err := testResolver.IterateAll(func(moduleName string, cdc schema.ModuleCodec) error { objectTypes[cdc.Schema.ObjectTypes[0].Name] = true return nil }) + if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -57,7 +58,7 @@ func TestModuleSetDecoderResolver_IterateAll(t *testing.T) { } func TestModuleSetDecoderResolver_LookupDecoder(t *testing.T) { - decoder, found, err := resolver.LookupDecoder("modA") + decoder, found, err := testResolver.LookupDecoder("modA") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -70,7 +71,7 @@ func TestModuleSetDecoderResolver_LookupDecoder(t *testing.T) { t.Fatalf("expected object type A, got %s", decoder.Schema.ObjectTypes[0].Name) } - decoder, found, err = resolver.LookupDecoder("modB") + decoder, found, err = testResolver.LookupDecoder("modB") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -83,7 +84,7 @@ func TestModuleSetDecoderResolver_LookupDecoder(t *testing.T) { t.Fatalf("expected object type B, got %s", decoder.Schema.ObjectTypes[0].Name) } - decoder, found, err = resolver.LookupDecoder("modC") + decoder, found, err = testResolver.LookupDecoder("modC") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -92,7 +93,7 @@ func TestModuleSetDecoderResolver_LookupDecoder(t *testing.T) { t.Fatalf("expected not to find decoder") } - decoder, found, err = resolver.LookupDecoder("modD") + decoder, found, err = testResolver.LookupDecoder("modD") if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/schema/decoding/sync.go b/schema/decoding/sync.go index 85e6b4d74b..19582cfa37 100644 --- a/schema/decoding/sync.go +++ b/schema/decoding/sync.go @@ -1,8 +1,64 @@ package decoding +import ( + "cosmossdk.io/schema" + "cosmossdk.io/schema/appdata" +) + // SyncSource is an interface that allows indexers to start indexing modules with pre-existing state. // It should generally be a wrapper around the key-value store. type SyncSource interface { + // IterateAllKVPairs iterates over all key-value pairs for a given module. IterateAllKVPairs(moduleName string, fn func(key, value []byte) error) error } + +// SyncOptions are the options for Sync. +type SyncOptions struct { + ModuleFilter func(moduleName string) bool +} + +// Sync synchronizes existing state from the sync source to the listener using the resolver to decode data. +func Sync(listener appdata.Listener, source SyncSource, resolver DecoderResolver, opts SyncOptions) error { + initializeModuleData := listener.InitializeModuleData + onObjectUpdate := listener.OnObjectUpdate + + // no-op if not listening to decoded data + if initializeModuleData == nil && onObjectUpdate == nil { + return nil + } + + return resolver.IterateAll(func(moduleName string, cdc schema.ModuleCodec) error { + if opts.ModuleFilter != nil && !opts.ModuleFilter(moduleName) { + // ignore this module + return nil + } + + if initializeModuleData != nil { + err := initializeModuleData(appdata.ModuleInitializationData{ + ModuleName: moduleName, + Schema: cdc.Schema, + }) + if err != nil { + return err + } + } + + if onObjectUpdate == nil || cdc.KVDecoder == nil { + return nil + } + + return source.IterateAllKVPairs(moduleName, func(key, value []byte) error { + updates, err := cdc.KVDecoder(schema.KVPairUpdate{Key: key, Value: value}) + if err != nil { + return err + } + + if len(updates) == 0 { + return nil + } + + return onObjectUpdate(appdata.ObjectUpdateData{ModuleName: moduleName, Updates: updates}) + }) + }) +}