diff --git a/orm/encoding/ormkv/codec.go b/orm/encoding/ormkv/codec.go new file mode 100644 index 0000000000..4ec62c891f --- /dev/null +++ b/orm/encoding/ormkv/codec.go @@ -0,0 +1,28 @@ +package ormkv + +import "google.golang.org/protobuf/reflect/protoreflect" + +// EntryCodec defines an interfaces for decoding and encoding entries in the +// kv-store backing an ORM instance. EntryCodec's enable full logical decoding +// of ORM data. +type EntryCodec interface { + + // DecodeEntry decodes a kv-pair into an Entry. + DecodeEntry(k, v []byte) (Entry, error) + + // EncodeEntry encodes an entry into a kv-pair. + EncodeEntry(entry Entry) (k, v []byte, err error) +} + +// IndexCodec defines an interfaces for encoding and decoding index-keys in the +// kv-store. +type IndexCodec interface { + EntryCodec + + // DecodeIndexKey decodes a kv-pair into index-fields and primary-key field + // values. These fields may or may not overlap depending on the index. + DecodeIndexKey(k, v []byte) (indexFields, primaryKey []protoreflect.Value, err error) + + // EncodeKVFromMessage encodes a kv-pair for the index from a message. + EncodeKVFromMessage(message protoreflect.Message) (k, v []byte, err error) +} diff --git a/orm/encoding/ormkv/entry.go b/orm/encoding/ormkv/entry.go new file mode 100644 index 0000000000..c1cd9793ed --- /dev/null +++ b/orm/encoding/ormkv/entry.go @@ -0,0 +1,152 @@ +package ormkv + +import ( + "fmt" + "strings" + + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/structpb" +) + +// Entry defines a logical representation of a kv-store entry for ORM instances. +type Entry interface { + fmt.Stringer + + // GetTableName returns the table-name (equivalent to the fully-qualified + // proto message name) this entry corresponds to. + GetTableName() protoreflect.FullName + + // to allow new methods to be added without breakage, this interface + // shouldn't be implemented outside this package, + // see https://go.dev/blog/module-compatibility + doNotImplement() +} + +// PrimaryKeyEntry represents a logically decoded primary-key entry. +type PrimaryKeyEntry struct { + + // TableName is the table this entry represents. + TableName protoreflect.FullName + + // Key represents the primary key values. + Key []protoreflect.Value + + // Value represents the message stored under the primary key. + Value proto.Message +} + +func (p *PrimaryKeyEntry) GetTableName() protoreflect.FullName { + return p.TableName +} + +func (p *PrimaryKeyEntry) String() string { + msg := p.Value + msgStr := "_" + if msg != nil { + msgBz, err := protojson.Marshal(msg) + if err == nil { + msgStr = string(msgBz) + } else { + msgStr = fmt.Sprintf("ERR:%v", err) + } + } + return fmt.Sprintf("PK:%s/%s:%s", p.TableName, fmtValues(p.Key), msgStr) +} + +func fmtValues(values []protoreflect.Value) string { + if len(values) == 0 { + return "_" + } + + parts := make([]string, len(values)) + for i, v := range values { + val, err := structpb.NewValue(v.Interface()) + if err != nil { + parts[i] = "ERR" + continue + } + + bz, err := protojson.Marshal(val) + if err != nil { + parts[i] = "ERR" + continue + } + + parts[i] = string(bz) + } + + return strings.Join(parts, "/") +} + +func (p *PrimaryKeyEntry) doNotImplement() {} + +// IndexKeyEntry represents a logically decoded index entry. +type IndexKeyEntry struct { + + // TableName is the table this entry represents. + TableName protoreflect.FullName + + // Fields are the index fields this entry represents. + Fields []protoreflect.Name + + // IsUnique indicates whether this index is unique or not. + IsUnique bool + + // IndexValues represent the index values. + IndexValues []protoreflect.Value + + // PrimaryKey represents the primary key values, it is empty if this is a + // prefix key + PrimaryKey []protoreflect.Value +} + +func (i *IndexKeyEntry) GetTableName() protoreflect.FullName { + return i.TableName +} + +func (i *IndexKeyEntry) doNotImplement() {} + +func (i *IndexKeyEntry) string() string { + return fmt.Sprintf("%s/%s:%s:%s", i.TableName, fmtFields(i.Fields), fmtValues(i.IndexValues), fmtValues(i.PrimaryKey)) +} + +func fmtFields(fields []protoreflect.Name) string { + strs := make([]string, len(fields)) + for i, field := range fields { + strs[i] = string(field) + } + return strings.Join(strs, "/") +} + +func (i *IndexKeyEntry) String() string { + if i.IsUnique { + return fmt.Sprintf("UNIQ:%s", i.string()) + } else { + + return fmt.Sprintf("IDX:%s", i.string()) + } +} + +// SeqEntry represents a sequence for tables with auto-incrementing primary keys. +type SeqEntry struct { + + // TableName is the table this entry represents. + TableName protoreflect.FullName + + // Value is the uint64 value stored for this sequence. + Value uint64 +} + +func (s *SeqEntry) GetTableName() protoreflect.FullName { + return s.TableName +} + +func (s *SeqEntry) doNotImplement() {} + +func (s *SeqEntry) String() string { + return fmt.Sprintf("SEQ:%s:%d", s.TableName, s.Value) +} + +var _, _, _ Entry = &PrimaryKeyEntry{}, &IndexKeyEntry{}, &SeqEntry{} diff --git a/orm/encoding/ormkv/entry_test.go b/orm/encoding/ormkv/entry_test.go new file mode 100644 index 0000000000..24df6082f1 --- /dev/null +++ b/orm/encoding/ormkv/entry_test.go @@ -0,0 +1,77 @@ +package ormkv_test + +import ( + "testing" + + "google.golang.org/protobuf/reflect/protoreflect" + + "github.com/cosmos/cosmos-sdk/orm/encoding/ormkv" + + "gotest.tools/v3/assert" + + "github.com/cosmos/cosmos-sdk/orm/internal/testpb" + "github.com/cosmos/cosmos-sdk/orm/internal/testutil" +) + +var aFullName = (&testpb.A{}).ProtoReflect().Descriptor().FullName() + +func TestPrimaryKeyEntry(t *testing.T) { + entry := &ormkv.PrimaryKeyEntry{ + TableName: aFullName, + Key: testutil.ValuesOf(uint32(1), "abc"), + Value: &testpb.A{I32: -1}, + } + assert.Equal(t, `PK:testpb.A/1/"abc":{"i32":-1}`, entry.String()) + assert.Equal(t, aFullName, entry.GetTableName()) + + // prefix key + entry = &ormkv.PrimaryKeyEntry{ + TableName: aFullName, + Key: testutil.ValuesOf(uint32(1), "abc"), + Value: nil, + } + assert.Equal(t, `PK:testpb.A/1/"abc":_`, entry.String()) + assert.Equal(t, aFullName, entry.GetTableName()) +} + +func TestIndexKeyEntry(t *testing.T) { + entry := &ormkv.IndexKeyEntry{ + TableName: aFullName, + Fields: []protoreflect.Name{"u32", "i32", "str"}, + IsUnique: false, + IndexValues: testutil.ValuesOf(uint32(10), int32(-1), "abc"), + PrimaryKey: testutil.ValuesOf("abc", int32(-1)), + } + assert.Equal(t, `IDX:testpb.A/u32/i32/str:10/-1/"abc":"abc"/-1`, entry.String()) + assert.Equal(t, aFullName, entry.GetTableName()) + + entry = &ormkv.IndexKeyEntry{ + TableName: aFullName, + Fields: []protoreflect.Name{"u32"}, + IsUnique: true, + IndexValues: testutil.ValuesOf(uint32(10)), + PrimaryKey: testutil.ValuesOf("abc", int32(-1)), + } + assert.Equal(t, `UNIQ:testpb.A/u32:10:"abc"/-1`, entry.String()) + assert.Equal(t, aFullName, entry.GetTableName()) + + // prefix key + entry = &ormkv.IndexKeyEntry{ + TableName: aFullName, + Fields: []protoreflect.Name{"u32", "i32", "str"}, + IsUnique: false, + IndexValues: testutil.ValuesOf(uint32(10), int32(-1)), + } + assert.Equal(t, `IDX:testpb.A/u32/i32/str:10/-1:_`, entry.String()) + assert.Equal(t, aFullName, entry.GetTableName()) + + // prefix key + entry = &ormkv.IndexKeyEntry{ + TableName: aFullName, + Fields: []protoreflect.Name{"str", "i32"}, + IsUnique: true, + IndexValues: testutil.ValuesOf("abc", int32(1)), + } + assert.Equal(t, `UNIQ:testpb.A/str/i32:"abc"/1:_`, entry.String()) + assert.Equal(t, aFullName, entry.GetTableName()) +} diff --git a/orm/encoding/ormkv/index_key.go b/orm/encoding/ormkv/index_key.go new file mode 100644 index 0000000000..7f2ae36b6b --- /dev/null +++ b/orm/encoding/ormkv/index_key.go @@ -0,0 +1,120 @@ +package ormkv + +import ( + "bytes" + "io" + + "github.com/cosmos/cosmos-sdk/orm/types/ormerrors" + + "google.golang.org/protobuf/reflect/protoreflect" +) + +// IndexKeyCodec is the codec for (non-unique) index keys. +type IndexKeyCodec struct { + *KeyCodec + tableName protoreflect.FullName + pkFieldOrder []int +} + +var _ IndexCodec = &IndexKeyCodec{} + +// NewIndexKeyCodec creates a new IndexKeyCodec with an optional prefix for the +// provided message descriptor, index and primary key fields. +func NewIndexKeyCodec(prefix []byte, messageDescriptor protoreflect.MessageDescriptor, indexFields, primaryKeyFields []protoreflect.Name) (*IndexKeyCodec, error) { + indexFieldMap := map[protoreflect.Name]int{} + + keyFields := make([]protoreflect.Name, 0, len(indexFields)+len(primaryKeyFields)) + for i, f := range indexFields { + indexFieldMap[f] = i + keyFields = append(keyFields, f) + } + + numIndexFields := len(indexFields) + numPrimaryKeyFields := len(primaryKeyFields) + pkFieldOrder := make([]int, numPrimaryKeyFields) + k := 0 + for j, f := range primaryKeyFields { + if i, ok := indexFieldMap[f]; ok { + pkFieldOrder[j] = i + continue + } + keyFields = append(keyFields, f) + pkFieldOrder[j] = numIndexFields + k + k++ + } + + cdc, err := NewKeyCodec(prefix, messageDescriptor, keyFields) + if err != nil { + return nil, err + } + + return &IndexKeyCodec{ + KeyCodec: cdc, + pkFieldOrder: pkFieldOrder, + tableName: messageDescriptor.FullName(), + }, nil +} + +func (cdc IndexKeyCodec) DecodeIndexKey(k, _ []byte) (indexFields, primaryKey []protoreflect.Value, err error) { + + values, err := cdc.Decode(bytes.NewReader(k)) + // got prefix key + if err == io.EOF { + return values, nil, nil + } else if err != nil { + return nil, nil, err + } + + // got prefix key + if len(values) < len(cdc.fieldCodecs) { + return values, nil, nil + } + + numPkFields := len(cdc.pkFieldOrder) + pkValues := make([]protoreflect.Value, numPkFields) + + for i := 0; i < numPkFields; i++ { + pkValues[i] = values[cdc.pkFieldOrder[i]] + } + + return values, pkValues, nil +} + +func (cdc IndexKeyCodec) DecodeEntry(k, v []byte) (Entry, error) { + idxValues, pk, err := cdc.DecodeIndexKey(k, v) + if err != nil { + return nil, err + } + + return &IndexKeyEntry{ + TableName: cdc.tableName, + Fields: cdc.fieldNames, + IndexValues: idxValues, + PrimaryKey: pk, + }, nil +} + +func (cdc IndexKeyCodec) EncodeEntry(entry Entry) (k, v []byte, err error) { + indexEntry, ok := entry.(*IndexKeyEntry) + if !ok { + return nil, nil, ormerrors.BadDecodeEntry + } + + if indexEntry.TableName != cdc.tableName { + return nil, nil, ormerrors.BadDecodeEntry + } + + bz, err := cdc.KeyCodec.Encode(indexEntry.IndexValues) + if err != nil { + return nil, nil, err + } + + return bz, sentinel, nil +} + +var sentinel = []byte{0} + +func (cdc IndexKeyCodec) EncodeKVFromMessage(message protoreflect.Message) (k, v []byte, err error) { + _, k, err = cdc.EncodeFromMessage(message) + return k, sentinel, err +} diff --git a/orm/encoding/ormkv/index_key_test.go b/orm/encoding/ormkv/index_key_test.go new file mode 100644 index 0000000000..afc019768b --- /dev/null +++ b/orm/encoding/ormkv/index_key_test.go @@ -0,0 +1,63 @@ +package ormkv_test + +import ( + "bytes" + "fmt" + "testing" + + "gotest.tools/v3/assert" + "pgregory.net/rapid" + + "github.com/cosmos/cosmos-sdk/orm/encoding/ormkv" + "github.com/cosmos/cosmos-sdk/orm/internal/testpb" + "github.com/cosmos/cosmos-sdk/orm/internal/testutil" +) + +func TestIndexKeyCodec(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + idxPartCdc := testutil.TestKeyCodecGen(1, 5).Draw(t, "idxPartCdc").(testutil.TestKeyCodec) + pkCodec := testutil.TestKeyCodecGen(1, 5).Draw(t, "pkCdc").(testutil.TestKeyCodec) + prefix := rapid.SliceOfN(rapid.Byte(), 0, 5).Draw(t, "prefix").([]byte) + desc := (&testpb.A{}).ProtoReflect().Descriptor() + indexKeyCdc, err := ormkv.NewIndexKeyCodec( + prefix, + desc, + idxPartCdc.Codec.GetFieldNames(), + pkCodec.Codec.GetFieldNames(), + ) + assert.NilError(t, err) + for i := 0; i < 100; i++ { + a := testutil.GenA.Draw(t, fmt.Sprintf("a%d", i)).(*testpb.A) + key := indexKeyCdc.GetValues(a.ProtoReflect()) + pk := pkCodec.Codec.GetValues(a.ProtoReflect()) + idx1 := &ormkv.IndexKeyEntry{ + TableName: desc.FullName(), + Fields: indexKeyCdc.GetFieldNames(), + IsUnique: false, + IndexValues: key, + PrimaryKey: pk, + } + k, v, err := indexKeyCdc.EncodeEntry(idx1) + assert.NilError(t, err) + + k2, v2, err := indexKeyCdc.EncodeKVFromMessage(a.ProtoReflect()) + assert.NilError(t, err) + assert.Assert(t, bytes.Equal(k, k2)) + assert.Assert(t, bytes.Equal(v, v2)) + + entry2, err := indexKeyCdc.DecodeEntry(k, v) + assert.NilError(t, err) + idx2 := entry2.(*ormkv.IndexKeyEntry) + assert.Equal(t, 0, indexKeyCdc.CompareValues(idx1.IndexValues, idx2.IndexValues)) + assert.Equal(t, 0, pkCodec.Codec.CompareValues(idx1.PrimaryKey, idx2.PrimaryKey)) + assert.Equal(t, false, idx2.IsUnique) + assert.Equal(t, desc.FullName(), idx2.TableName) + assert.DeepEqual(t, idx1.Fields, idx2.Fields) + + idxFields, pk2, err := indexKeyCdc.DecodeIndexKey(k, v) + assert.NilError(t, err) + assert.Equal(t, 0, indexKeyCdc.CompareValues(key, idxFields)) + assert.Equal(t, 0, pkCodec.Codec.CompareValues(pk, pk2)) + } + }) +} diff --git a/orm/encoding/ormkv/key_codec.go b/orm/encoding/ormkv/key_codec.go index 00e5374884..61481e20c4 100644 --- a/orm/encoding/ormkv/key_codec.go +++ b/orm/encoding/ormkv/key_codec.go @@ -20,23 +20,26 @@ type KeyCodec struct { prefix []byte fieldDescriptors []protoreflect.FieldDescriptor + fieldNames []protoreflect.Name fieldCodecs []ormfield.Codec } -// NewKeyCodec returns a new KeyCodec with the provided prefix and -// codecs for the provided fields. -func NewKeyCodec(prefix []byte, fieldDescriptors []protoreflect.FieldDescriptor) (*KeyCodec, error) { - n := len(fieldDescriptors) - var fieldCodecs []ormfield.Codec +// NewKeyCodec returns a new KeyCodec with an optional prefix for the provided +// message descriptor and fields. +func NewKeyCodec(prefix []byte, messageDescriptor protoreflect.MessageDescriptor, fieldNames []protoreflect.Name) (*KeyCodec, error) { + n := len(fieldNames) + fieldCodecs := make([]ormfield.Codec, n) + fieldDescriptors := make([]protoreflect.FieldDescriptor, n) var variableSizers []struct { cdc ormfield.Codec i int } fixedSize := 0 - names := make([]protoreflect.Name, len(fieldDescriptors)) + messageFields := messageDescriptor.Fields() + for i := 0; i < n; i++ { nonTerminal := i != n-1 - field := fieldDescriptors[i] + field := messageFields.ByName(fieldNames[i]) cdc, err := ormfield.GetCodec(field, nonTerminal) if err != nil { return nil, err @@ -49,13 +52,14 @@ func NewKeyCodec(prefix []byte, fieldDescriptors []protoreflect.FieldDescriptor) i int }{cdc, i}) } - fieldCodecs = append(fieldCodecs, cdc) - names[i] = field.Name() + fieldCodecs[i] = cdc + fieldDescriptors[i] = field } return &KeyCodec{ fieldCodecs: fieldCodecs, fieldDescriptors: fieldDescriptors, + fieldNames: fieldNames, prefix: prefix, fixedSize: fixedSize, variableSizers: variableSizers, @@ -269,3 +273,17 @@ func (cdc KeyCodec) CheckValidRangeIterationKeys(start, end []protoreflect.Value return nil } + +// GetFieldDescriptors returns the field descriptors for this codec. +func (cdc *KeyCodec) GetFieldDescriptors() []protoreflect.FieldDescriptor { + return cdc.fieldDescriptors +} + +// GetFieldNames returns the field names for this codec. +func (cdc *KeyCodec) GetFieldNames() []protoreflect.Name { + return cdc.fieldNames +} + +func (cdc *KeyCodec) Prefix() []byte { + return cdc.prefix +} diff --git a/orm/encoding/ormkv/key_codec_test.go b/orm/encoding/ormkv/key_codec_test.go index f920c97cca..9488110241 100644 --- a/orm/encoding/ormkv/key_codec_test.go +++ b/orm/encoding/ormkv/key_codec_test.go @@ -18,7 +18,7 @@ import ( func TestKeyCodec(t *testing.T) { rapid.Check(t, func(t *rapid.T) { - key := testutil.TestKeyCodecGen.Draw(t, "key").(testutil.TestKeyCodec) + key := testutil.TestKeyCodecGen(0, 5).Draw(t, "key").(testutil.TestKeyCodec) for i := 0; i < 100; i++ { keyValues := key.Draw(t, "values") @@ -45,11 +45,9 @@ func assertEncDecKey(t *rapid.T, key testutil.TestKeyCodec, keyValues []protoref } func TestCompareValues(t *testing.T) { - cdc, err := ormkv.NewKeyCodec(nil, []protoreflect.FieldDescriptor{ - testutil.GetTestField("u32"), - testutil.GetTestField("str"), - testutil.GetTestField("i32"), - }) + cdc, err := ormkv.NewKeyCodec(nil, + (&testpb.A{}).ProtoReflect().Descriptor(), + []protoreflect.Name{"u32", "str", "i32"}) assert.NilError(t, err) tests := []struct { @@ -61,113 +59,113 @@ func TestCompareValues(t *testing.T) { }{ { "eq", - ValuesOf(uint32(0), "abc", int32(-3)), - ValuesOf(uint32(0), "abc", int32(-3)), + testutil.ValuesOf(uint32(0), "abc", int32(-3)), + testutil.ValuesOf(uint32(0), "abc", int32(-3)), 0, false, }, { "eq prefix 0", - ValuesOf(), - ValuesOf(), + testutil.ValuesOf(), + testutil.ValuesOf(), 0, false, }, { "eq prefix 1", - ValuesOf(uint32(0)), - ValuesOf(uint32(0)), + testutil.ValuesOf(uint32(0)), + testutil.ValuesOf(uint32(0)), 0, false, }, { "eq prefix 2", - ValuesOf(uint32(0), "abc"), - ValuesOf(uint32(0), "abc"), + testutil.ValuesOf(uint32(0), "abc"), + testutil.ValuesOf(uint32(0), "abc"), 0, false, }, { "lt1", - ValuesOf(uint32(0), "abc", int32(-3)), - ValuesOf(uint32(1), "abc", int32(-3)), + testutil.ValuesOf(uint32(0), "abc", int32(-3)), + testutil.ValuesOf(uint32(1), "abc", int32(-3)), -1, true, }, { "lt2", - ValuesOf(uint32(1), "abb", int32(-3)), - ValuesOf(uint32(1), "abc", int32(-3)), + testutil.ValuesOf(uint32(1), "abb", int32(-3)), + testutil.ValuesOf(uint32(1), "abc", int32(-3)), -1, true, }, { "lt3", - ValuesOf(uint32(1), "abb", int32(-4)), - ValuesOf(uint32(1), "abb", int32(-3)), + testutil.ValuesOf(uint32(1), "abb", int32(-4)), + testutil.ValuesOf(uint32(1), "abb", int32(-3)), -1, true, }, { "less prefix 0", - ValuesOf(), - ValuesOf(uint32(1), "abb", int32(-4)), + testutil.ValuesOf(), + testutil.ValuesOf(uint32(1), "abb", int32(-4)), -1, true, }, { "less prefix 1", - ValuesOf(uint32(1)), - ValuesOf(uint32(1), "abb", int32(-4)), + testutil.ValuesOf(uint32(1)), + testutil.ValuesOf(uint32(1), "abb", int32(-4)), -1, true, }, { "less prefix 2", - ValuesOf(uint32(1), "abb"), - ValuesOf(uint32(1), "abb", int32(-4)), + testutil.ValuesOf(uint32(1), "abb"), + testutil.ValuesOf(uint32(1), "abb", int32(-4)), -1, true, }, { "gt1", - ValuesOf(uint32(2), "abb", int32(-4)), - ValuesOf(uint32(1), "abb", int32(-4)), + testutil.ValuesOf(uint32(2), "abb", int32(-4)), + testutil.ValuesOf(uint32(1), "abb", int32(-4)), 1, false, }, { "gt2", - ValuesOf(uint32(2), "abc", int32(-4)), - ValuesOf(uint32(2), "abb", int32(-4)), + testutil.ValuesOf(uint32(2), "abc", int32(-4)), + testutil.ValuesOf(uint32(2), "abb", int32(-4)), 1, false, }, { "gt3", - ValuesOf(uint32(2), "abc", int32(1)), - ValuesOf(uint32(2), "abc", int32(-3)), + testutil.ValuesOf(uint32(2), "abc", int32(1)), + testutil.ValuesOf(uint32(2), "abc", int32(-3)), 1, false, }, { "gt prefix 0", - ValuesOf(uint32(2), "abc", int32(-3)), - ValuesOf(), + testutil.ValuesOf(uint32(2), "abc", int32(-3)), + testutil.ValuesOf(), 1, true, }, { "gt prefix 1", - ValuesOf(uint32(2), "abc", int32(-3)), - ValuesOf(uint32(2)), + testutil.ValuesOf(uint32(2), "abc", int32(-3)), + testutil.ValuesOf(uint32(2)), 1, true, }, { "gt prefix 2", - ValuesOf(uint32(2), "abc", int32(-3)), - ValuesOf(uint32(2), "abc"), + testutil.ValuesOf(uint32(2), "abc", int32(-3)), + testutil.ValuesOf(uint32(2), "abc"), 1, true, }, @@ -189,22 +187,10 @@ func TestCompareValues(t *testing.T) { } } -func ValuesOf(values ...interface{}) []protoreflect.Value { - n := len(values) - res := make([]protoreflect.Value, n) - for i := 0; i < n; i++ { - res[i] = protoreflect.ValueOf(values[i]) - } - return res -} - func TestDecodePrefixKey(t *testing.T) { - cdc, err := ormkv.NewKeyCodec(nil, []protoreflect.FieldDescriptor{ - testutil.GetTestField("u32"), - testutil.GetTestField("str"), - testutil.GetTestField("bz"), - testutil.GetTestField("i32"), - }) + cdc, err := ormkv.NewKeyCodec(nil, + (&testpb.A{}).ProtoReflect().Descriptor(), + []protoreflect.Name{"u32", "str", "bz", "i32"}) assert.NilError(t, err) tests := []struct { @@ -213,7 +199,7 @@ func TestDecodePrefixKey(t *testing.T) { }{ { "1", - ValuesOf(uint32(5), "abc"), + testutil.ValuesOf(uint32(5), "abc"), }, } for _, test := range tests { @@ -228,12 +214,9 @@ func TestDecodePrefixKey(t *testing.T) { } func TestValidRangeIterationKeys(t *testing.T) { - cdc, err := ormkv.NewKeyCodec(nil, []protoreflect.FieldDescriptor{ - testutil.GetTestField("u32"), - testutil.GetTestField("str"), - testutil.GetTestField("bz"), - testutil.GetTestField("i32"), - }) + cdc, err := ormkv.NewKeyCodec(nil, + (&testpb.A{}).ProtoReflect().Descriptor(), + []protoreflect.Name{"u32", "str", "bz", "i32"}) assert.NilError(t, err) tests := []struct { @@ -244,62 +227,62 @@ func TestValidRangeIterationKeys(t *testing.T) { }{ { "1 eq", - ValuesOf(uint32(0)), - ValuesOf(uint32(0)), + testutil.ValuesOf(uint32(0)), + testutil.ValuesOf(uint32(0)), true, }, { "1 lt", - ValuesOf(uint32(0)), - ValuesOf(uint32(1)), + testutil.ValuesOf(uint32(0)), + testutil.ValuesOf(uint32(1)), false, }, { "1 gt", - ValuesOf(uint32(1)), - ValuesOf(uint32(0)), + testutil.ValuesOf(uint32(1)), + testutil.ValuesOf(uint32(0)), true, }, { "1,2 lt", - ValuesOf(uint32(0)), - ValuesOf(uint32(0), "abc"), + testutil.ValuesOf(uint32(0)), + testutil.ValuesOf(uint32(0), "abc"), false, }, { "1,2 gt", - ValuesOf(uint32(0), "abc"), - ValuesOf(uint32(0)), + testutil.ValuesOf(uint32(0), "abc"), + testutil.ValuesOf(uint32(0)), false, }, { "1,2,3", - ValuesOf(uint32(0)), - ValuesOf(uint32(0), "abc", []byte{1, 2}), + testutil.ValuesOf(uint32(0)), + testutil.ValuesOf(uint32(0), "abc", []byte{1, 2}), true, }, { "1,2,3,4 lt", - ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(-1)), - ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(1)), + testutil.ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(-1)), + testutil.ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(1)), false, }, { "too long", - ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(-1)), - ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(1), int32(1)), + testutil.ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(-1)), + testutil.ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(1), int32(1)), true, }, { "1,2,3,4 eq", - ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(1)), - ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(1)), + testutil.ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(1)), + testutil.ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(1)), true, }, { "1,2,3,4 bz err", - ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(-1)), - ValuesOf(uint32(0), "abc", []byte{1, 2, 3}, int32(1)), + testutil.ValuesOf(uint32(0), "abc", []byte{1, 2}, int32(-1)), + testutil.ValuesOf(uint32(0), "abc", []byte{1, 2, 3}, int32(1)), true, }, } @@ -316,15 +299,13 @@ func TestValidRangeIterationKeys(t *testing.T) { } func TestGetSet(t *testing.T) { - cdc, err := ormkv.NewKeyCodec(nil, []protoreflect.FieldDescriptor{ - testutil.GetTestField("u32"), - testutil.GetTestField("str"), - testutil.GetTestField("i32"), - }) + cdc, err := ormkv.NewKeyCodec(nil, + (&testpb.A{}).ProtoReflect().Descriptor(), + []protoreflect.Name{"u32", "str", "i32"}) assert.NilError(t, err) var a testpb.A - values := ValuesOf(uint32(4), "abc", int32(1)) + values := testutil.ValuesOf(uint32(4), "abc", int32(1)) cdc.SetValues(a.ProtoReflect(), values) values2 := cdc.GetValues(a.ProtoReflect()) assert.Equal(t, 0, cdc.CompareValues(values, values2)) diff --git a/orm/encoding/ormkv/primary_key.go b/orm/encoding/ormkv/primary_key.go new file mode 100644 index 0000000000..fd37558a7d --- /dev/null +++ b/orm/encoding/ormkv/primary_key.go @@ -0,0 +1,137 @@ +package ormkv + +import ( + "bytes" + "io" + + "github.com/cosmos/cosmos-sdk/orm/types/ormerrors" + + "google.golang.org/protobuf/proto" + + "google.golang.org/protobuf/reflect/protoreflect" +) + +// PrimaryKeyCodec is the codec for primary keys. +type PrimaryKeyCodec struct { + *KeyCodec + msgType protoreflect.MessageType + unmarshalOptions proto.UnmarshalOptions +} + +// NewPrimaryKeyCodec creates a new PrimaryKeyCodec for the provided msg and +// fields, with an optional prefix and unmarshal options. +func NewPrimaryKeyCodec(prefix []byte, msgType protoreflect.MessageType, fieldNames []protoreflect.Name, unmarshalOptions proto.UnmarshalOptions) (*PrimaryKeyCodec, error) { + keyCodec, err := NewKeyCodec(prefix, msgType.Descriptor(), fieldNames) + if err != nil { + return nil, err + } + + return &PrimaryKeyCodec{ + KeyCodec: keyCodec, + msgType: msgType, + unmarshalOptions: unmarshalOptions, + }, nil +} + +var _ IndexCodec = PrimaryKeyCodec{} + +func (p PrimaryKeyCodec) DecodeIndexKey(k, _ []byte) (indexFields, primaryKey []protoreflect.Value, err error) { + indexFields, err = p.Decode(bytes.NewReader(k)) + + // got prefix key + if err == io.EOF { + return indexFields, nil, nil + } else if err != nil { + return nil, nil, err + } + + if len(indexFields) == len(p.fieldCodecs) { + // for primary keys the index fields are the primary key + // but only if we don't have a prefix key + primaryKey = indexFields + } + return indexFields, primaryKey, nil + +} + +func (p PrimaryKeyCodec) DecodeEntry(k, v []byte) (Entry, error) { + values, err := p.Decode(bytes.NewReader(k)) + if err != nil { + return nil, err + } + + msg := p.msgType.New().Interface() + err = p.Unmarshal(values, v, msg) + + return &PrimaryKeyEntry{ + TableName: p.msgType.Descriptor().FullName(), + Key: values, + Value: msg, + }, err +} + +func (p PrimaryKeyCodec) EncodeEntry(entry Entry) (k, v []byte, err error) { + pkEntry, ok := entry.(*PrimaryKeyEntry) + if !ok { + return nil, nil, ormerrors.BadDecodeEntry.Wrapf("expected %T, got %T", &PrimaryKeyEntry{}, entry) + } + + if pkEntry.TableName != p.msgType.Descriptor().FullName() { + return nil, nil, ormerrors.BadDecodeEntry.Wrapf( + "wrong table name, got %s, expected %s", + pkEntry.TableName, + p.msgType.Descriptor().FullName(), + ) + } + + k, err = p.KeyCodec.Encode(pkEntry.Key) + if err != nil { + return nil, nil, err + } + + v, err = p.marshal(pkEntry.Key, pkEntry.Value) + return k, v, err +} + +func (p PrimaryKeyCodec) marshal(key []protoreflect.Value, message proto.Message) (v []byte, err error) { + // first clear the priamry key values because these are already stored in + // the key so we don't need to store them again in the value + p.ClearValues(message.ProtoReflect()) + + v, err = proto.MarshalOptions{Deterministic: true}.Marshal(message) + if err != nil { + return nil, err + } + + // set the primary key values again returning the message to its original state + p.SetValues(message.ProtoReflect(), key) + + return v, nil +} + +func (p *PrimaryKeyCodec) ClearValues(message protoreflect.Message) { + for _, f := range p.fieldDescriptors { + message.Clear(f) + } +} + +func (p *PrimaryKeyCodec) Unmarshal(key []protoreflect.Value, value []byte, message proto.Message) error { + err := p.unmarshalOptions.Unmarshal(value, message) + if err != nil { + return err + } + + // rehydrate primary key + p.SetValues(message.ProtoReflect(), key) + return nil +} + +func (p PrimaryKeyCodec) EncodeKVFromMessage(message protoreflect.Message) (k, v []byte, err error) { + ks, k, err := p.KeyCodec.EncodeFromMessage(message) + if err != nil { + return nil, nil, err + } + + v, err = p.marshal(ks, message.Interface()) + return k, v, err +} diff --git a/orm/encoding/ormkv/primary_key_test.go b/orm/encoding/ormkv/primary_key_test.go new file mode 100644 index 0000000000..787c799e1a --- /dev/null +++ b/orm/encoding/ormkv/primary_key_test.go @@ -0,0 +1,60 @@ +package ormkv_test + +import ( + "bytes" + "fmt" + "testing" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" + "gotest.tools/v3/assert" + "pgregory.net/rapid" + + "github.com/cosmos/cosmos-sdk/orm/encoding/ormkv" + "github.com/cosmos/cosmos-sdk/orm/internal/testpb" + "github.com/cosmos/cosmos-sdk/orm/internal/testutil" +) + +func TestPrimaryKeyCodec(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + keyCodec := testutil.TestKeyCodecGen(0, 5).Draw(t, "keyCodec").(testutil.TestKeyCodec) + pkCodec, err := ormkv.NewPrimaryKeyCodec( + keyCodec.Codec.Prefix(), + (&testpb.A{}).ProtoReflect().Type(), + keyCodec.Codec.GetFieldNames(), + proto.UnmarshalOptions{}, + ) + assert.NilError(t, err) + for i := 0; i < 100; i++ { + a := testutil.GenA.Draw(t, fmt.Sprintf("a%d", i)).(*testpb.A) + key := keyCodec.Codec.GetValues(a.ProtoReflect()) + pk1 := &ormkv.PrimaryKeyEntry{ + TableName: aFullName, + Key: key, + Value: a, + } + k, v, err := pkCodec.EncodeEntry(pk1) + assert.NilError(t, err) + + k2, v2, err := pkCodec.EncodeKVFromMessage(a.ProtoReflect()) + assert.NilError(t, err) + assert.Assert(t, bytes.Equal(k, k2)) + assert.Assert(t, bytes.Equal(v, v2)) + + entry2, err := pkCodec.DecodeEntry(k, v) + assert.NilError(t, err) + pk2 := entry2.(*ormkv.PrimaryKeyEntry) + assert.Equal(t, 0, pkCodec.CompareValues(pk1.Key, pk2.Key)) + assert.DeepEqual(t, pk1.Value, pk2.Value, protocmp.Transform()) + + idxFields, pk3, err := pkCodec.DecodeIndexKey(k, v) + assert.NilError(t, err) + assert.Equal(t, 0, pkCodec.CompareValues(pk1.Key, pk3)) + assert.Equal(t, 0, pkCodec.CompareValues(pk1.Key, idxFields)) + + pkCodec.ClearValues(a.ProtoReflect()) + pkCodec.SetValues(a.ProtoReflect(), pk1.Key) + assert.DeepEqual(t, a, pk2.Value, protocmp.Transform()) + } + }) +} diff --git a/orm/encoding/ormkv/seq.go b/orm/encoding/ormkv/seq.go new file mode 100644 index 0000000000..def96bc2ad --- /dev/null +++ b/orm/encoding/ormkv/seq.go @@ -0,0 +1,69 @@ +package ormkv + +import ( + "bytes" + "encoding/binary" + + "github.com/cosmos/cosmos-sdk/orm/types/ormerrors" + + "google.golang.org/protobuf/reflect/protoreflect" +) + +// SeqCodec is the codec for auto-incrementing uint64 primary key sequences. +type SeqCodec struct { + tableName protoreflect.FullName + prefix []byte +} + +// NewSeqCodec creates a new SeqCodec. +func NewSeqCodec(tableName protoreflect.FullName, prefix []byte) *SeqCodec { + return &SeqCodec{tableName: tableName, prefix: prefix} +} + +var _ EntryCodec = &SeqCodec{} + +func (s SeqCodec) DecodeEntry(k, v []byte) (Entry, error) { + if !bytes.Equal(k, s.prefix) { + return nil, ormerrors.UnexpectedDecodePrefix + } + + x, err := s.DecodeValue(v) + if err != nil { + return nil, err + } + + return &SeqEntry{ + TableName: s.tableName, + Value: x, + }, nil +} + +func (s SeqCodec) EncodeEntry(entry Entry) (k, v []byte, err error) { + seqEntry, ok := entry.(*SeqEntry) + if !ok { + return nil, nil, ormerrors.BadDecodeEntry + } + + if seqEntry.TableName != s.tableName { + return nil, nil, ormerrors.BadDecodeEntry + } + + return s.prefix, s.EncodeValue(seqEntry.Value), nil +} + +func (s SeqCodec) Prefix() []byte { + return s.prefix +} + +func (s SeqCodec) EncodeValue(seq uint64) (v []byte) { + bz := make([]byte, binary.MaxVarintLen64) + n := binary.PutUvarint(bz, seq) + return bz[:n] +} + +func (s SeqCodec) DecodeValue(v []byte) (uint64, error) { + if len(v) == 0 { + return 0, nil + } + return binary.ReadUvarint(bytes.NewReader(v)) +} diff --git a/orm/encoding/ormkv/seq_test.go b/orm/encoding/ormkv/seq_test.go new file mode 100644 index 0000000000..836d39ae77 --- /dev/null +++ b/orm/encoding/ormkv/seq_test.go @@ -0,0 +1,47 @@ +package ormkv_test + +import ( + "bytes" + "testing" + + "github.com/cosmos/cosmos-sdk/orm/encoding/ormkv" + + "gotest.tools/v3/assert" + "pgregory.net/rapid" + + "github.com/cosmos/cosmos-sdk/orm/internal/testpb" +) + +func TestSeqCodec(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + prefix := rapid.SliceOfN(rapid.Byte(), 0, 5).Draw(t, "prefix").([]byte) + tableName := (&testpb.A{}).ProtoReflect().Descriptor().FullName() + cdc := ormkv.NewSeqCodec(tableName, prefix) + + seq, err := cdc.DecodeValue(nil) + assert.NilError(t, err) + assert.Equal(t, uint64(0), seq) + + seq, err = cdc.DecodeValue([]byte{}) + assert.NilError(t, err) + assert.Equal(t, uint64(0), seq) + + seq = rapid.Uint64().Draw(t, "seq").(uint64) + + v := cdc.EncodeValue(seq) + seq2, err := cdc.DecodeValue(v) + assert.NilError(t, err) + assert.Equal(t, seq, seq2) + + entry := &ormkv.SeqEntry{ + TableName: tableName, + Value: seq, + } + k, v, err := cdc.EncodeEntry(entry) + assert.NilError(t, err) + entry2, err := cdc.DecodeEntry(k, v) + assert.NilError(t, err) + assert.DeepEqual(t, entry, entry2) + assert.Assert(t, bytes.Equal(cdc.Prefix(), k)) + }) +} diff --git a/orm/encoding/ormkv/unique_key.go b/orm/encoding/ormkv/unique_key.go new file mode 100644 index 0000000000..2f5af9a694 --- /dev/null +++ b/orm/encoding/ormkv/unique_key.go @@ -0,0 +1,170 @@ +package ormkv + +import ( + "bytes" + "io" + + "github.com/cosmos/cosmos-sdk/orm/types/ormerrors" + + "google.golang.org/protobuf/reflect/protoreflect" +) + +// UniqueKeyCodec is the codec for unique indexes. +type UniqueKeyCodec struct { + tableName protoreflect.FullName + pkFieldOrder []struct { + inKey bool + i int + } + keyCodec *KeyCodec + valueCodec *KeyCodec +} + +// NewUniqueKeyCodec creates a new UniqueKeyCodec with an optional prefix for the +// provided message descriptor, index and primary key fields. +func NewUniqueKeyCodec(prefix []byte, messageDescriptor protoreflect.MessageDescriptor, indexFields, primaryKeyFields []protoreflect.Name) (*UniqueKeyCodec, error) { + keyCodec, err := NewKeyCodec(prefix, messageDescriptor, indexFields) + if err != nil { + return nil, err + } + + haveFields := map[protoreflect.Name]int{} + for i, descriptor := range keyCodec.fieldDescriptors { + haveFields[descriptor.Name()] = i + } + + var valueFields []protoreflect.Name + var pkFieldOrder []struct { + inKey bool + i int + } + k := 0 + for _, field := range primaryKeyFields { + if j, ok := haveFields[field]; ok { + pkFieldOrder = append(pkFieldOrder, struct { + inKey bool + i int + }{inKey: true, i: j}) + } else { + valueFields = append(valueFields, field) + pkFieldOrder = append(pkFieldOrder, struct { + inKey bool + i int + }{inKey: false, i: k}) + k++ + } + } + + valueCodec, err := NewKeyCodec(nil, messageDescriptor, valueFields) + if err != nil { + return nil, err + } + + return &UniqueKeyCodec{ + tableName: messageDescriptor.FullName(), + pkFieldOrder: pkFieldOrder, + keyCodec: keyCodec, + valueCodec: valueCodec, + }, nil +} + +var _ IndexCodec = &UniqueKeyCodec{} + +func (u UniqueKeyCodec) DecodeIndexKey(k, v []byte) (indexFields, primaryKey []protoreflect.Value, err error) { + ks, err := u.keyCodec.Decode(bytes.NewReader(k)) + + // got prefix key + if err == io.EOF { + return ks, nil, err + } else if err != nil { + return nil, nil, err + } + + // got prefix key + if len(ks) < len(u.keyCodec.fieldCodecs) { + return ks, nil, err + } + + vs, err := u.valueCodec.Decode(bytes.NewReader(v)) + if err != nil { + return nil, nil, err + } + + pk := u.extractPrimaryKey(ks, vs) + return ks, pk, nil +} + +func (cdc UniqueKeyCodec) extractPrimaryKey(keyValues, valueValues []protoreflect.Value) []protoreflect.Value { + numPkFields := len(cdc.pkFieldOrder) + pkValues := make([]protoreflect.Value, numPkFields) + + for i := 0; i < numPkFields; i++ { + fo := cdc.pkFieldOrder[i] + if fo.inKey { + pkValues[i] = keyValues[fo.i] + } else { + pkValues[i] = valueValues[fo.i] + } + } + + return pkValues +} + +func (u UniqueKeyCodec) DecodeEntry(k, v []byte) (Entry, error) { + idxVals, pk, err := u.DecodeIndexKey(k, v) + if err != nil { + return nil, err + } + + return &IndexKeyEntry{ + TableName: u.tableName, + Fields: u.keyCodec.fieldNames, + IsUnique: true, + IndexValues: idxVals, + PrimaryKey: pk, + }, err +} + +func (u UniqueKeyCodec) EncodeEntry(entry Entry) (k, v []byte, err error) { + indexEntry, ok := entry.(*IndexKeyEntry) + if !ok { + return nil, nil, ormerrors.BadDecodeEntry + } + k, err = u.keyCodec.Encode(indexEntry.IndexValues) + if err != nil { + return nil, nil, err + } + + n := len(indexEntry.PrimaryKey) + if n != len(u.pkFieldOrder) { + return nil, nil, ormerrors.BadDecodeEntry.Wrapf("wrong primary key length") + } + + var values []protoreflect.Value + for i := 0; i < n; i++ { + value := indexEntry.PrimaryKey[i] + fieldOrder := u.pkFieldOrder[i] + if !fieldOrder.inKey { + // goes in values because it is not present in the index key otherwise + values = append(values, value) + } else { + // does not go in values, but we need to verify that the value in index values matches the primary key value + if u.keyCodec.fieldCodecs[fieldOrder.i].Compare(value, indexEntry.IndexValues[fieldOrder.i]) != 0 { + return nil, nil, ormerrors.BadDecodeEntry.Wrapf("value in primary key does not match corresponding value in index key") + } + } + } + + v, err = u.valueCodec.Encode(values) + return k, v, err +} + +func (u UniqueKeyCodec) EncodeKVFromMessage(message protoreflect.Message) (k, v []byte, err error) { + _, k, err = u.keyCodec.EncodeFromMessage(message) + if err != nil { + return nil, nil, err + } + + _, v, err = u.valueCodec.EncodeFromMessage(message) + return k, v, err +} diff --git a/orm/encoding/ormkv/unique_key_test.go b/orm/encoding/ormkv/unique_key_test.go new file mode 100644 index 0000000000..3c113c8e08 --- /dev/null +++ b/orm/encoding/ormkv/unique_key_test.go @@ -0,0 +1,62 @@ +package ormkv_test + +import ( + "bytes" + "fmt" + "testing" + + "gotest.tools/v3/assert" + "pgregory.net/rapid" + + "github.com/cosmos/cosmos-sdk/orm/encoding/ormkv" + "github.com/cosmos/cosmos-sdk/orm/internal/testpb" + "github.com/cosmos/cosmos-sdk/orm/internal/testutil" +) + +func TestUniqueKeyCodec(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + keyCodec := testutil.TestKeyCodecGen(1, 5).Draw(t, "keyCodec").(testutil.TestKeyCodec) + pkCodec := testutil.TestKeyCodecGen(1, 5).Draw(t, "primaryKeyCodec").(testutil.TestKeyCodec) + desc := (&testpb.A{}).ProtoReflect().Descriptor() + uniqueKeyCdc, err := ormkv.NewUniqueKeyCodec( + keyCodec.Codec.Prefix(), + desc, + keyCodec.Codec.GetFieldNames(), + pkCodec.Codec.GetFieldNames(), + ) + assert.NilError(t, err) + for i := 0; i < 100; i++ { + a := testutil.GenA.Draw(t, fmt.Sprintf("a%d", i)).(*testpb.A) + key := keyCodec.Codec.GetValues(a.ProtoReflect()) + pk := pkCodec.Codec.GetValues(a.ProtoReflect()) + uniq1 := &ormkv.IndexKeyEntry{ + TableName: desc.FullName(), + Fields: keyCodec.Codec.GetFieldNames(), + IsUnique: true, + IndexValues: key, + PrimaryKey: pk, + } + k, v, err := uniqueKeyCdc.EncodeEntry(uniq1) + assert.NilError(t, err) + + k2, v2, err := uniqueKeyCdc.EncodeKVFromMessage(a.ProtoReflect()) + assert.NilError(t, err) + assert.Assert(t, bytes.Equal(k, k2)) + assert.Assert(t, bytes.Equal(v, v2)) + + entry2, err := uniqueKeyCdc.DecodeEntry(k, v) + assert.NilError(t, err) + uniq2 := entry2.(*ormkv.IndexKeyEntry) + assert.Equal(t, 0, keyCodec.Codec.CompareValues(uniq1.IndexValues, uniq2.IndexValues)) + assert.Equal(t, 0, pkCodec.Codec.CompareValues(uniq1.PrimaryKey, uniq2.PrimaryKey)) + assert.Equal(t, true, uniq2.IsUnique) + assert.Equal(t, desc.FullName(), uniq2.TableName) + assert.DeepEqual(t, uniq1.Fields, uniq2.Fields) + + idxFields, pk2, err := uniqueKeyCdc.DecodeIndexKey(k, v) + assert.NilError(t, err) + assert.Equal(t, 0, keyCodec.Codec.CompareValues(key, idxFields)) + assert.Equal(t, 0, pkCodec.Codec.CompareValues(pk, pk2)) + } + }) +} diff --git a/orm/encoding/ormkv/util.go b/orm/encoding/ormkv/util.go index 499b9f5c47..7d9fee7e27 100644 --- a/orm/encoding/ormkv/util.go +++ b/orm/encoding/ormkv/util.go @@ -10,7 +10,7 @@ func skipPrefix(r *bytes.Reader, prefix []byte) error { if n > 0 { // we skip checking the prefix for performance reasons because we assume // that it was checked by the caller - _, err := r.Seek(int64(n), io.SeekCurrent); + _, err := r.Seek(int64(n), io.SeekCurrent) return err } return nil diff --git a/orm/internal/testutil/testutil.go b/orm/internal/testutil/testutil.go index 103e19b755..04a9100e3e 100644 --- a/orm/internal/testutil/testutil.go +++ b/orm/internal/testutil/testutil.go @@ -120,31 +120,45 @@ type TestKeyCodec struct { Codec *ormkv.KeyCodec } -var TestKeyCodecGen = rapid.Custom(func(t *rapid.T) TestKeyCodec { - xs := rapid.SliceOfNDistinct(rapid.IntRange(0, len(TestFieldSpecs)-1), 0, 5, func(i int) int { return i }). - Draw(t, "fieldSpecs").([]int) +func TestFieldSpecsGen(minLen, maxLen int) *rapid.Generator { + return rapid.Custom(func(t *rapid.T) []TestFieldSpec { + xs := rapid.SliceOfNDistinct(rapid.IntRange(0, len(TestFieldSpecs)-1), minLen, maxLen, func(i int) int { return i }). + Draw(t, "fieldSpecIndexes").([]int) - var specs []TestFieldSpec - var fields []protoreflect.FieldDescriptor + var specs []TestFieldSpec - for _, x := range xs { - spec := TestFieldSpecs[x] - specs = append(specs, spec) - fields = append(fields, GetTestField(spec.FieldName)) - } + for _, x := range xs { + spec := TestFieldSpecs[x] + specs = append(specs, spec) + } - prefix := rapid.SliceOfN(rapid.Byte(), 0, 5).Draw(t, "prefix").([]byte) + return specs + }) +} - cdc, err := ormkv.NewKeyCodec(prefix, fields) - if err != nil { - panic(err) - } +func TestKeyCodecGen(minLen, maxLen int) *rapid.Generator { + return rapid.Custom(func(t *rapid.T) TestKeyCodec { + specs := TestFieldSpecsGen(minLen, maxLen).Draw(t, "fieldSpecs").([]TestFieldSpec) - return TestKeyCodec{ - Codec: cdc, - KeySpecs: specs, - } -}) + var fields []protoreflect.Name + for _, spec := range specs { + fields = append(fields, spec.FieldName) + } + + prefix := rapid.SliceOfN(rapid.Byte(), 0, 5).Draw(t, "prefix").([]byte) + + desc := (&testpb.A{}).ProtoReflect().Descriptor() + cdc, err := ormkv.NewKeyCodec(prefix, desc, fields) + if err != nil { + panic(err) + } + + return TestKeyCodec{ + Codec: cdc, + KeySpecs: specs, + } + }) +} func (k TestKeyCodec) Draw(t *rapid.T, id string) []protoreflect.Value { n := len(k.KeySpecs) @@ -154,3 +168,23 @@ func (k TestKeyCodec) Draw(t *rapid.T, id string) []protoreflect.Value { } return keyValues } + +var GenA = rapid.Custom(func(t *rapid.T) *testpb.A { + a := &testpb.A{} + ref := a.ProtoReflect() + for _, spec := range TestFieldSpecs { + field := GetTestField(spec.FieldName) + value := spec.Gen.Draw(t, string(spec.FieldName)) + ref.Set(field, protoreflect.ValueOf(value)) + } + return a +}) + +func ValuesOf(values ...interface{}) []protoreflect.Value { + n := len(values) + res := make([]protoreflect.Value, n) + for i := 0; i < n; i++ { + res[i] = protoreflect.ValueOf(values[i]) + } + return res +}