diff --git a/orm/internal/codegen/index.go b/orm/internal/codegen/index.go index 8a76c7c973..a7024c34f2 100644 --- a/orm/internal/codegen/index.go +++ b/orm/internal/codegen/index.go @@ -20,9 +20,9 @@ func (t tableGen) genIndexKeys() { // start with primary key.. t.P("// primary key starting index..") - t.genIndex(t.table.PrimaryKey.Fields, t.ormTable.ID()) + t.genIndex(t.table.PrimaryKey.Fields, 0, true) for _, idx := range t.table.Index { - t.genIndex(idx.Fields, idx.Id) + t.genIndex(idx.Fields, idx.Id, false) } } @@ -94,9 +94,15 @@ func (t tableGen) indexStructName(fields []string) string { return t.msg.GoIdent.GoName + joinedNames + "IndexKey" } -func (t tableGen) genIndex(fields string, id uint32) { +func (t tableGen) genIndex(fields string, id uint32, isPrimaryKey bool) { fieldsSlc := strings.Split(fields, ",") idxKeyName := t.indexStructName(fieldsSlc) + + if isPrimaryKey { + t.P("type ", t.msg.GoIdent.GoName, "PrimaryKey = ", idxKeyName) + t.P() + } + t.P("type ", idxKeyName, " struct {") t.P("vs []interface{}") t.P("}") diff --git a/orm/internal/codegen/singleton.go b/orm/internal/codegen/singleton.go index c7ca53a36f..ddf739ed63 100644 --- a/orm/internal/codegen/singleton.go +++ b/orm/internal/codegen/singleton.go @@ -60,12 +60,9 @@ func (s singletonGen) genMethods() { varName := s.param(s.msg.GoIdent.GoName) // Get s.P(receiver, "Get(ctx ", contextPkg.Ident("Context"), ") (*", s.msg.GoIdent.GoName, ", error) {") - s.P("var ", varName, " ", s.msg.GoIdent.GoName) - s.P("found, err := x.table.Get(ctx, &", varName, ")") - s.P("if !found {") - s.P("return nil, err") - s.P("}") - s.P("return &", varName, ", err") + s.P(varName, " := &", s.msg.GoIdent.GoName, "{}") + s.P("_, err := x.table.Get(ctx, ", varName, ")") + s.P("return ", varName, ", err") s.P("}") s.P() diff --git a/orm/internal/codegen/table.go b/orm/internal/codegen/table.go index 5dd8d4725d..cd6e96a8b3 100644 --- a/orm/internal/codegen/table.go +++ b/orm/internal/codegen/table.go @@ -4,8 +4,6 @@ import ( "fmt" "strings" - "github.com/iancoleman/strcase" - "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/dynamicpb" @@ -189,11 +187,12 @@ func (t tableGen) genStoreImpl() { // has t.P("func (", receiverVar, " ", t.messageStoreReceiverName(t.msg), ") ", hasName, "{") - t.P("return ", receiverVar, ".table.Has(ctx, &", t.msg.GoIdent.GoName, "{") + t.P("return ", receiverVar, ".table.GetIndexByID(", idx.Id, ").(", + tablePkg.Ident("UniqueIndex"), ").Has(ctx,") for _, field := range fields { - t.P(strcase.ToCamel(field), ": ", field, ",") + t.P(field, ",") } - t.P("})") + t.P(")") t.P("}") t.P() @@ -201,23 +200,24 @@ func (t tableGen) genStoreImpl() { varName := t.param(t.msg.GoIdent.GoName) varTypeName := t.msg.GoIdent.GoName t.P("func (", receiverVar, " ", t.messageStoreReceiverName(t.msg), ") ", getName, "{") - t.P(varName, " := &", varTypeName, "{") + t.P("var ", varName, " ", varTypeName) + t.P("found, err := ", receiverVar, ".table.GetIndexByID(", idx.Id, ").(", + tablePkg.Ident("UniqueIndex"), ").Get(ctx, &", varName, ",") for _, field := range fields { - t.P(strcase.ToCamel(field), ": ", field, ",") + t.P(field, ",") } - t.P("}") - t.P("found, err := ", receiverVar, ".table.Get(ctx, ", varName, ")") + t.P(")") t.P("if !found {") t.P("return nil, err") t.P("}") - t.P("return ", varName, ", nil") + t.P("return &", varName, ", nil") t.P("}") t.P() } // List t.P(receiver, "List(ctx ", contextPkg.Ident("Context"), ", prefixKey ", t.indexKeyInterfaceName(), ", opts ...", ormListPkg.Ident("Option"), ") (", t.iteratorName(), ", error) {") - t.P("opts = append(opts, ", ormListPkg.Ident("Prefix"), "(prefixKey.values()))") + t.P("opts = append(opts, ", ormListPkg.Ident("Prefix"), "(prefixKey.values()...))") t.P("it, err := ", receiverVar, ".table.GetIndexByID(prefixKey.id()).Iterator(ctx, opts...)") t.P("return ", t.iteratorName(), "{it}, err") t.P("}") @@ -225,7 +225,7 @@ func (t tableGen) genStoreImpl() { // ListRange t.P(receiver, "ListRange(ctx ", contextPkg.Ident("Context"), ", from, to ", t.indexKeyInterfaceName(), ", opts ...", ormListPkg.Ident("Option"), ") (", t.iteratorName(), ", error) {") - t.P("opts = append(opts, ", ormListPkg.Ident("Start"), "(from.values()), ", ormListPkg.Ident("End"), "(to))") + t.P("opts = append(opts, ", ormListPkg.Ident("Start"), "(from.values()...), ", ormListPkg.Ident("End"), "(to.values()...))") t.P("it, err := ", receiverVar, ".table.GetIndexByID(from.id()).Iterator(ctx, opts...)") t.P("return ", t.iteratorName(), "{it}, err") t.P("}") diff --git a/orm/internal/testpb/bank.cosmos_orm.go b/orm/internal/testpb/bank.cosmos_orm.go index 578eea326e..e8eafbb42d 100644 --- a/orm/internal/testpb/bank.cosmos_orm.go +++ b/orm/internal/testpb/bank.cosmos_orm.go @@ -4,6 +4,7 @@ package testpb import ( context "context" + ormdb "github.com/cosmos/cosmos-sdk/orm/model/ormdb" ormlist "github.com/cosmos/cosmos-sdk/orm/model/ormlist" ormtable "github.com/cosmos/cosmos-sdk/orm/model/ormtable" @@ -40,11 +41,13 @@ type BalanceIndexKey interface { } // primary key starting index.. +type BalancePrimaryKey = BalanceAddressDenomIndexKey + type BalanceAddressDenomIndexKey struct { vs []interface{} } -func (x BalanceAddressDenomIndexKey) id() uint32 { return 1 } +func (x BalanceAddressDenomIndexKey) id() uint32 { return 0 } func (x BalanceAddressDenomIndexKey) values() []interface{} { return x.vs } func (x BalanceAddressDenomIndexKey) balanceIndexKey() {} @@ -105,13 +108,13 @@ func (this balanceStore) Get(ctx context.Context, address string, denom string) } func (this balanceStore) List(ctx context.Context, prefixKey BalanceIndexKey, opts ...ormlist.Option) (BalanceIterator, error) { - opts = append(opts, ormlist.Prefix(prefixKey.values())) + opts = append(opts, ormlist.Prefix(prefixKey.values()...)) it, err := this.table.GetIndexByID(prefixKey.id()).Iterator(ctx, opts...) return BalanceIterator{it}, err } func (this balanceStore) ListRange(ctx context.Context, from, to BalanceIndexKey, opts ...ormlist.Option) (BalanceIterator, error) { - opts = append(opts, ormlist.Start(from.values()), ormlist.End(to)) + opts = append(opts, ormlist.Start(from.values()...), ormlist.End(to.values()...)) it, err := this.table.GetIndexByID(from.id()).Iterator(ctx, opts...) return BalanceIterator{it}, err } @@ -158,11 +161,13 @@ type SupplyIndexKey interface { } // primary key starting index.. +type SupplyPrimaryKey = SupplyDenomIndexKey + type SupplyDenomIndexKey struct { vs []interface{} } -func (x SupplyDenomIndexKey) id() uint32 { return 2 } +func (x SupplyDenomIndexKey) id() uint32 { return 0 } func (x SupplyDenomIndexKey) values() []interface{} { return x.vs } func (x SupplyDenomIndexKey) supplyIndexKey() {} @@ -205,13 +210,13 @@ func (this supplyStore) Get(ctx context.Context, denom string) (*Supply, error) } func (this supplyStore) List(ctx context.Context, prefixKey SupplyIndexKey, opts ...ormlist.Option) (SupplyIterator, error) { - opts = append(opts, ormlist.Prefix(prefixKey.values())) + opts = append(opts, ormlist.Prefix(prefixKey.values()...)) it, err := this.table.GetIndexByID(prefixKey.id()).Iterator(ctx, opts...) return SupplyIterator{it}, err } func (this supplyStore) ListRange(ctx context.Context, from, to SupplyIndexKey, opts ...ormlist.Option) (SupplyIterator, error) { - opts = append(opts, ormlist.Start(from.values()), ormlist.End(to)) + opts = append(opts, ormlist.Start(from.values()...), ormlist.End(to.values()...)) it, err := this.table.GetIndexByID(from.id()).Iterator(ctx, opts...) return SupplyIterator{it}, err } diff --git a/orm/internal/testpb/bank.pulsar.go b/orm/internal/testpb/bank.pulsar.go index 36cfc62a53..73ac47da7b 100644 --- a/orm/internal/testpb/bank.pulsar.go +++ b/orm/internal/testpb/bank.pulsar.go @@ -1,15 +1,18 @@ +// Code generated by protoc-gen-go-pulsar. DO NOT EDIT. package testpb import ( fmt "fmt" - runtime "github.com/cosmos/cosmos-proto/runtime" - _ "github.com/cosmos/cosmos-sdk/api/cosmos/orm/v1alpha1" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoiface "google.golang.org/protobuf/runtime/protoiface" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" io "io" reflect "reflect" sync "sync" + + runtime "github.com/cosmos/cosmos-proto/runtime" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoiface "google.golang.org/protobuf/runtime/protoiface" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + + _ "github.com/cosmos/cosmos-sdk/api/cosmos/orm/v1alpha1" ) var ( diff --git a/orm/internal/testpb/test_schema.cosmos_orm.go b/orm/internal/testpb/test_schema.cosmos_orm.go index 4a34cfcee3..573103371d 100644 --- a/orm/internal/testpb/test_schema.cosmos_orm.go +++ b/orm/internal/testpb/test_schema.cosmos_orm.go @@ -4,6 +4,7 @@ package testpb import ( context "context" + ormdb "github.com/cosmos/cosmos-sdk/orm/model/ormdb" ormlist "github.com/cosmos/cosmos-sdk/orm/model/ormlist" ormtable "github.com/cosmos/cosmos-sdk/orm/model/ormtable" @@ -42,11 +43,13 @@ type ExampleTableIndexKey interface { } // primary key starting index.. +type ExampleTablePrimaryKey = ExampleTableU32I64StrIndexKey + type ExampleTableU32I64StrIndexKey struct { vs []interface{} } -func (x ExampleTableU32I64StrIndexKey) id() uint32 { return 1 } +func (x ExampleTableU32I64StrIndexKey) id() uint32 { return 0 } func (x ExampleTableU32I64StrIndexKey) values() []interface{} { return x.vs } func (x ExampleTableU32I64StrIndexKey) exampleTableIndexKey() {} @@ -153,32 +156,32 @@ func (this exampleTableStore) Get(ctx context.Context, u32 uint32, i64 int64, st } func (this exampleTableStore) HasByU64Str(ctx context.Context, u64 uint64, str string) (found bool, err error) { - return this.table.Has(ctx, &ExampleTable{ - U64: u64, - Str: str, - }) + return this.table.GetIndexByID(1).(ormtable.UniqueIndex).Has(ctx, + u64, + str, + ) } func (this exampleTableStore) GetByU64Str(ctx context.Context, u64 uint64, str string) (*ExampleTable, error) { - exampleTable := &ExampleTable{ - U64: u64, - Str: str, - } - found, err := this.table.Get(ctx, exampleTable) + var exampleTable ExampleTable + found, err := this.table.GetIndexByID(1).(ormtable.UniqueIndex).Get(ctx, &exampleTable, + u64, + str, + ) if !found { return nil, err } - return exampleTable, nil + return &exampleTable, nil } func (this exampleTableStore) List(ctx context.Context, prefixKey ExampleTableIndexKey, opts ...ormlist.Option) (ExampleTableIterator, error) { - opts = append(opts, ormlist.Prefix(prefixKey.values())) + opts = append(opts, ormlist.Prefix(prefixKey.values()...)) it, err := this.table.GetIndexByID(prefixKey.id()).Iterator(ctx, opts...) return ExampleTableIterator{it}, err } func (this exampleTableStore) ListRange(ctx context.Context, from, to ExampleTableIndexKey, opts ...ormlist.Option) (ExampleTableIterator, error) { - opts = append(opts, ormlist.Start(from.values()), ormlist.End(to)) + opts = append(opts, ormlist.Start(from.values()...), ormlist.End(to.values()...)) it, err := this.table.GetIndexByID(from.id()).Iterator(ctx, opts...) return ExampleTableIterator{it}, err } @@ -227,11 +230,13 @@ type ExampleAutoIncrementTableIndexKey interface { } // primary key starting index.. +type ExampleAutoIncrementTablePrimaryKey = ExampleAutoIncrementTableIdIndexKey + type ExampleAutoIncrementTableIdIndexKey struct { vs []interface{} } -func (x ExampleAutoIncrementTableIdIndexKey) id() uint32 { return 3 } +func (x ExampleAutoIncrementTableIdIndexKey) id() uint32 { return 0 } func (x ExampleAutoIncrementTableIdIndexKey) values() []interface{} { return x.vs } func (x ExampleAutoIncrementTableIdIndexKey) exampleAutoIncrementTableIndexKey() {} @@ -287,30 +292,30 @@ func (this exampleAutoIncrementTableStore) Get(ctx context.Context, id uint64) ( } func (this exampleAutoIncrementTableStore) HasByX(ctx context.Context, x string) (found bool, err error) { - return this.table.Has(ctx, &ExampleAutoIncrementTable{ - X: x, - }) + return this.table.GetIndexByID(1).(ormtable.UniqueIndex).Has(ctx, + x, + ) } func (this exampleAutoIncrementTableStore) GetByX(ctx context.Context, x string) (*ExampleAutoIncrementTable, error) { - exampleAutoIncrementTable := &ExampleAutoIncrementTable{ - X: x, - } - found, err := this.table.Get(ctx, exampleAutoIncrementTable) + var exampleAutoIncrementTable ExampleAutoIncrementTable + found, err := this.table.GetIndexByID(1).(ormtable.UniqueIndex).Get(ctx, &exampleAutoIncrementTable, + x, + ) if !found { return nil, err } - return exampleAutoIncrementTable, nil + return &exampleAutoIncrementTable, nil } func (this exampleAutoIncrementTableStore) List(ctx context.Context, prefixKey ExampleAutoIncrementTableIndexKey, opts ...ormlist.Option) (ExampleAutoIncrementTableIterator, error) { - opts = append(opts, ormlist.Prefix(prefixKey.values())) + opts = append(opts, ormlist.Prefix(prefixKey.values()...)) it, err := this.table.GetIndexByID(prefixKey.id()).Iterator(ctx, opts...) return ExampleAutoIncrementTableIterator{it}, err } func (this exampleAutoIncrementTableStore) ListRange(ctx context.Context, from, to ExampleAutoIncrementTableIndexKey, opts ...ormlist.Option) (ExampleAutoIncrementTableIterator, error) { - opts = append(opts, ormlist.Start(from.values()), ormlist.End(to)) + opts = append(opts, ormlist.Start(from.values()...), ormlist.End(to.values()...)) it, err := this.table.GetIndexByID(from.id()).Iterator(ctx, opts...) return ExampleAutoIncrementTableIterator{it}, err } @@ -340,12 +345,9 @@ type exampleSingletonStore struct { var _ ExampleSingletonStore = exampleSingletonStore{} func (x exampleSingletonStore) Get(ctx context.Context) (*ExampleSingleton, error) { - var exampleSingleton ExampleSingleton - found, err := x.table.Get(ctx, &exampleSingleton) - if !found { - return nil, err - } - return &exampleSingleton, err + exampleSingleton := &ExampleSingleton{} + _, err := x.table.Get(ctx, exampleSingleton) + return exampleSingleton, err } func (x exampleSingletonStore) Save(ctx context.Context, exampleSingleton *ExampleSingleton) error { diff --git a/orm/internal/testpb/test_schema.pulsar.go b/orm/internal/testpb/test_schema.pulsar.go index 80f6517893..6de676cbec 100644 --- a/orm/internal/testpb/test_schema.pulsar.go +++ b/orm/internal/testpb/test_schema.pulsar.go @@ -1,19 +1,22 @@ +// Code generated by protoc-gen-go-pulsar. DO NOT EDIT. package testpb import ( binary "encoding/binary" fmt "fmt" + io "io" + reflect "reflect" + sort "sort" + sync "sync" + runtime "github.com/cosmos/cosmos-proto/runtime" - _ "github.com/cosmos/cosmos-sdk/api/cosmos/orm/v1alpha1" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoiface "google.golang.org/protobuf/runtime/protoiface" protoimpl "google.golang.org/protobuf/runtime/protoimpl" durationpb "google.golang.org/protobuf/types/known/durationpb" timestamppb "google.golang.org/protobuf/types/known/timestamppb" - io "io" - reflect "reflect" - sort "sort" - sync "sync" + + _ "github.com/cosmos/cosmos-sdk/api/cosmos/orm/v1alpha1" ) var _ protoreflect.List = (*_ExampleTable_17_list)(nil) diff --git a/orm/model/ormdb/module.go b/orm/model/ormdb/module.go index 5b99d2309f..5111200cdb 100644 --- a/orm/model/ormdb/module.go +++ b/orm/model/ormdb/module.go @@ -31,12 +31,7 @@ type ModuleSchema struct { } // ModuleDB defines the ORM database type to be used by modules. -type ModuleDB interface { - ormkv.EntryCodec - - // GetTable returns the table for the provided message type or nil. - GetTable(message proto.Message) ormtable.Table -} +type ModuleDB = ormtable.Schema type moduleDB struct { prefix []byte diff --git a/orm/model/ormtable/auto_increment.go b/orm/model/ormtable/auto_increment.go index a373edf8bb..5e1b6ab5ba 100644 --- a/orm/model/ormtable/auto_increment.go +++ b/orm/model/ormtable/auto_increment.go @@ -216,3 +216,10 @@ func (t autoIncrementTable) ExportJSON(ctx context.Context, writer io.Writer) er return t.doExportJSON(ctx, writer) } + +func (t *autoIncrementTable) GetTable(message proto.Message) Table { + if message.ProtoReflect().Descriptor().FullName() == t.MessageType().Descriptor().FullName() { + return t + } + return nil +} diff --git a/orm/model/ormtable/auto_increment_test.go b/orm/model/ormtable/auto_increment_test.go index 4c92b7a8a5..c3404b3380 100644 --- a/orm/model/ormtable/auto_increment_test.go +++ b/orm/model/ormtable/auto_increment_test.go @@ -40,11 +40,14 @@ func TestAutoIncrementScenario(t *testing.T) { } func runAutoIncrementScenario(t *testing.T, table ormtable.Table, context context.Context) { - err := table.Save(context, &testpb.ExampleAutoIncrementTable{Id: 5}) + store, err := testpb.NewExampleAutoIncrementTableStore(table) + assert.NilError(t, err) + + err = store.Save(context, &testpb.ExampleAutoIncrementTable{Id: 5}) assert.ErrorContains(t, err, "update") ex1 := &testpb.ExampleAutoIncrementTable{X: "foo", Y: 5} - assert.NilError(t, table.Save(context, ex1)) + assert.NilError(t, store.Save(context, ex1)) assert.Equal(t, uint64(1), ex1.Id) buf := &bytes.Buffer{} diff --git a/orm/model/ormtable/singleton.go b/orm/model/ormtable/singleton.go index c701ada20f..cab54f7b22 100644 --- a/orm/model/ormtable/singleton.go +++ b/orm/model/ormtable/singleton.go @@ -5,6 +5,8 @@ import ( "encoding/json" "io" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/encoding/protojson" ) @@ -91,3 +93,10 @@ func (t singleton) jsonMarshalOptions() protojson.MarshalOptions { Resolver: t.typeResolver, } } + +func (t *singleton) GetTable(message proto.Message) Table { + if message.ProtoReflect().Descriptor().FullName() == t.MessageType().Descriptor().FullName() { + return t + } + return nil +} diff --git a/orm/model/ormtable/singleton_test.go b/orm/model/ormtable/singleton_test.go index 83dc0143ee..3fe6ddeecd 100644 --- a/orm/model/ormtable/singleton_test.go +++ b/orm/model/ormtable/singleton_test.go @@ -15,38 +15,35 @@ import ( ) func TestSingleton(t *testing.T) { - val := &testpb.ExampleSingleton{} - singleton, err := ormtable.Build(ormtable.Options{ - MessageType: val.ProtoReflect().Type(), + table, err := ormtable.Build(ormtable.Options{ + MessageType: (&testpb.ExampleSingleton{}).ProtoReflect().Type(), }) assert.NilError(t, err) - store := ormtable.WrapContextDefault(testkv.NewSplitMemBackend()) + ctx := ormtable.WrapContextDefault(testkv.NewSplitMemBackend()) - found, err := singleton.Has(store, val) + store, err := testpb.NewExampleSingletonStore(table) assert.NilError(t, err) - assert.Assert(t, !found) - assert.NilError(t, singleton.Save(store, val)) - found, err = singleton.Has(store, val) + + val, err := store.Get(ctx) assert.NilError(t, err) - assert.Assert(t, found) + assert.Assert(t, val != nil) // singletons are always set + assert.NilError(t, store.Save(ctx, &testpb.ExampleSingleton{})) val.Foo = "abc" val.Bar = 3 - assert.NilError(t, singleton.Save(store, val)) + assert.NilError(t, store.Save(ctx, val)) - var val2 testpb.ExampleSingleton - found, err = singleton.Get(store, &val2) + val2, err := store.Get(ctx) assert.NilError(t, err) - assert.DeepEqual(t, val, &val2, protocmp.Transform()) + assert.DeepEqual(t, val, val2, protocmp.Transform()) buf := &bytes.Buffer{} - assert.NilError(t, singleton.ExportJSON(store, buf)) - assert.NilError(t, singleton.ValidateJSON(bytes.NewReader(buf.Bytes()))) + assert.NilError(t, table.ExportJSON(ctx, buf)) + assert.NilError(t, table.ValidateJSON(bytes.NewReader(buf.Bytes()))) store2 := ormtable.WrapContextDefault(testkv.NewSplitMemBackend()) - assert.NilError(t, singleton.ImportJSON(store2, bytes.NewReader(buf.Bytes()))) + assert.NilError(t, table.ImportJSON(store2, bytes.NewReader(buf.Bytes()))) - var val3 testpb.ExampleSingleton - found, err = singleton.Get(store, &val3) + val3, err := store.Get(ctx) assert.NilError(t, err) - assert.DeepEqual(t, val, &val3, protocmp.Transform()) + assert.DeepEqual(t, val, val3, protocmp.Transform()) } diff --git a/orm/model/ormtable/table.go b/orm/model/ormtable/table.go index 32e00eb9cf..88625077d8 100644 --- a/orm/model/ormtable/table.go +++ b/orm/model/ormtable/table.go @@ -128,4 +128,15 @@ type Table interface { // ID is the ID of this table within the schema of its FileDescriptor. ID() uint32 + + Schema +} + +// Schema is an interface for things that contain tables and can encode and +// decode kv-store pairs. +type Schema interface { + ormkv.EntryCodec + + // GetTable returns the table for the provided message type or nil. + GetTable(message proto.Message) Table } diff --git a/orm/model/ormtable/table_impl.go b/orm/model/ormtable/table_impl.go index 95c76021c0..0ccda1eb90 100644 --- a/orm/model/ormtable/table_impl.go +++ b/orm/model/ormtable/table_impl.go @@ -32,6 +32,13 @@ type tableImpl struct { customJSONValidator func(message proto.Message) error } +func (t *tableImpl) GetTable(message proto.Message) Table { + if message.ProtoReflect().Descriptor().FullName() == t.MessageType().Descriptor().FullName() { + return t + } + return nil +} + func (t tableImpl) PrimaryKey() UniqueIndex { return t.primaryKeyIndex } @@ -399,6 +406,7 @@ func (t tableImpl) Get(ctx context.Context, message proto.Message) (found bool, } var _ Table = &tableImpl{} +var _ Schema = &tableImpl{} type saveMode int diff --git a/orm/model/ormtable/table_test.go b/orm/model/ormtable/table_test.go index abb18e7a00..02813a1b7d 100644 --- a/orm/model/ormtable/table_test.go +++ b/orm/model/ormtable/table_test.go @@ -80,6 +80,7 @@ func checkEncodeDecodeEntries(t *testing.T, table ormtable.Table, store kv.Reado func runTestScenario(t *testing.T, table ormtable.Table, backend ormtable.Backend) { ctx := ormtable.WrapContextDefault(backend) + store, err := testpb.NewExampleTableStore(table) // let's create 10 data items we'll use later and give them indexes data := []*testpb.ExampleTable{ @@ -110,50 +111,50 @@ func runTestScenario(t *testing.T, table ormtable.Table, backend ormtable.Backen } // insert one record - err := table.Insert(ctx, data[0]) + err = store.Insert(ctx, data[0]) // trivial prefix query has one record - it, err := table.Iterator(ctx) + it, err := store.List(ctx, testpb.ExampleTablePrimaryKey{}) assert.NilError(t, err) assertIteratorItems(it, 0) // insert one record - err = table.Insert(ctx, data[1]) + err = store.Insert(ctx, data[1]) // trivial prefix query has two records - it, err = table.Iterator(ctx) + it, err = store.List(ctx, testpb.ExampleTablePrimaryKey{}) assert.NilError(t, err) assertIteratorItems(it, 0, 1) // insert the other records assert.NilError(t, err) for i := 2; i < len(data); i++ { - err = table.Insert(ctx, data[i]) + err = store.Insert(ctx, data[i]) assert.NilError(t, err) } // let's do a prefix query on the primary key - it, err = table.Iterator(ctx, ormlist.Prefix(uint32(8))) + it, err = store.List(ctx, testpb.ExampleTablePrimaryKey{}.WithU32(8)) assert.NilError(t, err) assertIteratorItems(it, 7, 8, 9) // let's try a reverse prefix query - it, err = table.Iterator(ctx, ormlist.Prefix(uint32(4)), ormlist.Reverse()) + it, err = store.List(ctx, testpb.ExampleTablePrimaryKey{}.WithU32(4), ormlist.Reverse()) assert.NilError(t, err) defer it.Close() assertIteratorItems(it, 2, 1, 0) // let's try a range query - it, err = table.Iterator(ctx, - ormlist.Start(uint32(4), int64(-1)), - ormlist.End(uint32(7)), + it, err = store.ListRange(ctx, + testpb.ExampleTablePrimaryKey{}.WithU32I64(4, -1), + testpb.ExampleTablePrimaryKey{}.WithU32(7), ) assert.NilError(t, err) defer it.Close() assertIteratorItems(it, 2, 3, 4, 5, 6) // and another range query - it, err = table.Iterator(ctx, - ormlist.Start(uint32(5), int64(-3)), - ormlist.End(uint32(8), int64(1), "abc"), + it, err = store.ListRange(ctx, + testpb.ExampleTablePrimaryKey{}.WithU32I64(5, -3), + testpb.ExampleTablePrimaryKey{}.WithU32I64Str(8, 1, "abc"), ) assert.NilError(t, err) defer it.Close() @@ -162,31 +163,33 @@ func runTestScenario(t *testing.T, table ormtable.Table, backend ormtable.Backen // now a reverse range query on a different index strU32Index := table.GetIndex("str,u32") assert.Assert(t, strU32Index != nil) - it, err = strU32Index.Iterator(ctx, - ormlist.Start("abc"), - ormlist.End("abd"), + it, err = store.ListRange(ctx, + testpb.ExampleTableStrU32IndexKey{}.WithStr("abc"), + testpb.ExampleTableStrU32IndexKey{}.WithStr("abd"), ormlist.Reverse(), ) assertIteratorItems(it, 9, 3, 1, 8, 7, 2, 0) // another prefix query forwards - it, err = strU32Index.Iterator(ctx, ormlist.Prefix("abe", uint32(7))) + + it, err = store.List(ctx, + testpb.ExampleTableStrU32IndexKey{}.WithStrU32("abe", 7), + ) assertIteratorItems(it, 5, 6) // and backwards - it, err = strU32Index.Iterator(ctx, ormlist.Prefix("abc", uint32(4)), ormlist.Reverse()) + it, err = store.List(ctx, + testpb.ExampleTableStrU32IndexKey{}.WithStrU32("abc", 4), + ormlist.Reverse(), + ) assertIteratorItems(it, 2, 0) // try an unique index - u64StrIndex := table.GetUniqueIndex("u64,str") - assert.Assert(t, u64StrIndex != nil) - found, err := u64StrIndex.Has(ctx, uint64(12), "abc") + found, err := store.HasByU64Str(ctx, 12, "abc") assert.NilError(t, err) assert.Assert(t, found) - var a testpb.ExampleTable - found, err = u64StrIndex.Get(ctx, &a, uint64(12), "abc") + a, err := store.GetByU64Str(ctx, 12, "abc") assert.NilError(t, err) - assert.Assert(t, found) - assert.DeepEqual(t, data[8], &a, protocmp.Transform()) + assert.DeepEqual(t, data[8], a, protocmp.Transform()) // let's try paginating some stuff @@ -359,32 +362,32 @@ func runTestScenario(t *testing.T, table ormtable.Table, backend ormtable.Backen for i := 0; i < 5; i++ { data[i].U64 = data[i].U64 * 2 data[i].Bz = []byte(data[i].Str) - err = table.Update(ctx, data[i]) + err = store.Update(ctx, data[i]) assert.NilError(t, err) } - it, err = table.Iterator(ctx) + it, err = store.List(ctx, testpb.ExampleTablePrimaryKey{}) assert.NilError(t, err) // we should still get everything in the same order assertIteratorItems(it, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) // let's use SAVE_MODE_DEFAULT and add something data = append(data, &testpb.ExampleTable{U32: 9}) - err = table.Save(ctx, data[10]) + err = store.Save(ctx, data[10]) assert.NilError(t, err) pkIndex := table.GetUniqueIndex("u32,i64,str") - found, err = pkIndex.Get(ctx, &a, uint32(9), int64(0), "") + a, err = store.Get(ctx, 9, 0, "") assert.NilError(t, err) - assert.Assert(t, found) - assert.DeepEqual(t, data[10], &a, protocmp.Transform()) + assert.Assert(t, a != nil) + assert.DeepEqual(t, data[10], a, protocmp.Transform()) // and update it data[10].B = true assert.NilError(t, table.Save(ctx, data[10])) - found, err = pkIndex.Get(ctx, &a, uint32(9), int64(0), "") + a, err = store.Get(ctx, 9, 0, "") assert.NilError(t, err) - assert.Assert(t, found) - assert.DeepEqual(t, data[10], &a, protocmp.Transform()) + assert.Assert(t, a != nil) + assert.DeepEqual(t, data[10], a, protocmp.Transform()) // and iterate - it, err = table.Iterator(ctx) + it, err = store.List(ctx, testpb.ExampleTablePrimaryKey{}) assert.NilError(t, err) assertIteratorItems(it, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) @@ -409,7 +412,7 @@ func runTestScenario(t *testing.T, table ormtable.Table, backend ormtable.Backen assert.NilError(t, err) assert.Assert(t, !found) // and missing from the iterator - it, err = table.Iterator(ctx) + it, err = store.List(ctx, testpb.ExampleTablePrimaryKey{}) assert.NilError(t, err) assertIteratorItems(it, 0, 1, 2, 3, 4, 6, 7, 8, 9, 10) }