From dc20731bdd423d27611aaa64575f8f80e62b5298 Mon Sep 17 00:00:00 2001 From: Aaron Craelius Date: Thu, 9 Feb 2023 12:25:04 -0500 Subject: [PATCH] fix(orm)!: timestamp encoding doesn't handle nil values properly (#12273) Co-authored-by: Julien Robert --- orm/CHANGELOG.md | 8 +- orm/encoding/encodeutil/util.go | 7 +- orm/encoding/ormfield/bool.go | 15 +- orm/encoding/ormfield/bytes.go | 26 +++- orm/encoding/ormfield/codec_test.go | 35 +++++ orm/encoding/ormfield/enum.go | 14 +- orm/encoding/ormfield/int32.go | 5 +- orm/encoding/ormfield/int64.go | 14 +- orm/encoding/ormfield/string.go | 30 +++- orm/encoding/ormfield/timestamp.go | 220 ++++++++++++++++++++++++---- orm/encoding/ormfield/uint32.go | 12 +- orm/encoding/ormfield/uint64.go | 21 ++- orm/encoding/ormkv/key_codec.go | 9 +- orm/internal/testutil/testutil.go | 8 +- orm/model/ormtable/table_test.go | 51 ++++++- proto/cosmos/orm/v1/orm.proto | 7 +- 16 files changed, 418 insertions(+), 64 deletions(-) diff --git a/orm/CHANGELOG.md b/orm/CHANGELOG.md index 578771668f..e7abd1f389 100644 --- a/orm/CHANGELOG.md +++ b/orm/CHANGELOG.md @@ -36,6 +36,10 @@ Ref: https://keepachangelog.com/en/1.0.0/ ## [Unreleased] -### API-Breaking Changes +### API Breaking Changes -- [14822](https://github.com/cosmos/cosmos-sdk/pull/14822) Migrate to cosmossdk.io/core genesis API \ No newline at end of file +- [14822](https://github.com/cosmos/cosmos-sdk/pull/14822) Migrate to cosmossdk.io/core genesis API + +### State-machine Breaking Changes + +- [12273](https://github.com/cosmos/cosmos-sdk/pull/12273) The timestamp key encoding was reworked to properly handle nil values. Existing users will need to manually migrate their data to the new encoding before upgrading. diff --git a/orm/encoding/encodeutil/util.go b/orm/encoding/encodeutil/util.go index c57bbf6db9..d8fe8a1c2e 100644 --- a/orm/encoding/encodeutil/util.go +++ b/orm/encoding/encodeutil/util.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "io" + "reflect" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -42,7 +43,11 @@ func ValuesOf(values ...interface{}) []protoreflect.Value { value := values[i] switch value.(type) { case protoreflect.ProtoMessage: - value = value.(protoreflect.ProtoMessage).ProtoReflect() + if !reflect.ValueOf(value).IsNil() { + value = value.(protoreflect.ProtoMessage).ProtoReflect() + } else { + value = nil + } } res[i] = protoreflect.ValueOf(value) } diff --git a/orm/encoding/ormfield/bool.go b/orm/encoding/ormfield/bool.go index a1189bccdc..8fac99becb 100644 --- a/orm/encoding/ormfield/bool.go +++ b/orm/encoding/ormfield/bool.go @@ -21,17 +21,22 @@ var ( func (b BoolCodec) Encode(value protoreflect.Value, w io.Writer) error { var err error - if value.Bool() { - _, err = w.Write(oneBz) - } else { + if !value.IsValid() || !value.Bool() { _, err = w.Write(zeroBz) + } else { + _, err = w.Write(oneBz) } return err } func (b BoolCodec) Compare(v1, v2 protoreflect.Value) int { - b1 := v1.Bool() - b2 := v2.Bool() + var b1, b2 bool + if v1.IsValid() { + b1 = v1.Bool() + } + if v2.IsValid() { + b2 = v2.Bool() + } if b1 == b2 { return 0 } else if b1 { diff --git a/orm/encoding/ormfield/bytes.go b/orm/encoding/ormfield/bytes.go index 2c984b113d..0eea476b8a 100644 --- a/orm/encoding/ormfield/bytes.go +++ b/orm/encoding/ormfield/bytes.go @@ -22,6 +22,9 @@ func (b BytesCodec) ComputeBufferSize(value protoreflect.Value) (int, error) { } func bytesSize(value protoreflect.Value) int { + if !value.IsValid() { + return 0 + } return len(value.Bytes()) } @@ -35,12 +38,15 @@ func (b BytesCodec) Decode(r Reader) (protoreflect.Value, error) { } func (b BytesCodec) Encode(value protoreflect.Value, w io.Writer) error { + if !value.IsValid() { + return nil + } _, err := w.Write(value.Bytes()) return err } func (b BytesCodec) Compare(v1, v2 protoreflect.Value) int { - return bytes.Compare(v1.Bytes(), v2.Bytes()) + return compareBytes(v1, v2) } // NonTerminalBytesCodec encodes bytes as raw bytes length prefixed by a single @@ -69,7 +75,7 @@ func (b NonTerminalBytesCodec) IsOrdered() bool { } func (b NonTerminalBytesCodec) Compare(v1, v2 protoreflect.Value) int { - return bytes.Compare(v1.Bytes(), v2.Bytes()) + return compareBytes(v1, v2) } func (b NonTerminalBytesCodec) Decode(r Reader) (protoreflect.Value, error) { @@ -88,7 +94,10 @@ func (b NonTerminalBytesCodec) Decode(r Reader) (protoreflect.Value, error) { } func (b NonTerminalBytesCodec) Encode(value protoreflect.Value, w io.Writer) error { - bz := value.Bytes() + var bz []byte + if value.IsValid() { + bz = value.Bytes() + } n := len(bz) var prefix [binary.MaxVarintLen64]byte prefixLen := binary.PutUvarint(prefix[:], uint64(n)) @@ -99,3 +108,14 @@ func (b NonTerminalBytesCodec) Encode(value protoreflect.Value, w io.Writer) err _, err = w.Write(bz) return err } + +func compareBytes(v1, v2 protoreflect.Value) int { + var bz1, bz2 []byte + if v1.IsValid() { + bz1 = v1.Bytes() + } + if v2.IsValid() { + bz2 = v2.Bytes() + } + return bytes.Compare(bz1, bz2) +} diff --git a/orm/encoding/ormfield/codec_test.go b/orm/encoding/ormfield/codec_test.go index 98e67bb8a2..abab4d2d5a 100644 --- a/orm/encoding/ormfield/codec_test.go +++ b/orm/encoding/ormfield/codec_test.go @@ -4,6 +4,9 @@ import ( "bytes" "fmt" "testing" + "time" + + "google.golang.org/protobuf/types/known/timestamppb" "github.com/cosmos/cosmos-sdk/orm/encoding/ormfield" @@ -169,3 +172,35 @@ func TestCompactUInt64(t *testing.T) { assert.Equal(t, y, y2) }) } + +func TestTimestamp(t *testing.T) { + cdc := ormfield.TimestampCodec{} + + // nil value + buf := &bytes.Buffer{} + assert.NilError(t, cdc.Encode(protoreflect.Value{}, buf)) + assert.Equal(t, 1, len(buf.Bytes())) + val, err := cdc.Decode(buf) + assert.NilError(t, err) + assert.Assert(t, !val.IsValid()) + + // no nanos + ts := timestamppb.New(time.Date(2022, 1, 1, 12, 30, 15, 0, time.UTC)) + val = protoreflect.ValueOfMessage(ts.ProtoReflect()) + buf = &bytes.Buffer{} + assert.NilError(t, cdc.Encode(val, buf)) + assert.Equal(t, 6, len(buf.Bytes())) + val2, err := cdc.Decode(buf) + assert.NilError(t, err) + assert.Equal(t, 0, cdc.Compare(val, val2)) + + // nanos + ts = timestamppb.New(time.Date(2022, 1, 1, 12, 30, 15, 235809753, time.UTC)) + val = protoreflect.ValueOfMessage(ts.ProtoReflect()) + buf = &bytes.Buffer{} + assert.NilError(t, cdc.Encode(val, buf)) + assert.Equal(t, 9, len(buf.Bytes())) + val2, err = cdc.Decode(buf) + assert.NilError(t, err) + assert.Equal(t, 0, cdc.Compare(val, val2)) +} diff --git a/orm/encoding/ormfield/enum.go b/orm/encoding/ormfield/enum.go index 097a648cd1..106ac311a8 100644 --- a/orm/encoding/ormfield/enum.go +++ b/orm/encoding/ormfield/enum.go @@ -16,7 +16,10 @@ func (e EnumCodec) Decode(r Reader) (protoreflect.Value, error) { } func (e EnumCodec) Encode(value protoreflect.Value, w io.Writer) error { - x := value.Enum() + var x protoreflect.EnumNumber + if value.IsValid() { + x = value.Enum() + } buf := make([]byte, binary.MaxVarintLen32) n := binary.PutVarint(buf, int64(x)) _, err := w.Write(buf[:n]) @@ -24,8 +27,13 @@ func (e EnumCodec) Encode(value protoreflect.Value, w io.Writer) error { } func (e EnumCodec) Compare(v1, v2 protoreflect.Value) int { - x := v1.Enum() - y := v2.Enum() + var x, y protoreflect.EnumNumber + if v1.IsValid() { + x = v1.Enum() + } + if v2.IsValid() { + y = v2.Enum() + } if x == y { return 0 } else if x < y { diff --git a/orm/encoding/ormfield/int32.go b/orm/encoding/ormfield/int32.go index a3482ba862..8b2dd9331a 100644 --- a/orm/encoding/ormfield/int32.go +++ b/orm/encoding/ormfield/int32.go @@ -27,7 +27,10 @@ func (i Int32Codec) Decode(r Reader) (protoreflect.Value, error) { } func (i Int32Codec) Encode(value protoreflect.Value, w io.Writer) error { - x := value.Int() + var x int64 + if value.IsValid() { + x = value.Int() + } x += int32Offset return binary.Write(w, binary.BigEndian, uint32(x)) } diff --git a/orm/encoding/ormfield/int64.go b/orm/encoding/ormfield/int64.go index baeccce904..cbe13420d7 100644 --- a/orm/encoding/ormfield/int64.go +++ b/orm/encoding/ormfield/int64.go @@ -29,7 +29,10 @@ func (i Int64Codec) Decode(r Reader) (protoreflect.Value, error) { } func (i Int64Codec) Encode(value protoreflect.Value, w io.Writer) error { - x := value.Int() + var x int64 + if value.IsValid() { + x = value.Int() + } if x >= -1 { y := uint64(x) + int64Max + 1 return binary.Write(w, binary.BigEndian, y) @@ -57,8 +60,13 @@ func (i Int64Codec) ComputeBufferSize(protoreflect.Value) (int, error) { } func compareInt(v1, v2 protoreflect.Value) int { - x := v1.Int() - y := v2.Int() + var x, y int64 + if v1.IsValid() { + x = v1.Int() + } + if v2.IsValid() { + y = v2.Int() + } if x == y { return 0 } else if x < y { diff --git a/orm/encoding/ormfield/string.go b/orm/encoding/ormfield/string.go index 3a9955e83e..e052efab31 100644 --- a/orm/encoding/ormfield/string.go +++ b/orm/encoding/ormfield/string.go @@ -16,6 +16,10 @@ func (s StringCodec) FixedBufferSize() int { } func (s StringCodec) ComputeBufferSize(value protoreflect.Value) (int, error) { + if !value.IsValid() { + return 0, nil + } + return len(value.String()), nil } @@ -24,7 +28,7 @@ func (s StringCodec) IsOrdered() bool { } func (s StringCodec) Compare(v1, v2 protoreflect.Value) int { - return strings.Compare(v1.String(), v2.String()) + return compareStrings(v1, v2) } func (s StringCodec) Decode(r Reader) (protoreflect.Value, error) { @@ -33,7 +37,11 @@ func (s StringCodec) Decode(r Reader) (protoreflect.Value, error) { } func (s StringCodec) Encode(value protoreflect.Value, w io.Writer) error { - _, err := w.Write([]byte(value.String())) + var x string + if value.IsValid() { + x = value.String() + } + _, err := w.Write([]byte(x)) return err } @@ -54,7 +62,7 @@ func (s NonTerminalStringCodec) IsOrdered() bool { } func (s NonTerminalStringCodec) Compare(v1, v2 protoreflect.Value) int { - return strings.Compare(v1.String(), v2.String()) + return compareStrings(v1, v2) } func (s NonTerminalStringCodec) Decode(r Reader) (protoreflect.Value, error) { @@ -69,7 +77,10 @@ func (s NonTerminalStringCodec) Decode(r Reader) (protoreflect.Value, error) { } func (s NonTerminalStringCodec) Encode(value protoreflect.Value, w io.Writer) error { - str := value.String() + var str string + if value.IsValid() { + str = value.String() + } bz := []byte(str) for _, b := range bz { if b == 0 { @@ -85,3 +96,14 @@ func (s NonTerminalStringCodec) Encode(value protoreflect.Value, w io.Writer) er } var nullTerminator = []byte{0} + +func compareStrings(v1, v2 protoreflect.Value) int { + var x, y string + if v1.IsValid() { + x = v1.String() + } + if v2.IsValid() { + y = v2.String() + } + return strings.Compare(x, y) +} diff --git a/orm/encoding/ormfield/timestamp.go b/orm/encoding/ormfield/timestamp.go index fa3102921d..4342be5f8b 100644 --- a/orm/encoding/ormfield/timestamp.go +++ b/orm/encoding/ormfield/timestamp.go @@ -1,51 +1,151 @@ package ormfield import ( + "fmt" "io" "google.golang.org/protobuf/reflect/protoreflect" ) -// TimestampCodec DurationCodec encodes a google.protobuf.Timestamp value as 12 bytes using -// Int64Codec for seconds followed by Int32Codec for nanos. This allows for -// sorted iteration. +// TimestampCodec encodes google.protobuf.Timestamp values with the following +// encoding: +// - nil is encoded as []byte{0xFF} +// - seconds (which can range from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z) is encoded as 5 fixed bytes +// - nanos (which can range from 0 to 999,999,999) is encoded as: +// - []byte{0x0} for zero nanos +// - 4 fixed bytes with the bit mask 0xC0 applied to the first byte +// +// When iterating over timestamp indexes, nil values will always be ordered last. +// +// Values for seconds and nanos outside the ranges specified by google.protobuf.Timestamp will be rejected. type TimestampCodec struct{} -var ( - timestampSecondsField = timestampMsgType.Descriptor().Fields().ByName("seconds") - timestampNanosField = timestampMsgType.Descriptor().Fields().ByName("nanos") +const ( + timestampNilValue = 0xFF + timestampZeroNanosValue = 0x0 + timestampSecondsMin = -62135579038 + timestampSecondsMax = 253402318799 + timestampNanosMax = 999999999 ) -func getTimestampSecondsAndNanos(value protoreflect.Value) (protoreflect.Value, protoreflect.Value) { - msg := value.Message() - return msg.Get(timestampSecondsField), msg.Get(timestampNanosField) -} - -func (t TimestampCodec) Decode(r Reader) (protoreflect.Value, error) { - seconds, err := int64Codec.Decode(r) - if err != nil { - return protoreflect.Value{}, err - } - nanos, err := int32Codec.Decode(r) - if err != nil { - return protoreflect.Value{}, err - } - msg := timestampMsgType.New() - msg.Set(timestampSecondsField, seconds) - msg.Set(timestampNanosField, nanos) - return protoreflect.ValueOfMessage(msg), nil -} +var ( + timestampNilBz = []byte{timestampNilValue} + timestampZeroNanosBz = []byte{timestampZeroNanosValue} +) func (t TimestampCodec) Encode(value protoreflect.Value, w io.Writer) error { + // nil case + if !value.IsValid() { + _, err := w.Write(timestampNilBz) + return err + } + seconds, nanos := getTimestampSecondsAndNanos(value) - err := int64Codec.Encode(seconds, w) + secondsInt := seconds.Int() + if secondsInt < timestampSecondsMin || secondsInt > timestampSecondsMax { + return fmt.Errorf("seconds is out of range %d, must be between %d and %d", secondsInt, timestampSecondsMin, timestampSecondsMax) + } + secondsInt -= timestampSecondsMin + var secondsBz [5]byte + // write the seconds buffer from the end to the front + for i := 4; i >= 0; i-- { + secondsBz[i] = byte(secondsInt) + secondsInt >>= 8 + } + _, err := w.Write(secondsBz[:]) if err != nil { return err } - return int32Codec.Encode(nanos, w) + + nanosInt := nanos.Int() + if nanosInt == 0 { + _, err = w.Write(timestampZeroNanosBz) + return err + } + + if nanosInt < 0 || nanosInt > timestampNanosMax { + return fmt.Errorf("nanos is out of range %d, must be between %d and %d", secondsInt, 0, timestampNanosMax) + } + + var nanosBz [4]byte + for i := 3; i >= 0; i-- { + nanosBz[i] = byte(nanosInt) + nanosInt >>= 8 + } + nanosBz[0] = nanosBz[0] | 0xC0 + _, err = w.Write(nanosBz[:]) + return err +} + +func (t TimestampCodec) Decode(r Reader) (protoreflect.Value, error) { + b0, err := r.ReadByte() + if err != nil { + return protoreflect.Value{}, err + } + + if b0 == timestampNilValue { + return protoreflect.Value{}, nil + } + + var secondsBz [4]byte + n, err := r.Read(secondsBz[:]) + if err != nil { + return protoreflect.Value{}, err + } + if n < 4 { + return protoreflect.Value{}, io.EOF + } + + var seconds = int64(b0) + for i := 0; i < 4; i++ { + seconds <<= 8 + seconds |= int64(secondsBz[i]) + } + seconds += timestampSecondsMin + + msg := timestampMsgType.New() + msg.Set(timestampSecondsField, protoreflect.ValueOfInt64(seconds)) + + b0, err = r.ReadByte() + if err != nil { + return protoreflect.Value{}, err + } + + if b0 == timestampZeroNanosValue { + return protoreflect.ValueOfMessage(msg), nil + } + + var nanosBz [3]byte + n, err = r.Read(nanosBz[:]) + if err != nil { + return protoreflect.Value{}, err + } + if n < 3 { + return protoreflect.Value{}, io.EOF + } + + var nanos = int32(b0) & 0x3F // clear first two bits + for i := 0; i < 3; i++ { + nanos <<= 8 + nanos |= int32(nanosBz[i]) + } + + msg.Set(timestampNanosField, protoreflect.ValueOfInt32(nanos)) + return protoreflect.ValueOfMessage(msg), nil } func (t TimestampCodec) Compare(v1, v2 protoreflect.Value) int { + if !v1.IsValid() { + if !v2.IsValid() { + return 0 + } + return 1 + } + + if !v2.IsValid() { + return -1 + } + s1, n1 := getTimestampSecondsAndNanos(v1) s2, n2 := getTimestampSecondsAndNanos(v2) c := compareInt(s1, s2) @@ -61,9 +161,73 @@ func (t TimestampCodec) IsOrdered() bool { } func (t TimestampCodec) FixedBufferSize() int { - return 12 + return 9 } func (t TimestampCodec) ComputeBufferSize(protoreflect.Value) (int, error) { + return 9, nil +} + +// TimestampV0Codec encodes a google.protobuf.Timestamp value as 12 bytes using +// Int64Codec for seconds followed by Int32Codec for nanos. This type does not +// encode nil values correctly, but is retained in order to allow users of the +// previous encoding to successfully migrate from this encoding to the new encoding +// specified by TimestampCodec. +type TimestampV0Codec struct{} + +var ( + timestampSecondsField = timestampMsgType.Descriptor().Fields().ByName("seconds") + timestampNanosField = timestampMsgType.Descriptor().Fields().ByName("nanos") +) + +func getTimestampSecondsAndNanos(value protoreflect.Value) (protoreflect.Value, protoreflect.Value) { + msg := value.Message() + return msg.Get(timestampSecondsField), msg.Get(timestampNanosField) +} + +func (t TimestampV0Codec) Decode(r Reader) (protoreflect.Value, error) { + seconds, err := int64Codec.Decode(r) + if err != nil { + return protoreflect.Value{}, err + } + nanos, err := int32Codec.Decode(r) + if err != nil { + return protoreflect.Value{}, err + } + msg := timestampMsgType.New() + msg.Set(timestampSecondsField, seconds) + msg.Set(timestampNanosField, nanos) + return protoreflect.ValueOfMessage(msg), nil +} + +func (t TimestampV0Codec) Encode(value protoreflect.Value, w io.Writer) error { + seconds, nanos := getTimestampSecondsAndNanos(value) + err := int64Codec.Encode(seconds, w) + if err != nil { + return err + } + return int32Codec.Encode(nanos, w) +} + +func (t TimestampV0Codec) Compare(v1, v2 protoreflect.Value) int { + s1, n1 := getTimestampSecondsAndNanos(v1) + s2, n2 := getTimestampSecondsAndNanos(v2) + c := compareInt(s1, s2) + if c != 0 { + return c + } else { + return compareInt(n1, n2) + } +} + +func (t TimestampV0Codec) IsOrdered() bool { + return true +} + +func (t TimestampV0Codec) FixedBufferSize() int { + return 12 +} + +func (t TimestampV0Codec) ComputeBufferSize(protoreflect.Value) (int, error) { return t.FixedBufferSize(), nil } diff --git a/orm/encoding/ormfield/uint32.go b/orm/encoding/ormfield/uint32.go index 748808f014..0e770ad6b4 100644 --- a/orm/encoding/ormfield/uint32.go +++ b/orm/encoding/ormfield/uint32.go @@ -34,7 +34,11 @@ func (u FixedUint32Codec) Decode(r Reader) (protoreflect.Value, error) { } func (u FixedUint32Codec) Encode(value protoreflect.Value, w io.Writer) error { - return binary.Write(w, binary.BigEndian, uint32(value.Uint())) + var x uint64 + if value.IsValid() { + x = value.Uint() + } + return binary.Write(w, binary.BigEndian, uint32(x)) } // CompactUint32Codec encodes uint32 values using EncodeCompactUint32. @@ -46,7 +50,11 @@ func (c CompactUint32Codec) Decode(r Reader) (protoreflect.Value, error) { } func (c CompactUint32Codec) Encode(value protoreflect.Value, w io.Writer) error { - _, err := w.Write(EncodeCompactUint32(uint32(value.Uint()))) + var x uint64 + if value.IsValid() { + x = value.Uint() + } + _, err := w.Write(EncodeCompactUint32(uint32(x))) return err } diff --git a/orm/encoding/ormfield/uint64.go b/orm/encoding/ormfield/uint64.go index 57bf324c24..e4f6542399 100644 --- a/orm/encoding/ormfield/uint64.go +++ b/orm/encoding/ormfield/uint64.go @@ -34,12 +34,21 @@ func (u FixedUint64Codec) Decode(r Reader) (protoreflect.Value, error) { } func (u FixedUint64Codec) Encode(value protoreflect.Value, w io.Writer) error { - return binary.Write(w, binary.BigEndian, value.Uint()) + var x uint64 + if value.IsValid() { + x = value.Uint() + } + return binary.Write(w, binary.BigEndian, x) } func compareUint(v1, v2 protoreflect.Value) int { - x := v1.Uint() - y := v2.Uint() + var x, y uint64 + if v1.IsValid() { + x = v1.Uint() + } + if v2.IsValid() { + y = v2.Uint() + } if x == y { return 0 } else if x < y { @@ -58,7 +67,11 @@ func (c CompactUint64Codec) Decode(r Reader) (protoreflect.Value, error) { } func (c CompactUint64Codec) Encode(value protoreflect.Value, w io.Writer) error { - _, err := w.Write(EncodeCompactUint64(value.Uint())) + var x uint64 + if value.IsValid() { + x = value.Uint() + } + _, err := w.Write(EncodeCompactUint64(x)) return err } diff --git a/orm/encoding/ormkv/key_codec.go b/orm/encoding/ormkv/key_codec.go index 00ef51ea28..4d16a7bb9b 100644 --- a/orm/encoding/ormkv/key_codec.go +++ b/orm/encoding/ormkv/key_codec.go @@ -104,7 +104,9 @@ func (cdc *KeyCodec) EncodeKey(values []protoreflect.Value) ([]byte, error) { func (cdc *KeyCodec) GetKeyValues(message protoreflect.Message) []protoreflect.Value { res := make([]protoreflect.Value, len(cdc.fieldDescriptors)) for i, f := range cdc.fieldDescriptors { - res[i] = message.Get(f) + if message.Has(f) { + res[i] = message.Get(f) + } } return res } @@ -209,7 +211,10 @@ func (cdc KeyCodec) ComputeKeyBufferSize(values []protoreflect.Value) (int, erro // supported. func (cdc *KeyCodec) SetKeyValues(message protoreflect.Message, values []protoreflect.Value) { for i, f := range cdc.fieldDescriptors { - message.Set(f, values[i]) + value := values[i] + if value.IsValid() { + message.Set(f, value) + } } } diff --git a/orm/internal/testutil/testutil.go b/orm/internal/testutil/testutil.go index 6ea248e83b..ce27964812 100644 --- a/orm/internal/testutil/testutil.go +++ b/orm/internal/testutil/testutil.go @@ -80,6 +80,10 @@ var TestFieldSpecs = []TestFieldSpec{ { "ts", rapid.Custom(func(t *rapid.T) protoreflect.Message { + isNil := rapid.Float32().Draw(t, "isNil") + if isNil >= 0.95 { // draw a nil 5% of the time + return nil + } seconds := rapid.Int64Range(-9999999999, 9999999999).Draw(t, "seconds") nanos := rapid.Int32Range(0, 999999999).Draw(t, "nanos") return (×tamppb.Timestamp{ @@ -180,7 +184,9 @@ var GenA = rapid.Custom(func(t *rapid.T) *testpb.ExampleTable { for _, spec := range TestFieldSpecs { field := GetTestField(spec.FieldName) value := spec.Gen.Draw(t, string(spec.FieldName)) - ref.Set(field, protoreflect.ValueOf(value)) + if value != nil { + ref.Set(field, protoreflect.ValueOf(value)) + } } return a }) diff --git a/orm/model/ormtable/table_test.go b/orm/model/ormtable/table_test.go index 40a31a3228..2a094eeaac 100644 --- a/orm/model/ormtable/table_test.go +++ b/orm/model/ormtable/table_test.go @@ -100,11 +100,16 @@ func TestPaginationLimitCountTotal(t *testing.T) { assert.Equal(t, uint64(3), pr.Total) } -func TestImportedMessageIterator(t *testing.T) { +func TestTimestampIndex(t *testing.T) { table, err := ormtable.Build(ormtable.Options{ MessageType: (&testpb.ExampleTimestamp{}).ProtoReflect().Type(), }) - backend := testkv.NewSplitMemBackend() + backend := testkv.NewDebugBackend(testkv.NewSplitMemBackend(), &testkv.EntryCodecDebugger{ + EntryCodec: table, + Print: func(s string) { + t.Log(s) + }, + }) ctx := ormtable.WrapContextDefault(backend) store, err := testpb.NewExampleTimestampTable(table) assert.NilError(t, err) @@ -117,7 +122,7 @@ func TestImportedMessageIterator(t *testing.T) { assert.NilError(t, err) pastPb, middlePb, futurePb := timestamppb.New(past), timestamppb.New(middle), timestamppb.New(future) - timeOrder := [3]*timestamppb.Timestamp{pastPb, middlePb, futurePb} + timeOrder := []*timestamppb.Timestamp{pastPb, middlePb, futurePb} assert.NilError(t, store.Insert(ctx, &testpb.ExampleTimestamp{ Name: "foo", @@ -143,6 +148,46 @@ func TestImportedMessageIterator(t *testing.T) { assert.Equal(t, timeOrder[i].String(), v.Ts.String()) i++ } + + // insert a nil entry + id, err := store.InsertReturningId(ctx, &testpb.ExampleTimestamp{ + Name: "nil", + Ts: nil, + }) + assert.NilError(t, err) + + res, err := store.Get(ctx, id) + assert.Assert(t, res.Ts == nil) + + it, err = store.List(ctx, testpb.ExampleTimestampTsIndexKey{}) + assert.NilError(t, err) + + // make sure nils are ordered last + timeOrder = append(timeOrder, nil) + i = 0 + for it.Next() { + v, err := it.Value() + assert.NilError(t, err) + assert.Assert(t, v != nil) + x := timeOrder[i] + if x == nil { + assert.Assert(t, v.Ts == nil) + } else { + assert.Equal(t, x.String(), v.Ts.String()) + } + i++ + } + it.Close() + + // try iterating over just nil timestamps + it, err = store.List(ctx, testpb.ExampleTimestampTsIndexKey{}.WithTs(nil)) + assert.NilError(t, err) + assert.Assert(t, it.Next()) + res, err = it.Value() + assert.NilError(t, err) + assert.Assert(t, res.Ts == nil) + assert.Assert(t, !it.Next()) + it.Close() } // check that the ormkv.Entry's decode and encode to the same bytes diff --git a/proto/cosmos/orm/v1/orm.proto b/proto/cosmos/orm/v1/orm.proto index 389babd196..e8509392b8 100644 --- a/proto/cosmos/orm/v1/orm.proto +++ b/proto/cosmos/orm/v1/orm.proto @@ -52,8 +52,11 @@ message PrimaryKeyDescriptor { // with a 32-bit unsigned varint in non-terminal segments. // - int32, sint32, int64, sint64, sfixed32, sfixed64 are encoded as fixed width bytes with // an encoding that enables sorted iteration. - // - google.protobuf.Timestamp and google.protobuf.Duration are encoded - // as 12 bytes using an encoding that enables sorted iteration. + // - google.protobuf.Timestamp is encoded such that values with only seconds occupy 6 bytes, + // values including nanos occupy 9 bytes, and nil values occupy 1 byte. When iterating, nil + // values will always be ordered last. Seconds and nanos values must conform to the officially + // specified ranges of 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z and 0 to 999,999,999 respectively. + // - google.protobuf.Duration is encoded as 12 bytes using an encoding that enables sorted iteration. // - enum fields are encoded using varint encoding and do not support sorted // iteration. // - bool fields are encoded as a single byte 0 or 1.