codec: implement protobuf unknown fields checker (#6557)

This commit is contained in:
Emmanuel T Odeke 2020-07-29 08:31:23 -07:00 committed by GitHub
parent a4d1f306c0
commit b0c73ae994
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 13821 additions and 644 deletions

View File

@ -0,0 +1,115 @@
package unknownproto_test
import (
"sync"
"testing"
"github.com/gogo/protobuf/proto"
"github.com/cosmos/cosmos-sdk/codec/unknownproto"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
)
var n1BBlob []byte
func init() {
n1B := &testdata.Nested1B{
Id: 1,
Age: 99,
Nested: &testdata.Nested2B{
Id: 2,
Route: "Wintery route",
Fee: 99,
Nested: &testdata.Nested3B{
Id: 3,
Name: "3A this one that one there those oens",
Age: 4588,
B4: []*testdata.Nested4B{
{
Id: 4,
Age: 88,
Name: "Nested4B",
},
},
},
},
}
var err error
n1BBlob, err = proto.Marshal(n1B)
if err != nil {
panic(err)
}
}
func BenchmarkRejectUnknownFields_serial(b *testing.B) {
benchmarkRejectUnknownFields(b, false)
}
func BenchmarkRejectUnknownFields_parallel(b *testing.B) {
benchmarkRejectUnknownFields(b, true)
}
func benchmarkRejectUnknownFields(b *testing.B, parallel bool) {
b.ReportAllocs()
if !parallel {
ckr := new(unknownproto.Checker)
b.ResetTimer()
for i := 0; i < b.N; i++ {
n1A := new(testdata.Nested1A)
if err := ckr.RejectUnknownFields(n1BBlob, n1A); err == nil {
b.Fatal("expected an error")
}
b.SetBytes(int64(len(n1BBlob)))
}
} else {
var mu sync.Mutex
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ckr := new(unknownproto.Checker)
for pb.Next() {
// To simulate the conditions of multiple transactions being processed in parallel.
n1A := new(testdata.Nested1A)
if err := ckr.RejectUnknownFields(n1BBlob, n1A); err == nil {
b.Fatal("expected an error")
}
mu.Lock()
b.SetBytes(int64(len(n1BBlob)))
mu.Unlock()
}
})
}
}
func BenchmarkProtoUnmarshal_serial(b *testing.B) {
benchmarkProtoUnmarshal(b, false)
}
func BenchmarkProtoUnmarshal_parallel(b *testing.B) {
benchmarkProtoUnmarshal(b, true)
}
func benchmarkProtoUnmarshal(b *testing.B, parallel bool) {
b.ReportAllocs()
if !parallel {
for i := 0; i < b.N; i++ {
n1A := new(testdata.Nested1A)
if err := proto.Unmarshal(n1BBlob, n1A); err == nil {
b.Fatal("expected an error")
}
b.SetBytes(int64(len(n1BBlob)))
}
} else {
var mu sync.Mutex
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
n1A := new(testdata.Nested1A)
if err := proto.Unmarshal(n1BBlob, n1A); err == nil {
b.Fatal("expected an error")
}
mu.Lock()
b.SetBytes(int64(len(n1BBlob)))
mu.Unlock()
}
})
}
}

28
codec/unknownproto/doc.go Normal file
View File

@ -0,0 +1,28 @@
/*
unknownproto implements functionality to "type check" protobuf serialized byte sequences
against an expected proto.Message to report:
a) Unknown fields in the stream -- this is indicative of mismatched services, perhaps a malicious actor
b) Mismatched wire types for a field -- this is indicative of mismatched services
Its API signature is similar to proto.Unmarshal([]byte, proto.Message) as
ckr := new(unknownproto.Checker)
if err := ckr.RejectUnknownFields(protoBlob, protoMessage); err != nil {
// Handle the error.
}
and ideally should be added before invoking proto.Unmarshal, if you'd like to enforce the features mentioned above.
By default, for security we report every single field that's unknown, whether a non-critical field or not. To customize
this behavior, please create a Checker and set the AllowUnknownNonCriticals to true, for example:
ckr := &unknownproto.Checker{
AllowUnknownNonCriticals: true,
}
if err := ckr.RejectUnknownFields(protoBlob, protoMessage); err != nil {
// Handle the error.
}
*/
package unknownproto

View File

@ -0,0 +1,32 @@
package unknownproto
import (
"fmt"
"testing"
"google.golang.org/protobuf/encoding/protowire"
)
func TestWireTypeToString(t *testing.T) {
tests := []struct {
typ protowire.Type
want string
}{
{typ: 0, want: "varint"},
{typ: 1, want: "fixed64"},
{typ: 2, want: "bytes"},
{typ: 3, want: "start_group"},
{typ: 4, want: "end_group"},
{typ: 5, want: "fixed32"},
{typ: 95, want: "unknown type: 95"},
}
for _, tt := range tests {
tt := tt
t.Run(fmt.Sprintf("wireType=%d", tt.typ), func(t *testing.T) {
if g, w := wireTypeToString(tt.typ), tt.want; g != w {
t.Fatalf("Mismatch:\nGot: %q\nWant: %q\n", g, w)
}
})
}
}

View File

@ -0,0 +1,365 @@
package unknownproto
import (
"bytes"
"compress/gzip"
"errors"
"fmt"
"io/ioutil"
"reflect"
"sync"
"github.com/gogo/protobuf/proto"
"github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
"google.golang.org/protobuf/encoding/protowire"
"github.com/cosmos/cosmos-sdk/codec/types"
)
const bit11NonCritical = 1 << 10
type descriptorIface interface {
Descriptor() ([]byte, []int)
}
type Checker struct {
// AllowUnknownNonCriticals when set will skip over non-critical fields that are unknown.
AllowUnknownNonCriticals bool
}
func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
if len(b) == 0 {
return nil
}
desc, ok := msg.(descriptorIface)
if !ok {
return fmt.Errorf("%T does not have a Descriptor() method", msg)
}
fieldDescProtoFromTagNum, _, err := getDescriptorInfo(desc, msg)
if err != nil {
return err
}
for len(b) > 0 {
tagNum, wireType, n := protowire.ConsumeField(b)
if n < 0 {
return errors.New("invalid length")
}
fieldDescProto, ok := fieldDescProtoFromTagNum[int32(tagNum)]
switch {
case ok:
// Assert that the wireTypes match.
if !canEncodeType(wireType, fieldDescProto.GetType()) {
return &errMismatchedWireType{
Type: reflect.ValueOf(msg).Type().String(),
TagNum: tagNum,
GotWireType: wireType,
WantWireType: protowire.Type(fieldDescProto.WireType()),
}
}
default:
if !ckr.AllowUnknownNonCriticals || tagNum&bit11NonCritical == 0 {
// The tag is critical, so report it.
return &errUnknownField{
Type: reflect.ValueOf(msg).Type().String(),
TagNum: tagNum,
WireType: wireType,
}
}
}
// Skip over the 2 bytes that store fieldNumber and wireType bytes.
fieldBytes := b[2:n]
b = b[n:]
// An unknown but non-critical field or just a scalar type (aka *INT and BYTES like).
if fieldDescProto == nil || fieldDescProto.IsScalar() {
continue
}
protoMessageName := fieldDescProto.GetTypeName()
if protoMessageName == "" {
switch typ := fieldDescProto.GetType(); typ {
case descriptor.FieldDescriptorProto_TYPE_STRING, descriptor.FieldDescriptorProto_TYPE_BYTES:
// At this point only TYPE_STRING is expected to be unregistered, since FieldDescriptorProto.IsScalar() returns false for
// TYPE_BYTES and TYPE_STRING as per
// https://github.com/gogo/protobuf/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118
default:
return fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ)
}
continue
}
// Let's recursively traverse and typecheck the field.
if protoMessageName == ".google.protobuf.Any" {
// Firstly typecheck types.Any to ensure nothing snuck in.
if err := ckr.RejectUnknownFields(fieldBytes, (*types.Any)(nil)); err != nil {
return err
}
// And finally we can extract the TypeURL containing the protoMessageName.
any := new(types.Any)
if err := proto.Unmarshal(fieldBytes, any); err != nil {
return err
}
protoMessageName = any.TypeUrl
fieldBytes = any.Value
}
msg, err := protoMessageForTypeName(protoMessageName[1:])
if err != nil {
return err
}
if err := ckr.RejectUnknownFields(fieldBytes, msg); err != nil {
return err
}
}
return nil
}
var protoMessageForTypeNameMu sync.RWMutex
var protoMessageForTypeNameCache = make(map[string]proto.Message)
// protoMessageForTypeName takes in a fully qualified name e.g. testdata.TestVersionFD1
// and returns a corresponding empty protobuf message that serves the prototype for typechecking.
func protoMessageForTypeName(protoMessageName string) (proto.Message, error) {
protoMessageForTypeNameMu.RLock()
msg, ok := protoMessageForTypeNameCache[protoMessageName]
protoMessageForTypeNameMu.RUnlock()
if ok {
return msg, nil
}
concreteGoType := proto.MessageType(protoMessageName)
if concreteGoType == nil {
return nil, fmt.Errorf("failed to retrieve the message of type %q", protoMessageName)
}
value := reflect.New(concreteGoType).Elem()
msg, ok = value.Interface().(proto.Message)
if !ok {
return nil, fmt.Errorf("%q does not implement proto.Message", protoMessageName)
}
// Now cache it.
protoMessageForTypeNameMu.Lock()
protoMessageForTypeNameCache[protoMessageName] = msg
protoMessageForTypeNameMu.Unlock()
return msg, nil
}
// checks is a mapping of protowire.Type to supported descriptor.FieldDescriptorProto_Type.
// it is implemented this way so as to have constant time lookups and avoid the overhead
// from O(n) walking of switch. The change to using this mapping boosts throughput by about 200%.
var checks = [...]map[descriptor.FieldDescriptorProto_Type]bool{
// "0 Varint: int32, int64, uint32, uint64, sint32, sint64, bool, enum"
0: {
descriptor.FieldDescriptorProto_TYPE_INT32: true,
descriptor.FieldDescriptorProto_TYPE_INT64: true,
descriptor.FieldDescriptorProto_TYPE_UINT32: true,
descriptor.FieldDescriptorProto_TYPE_UINT64: true,
descriptor.FieldDescriptorProto_TYPE_SINT32: true,
descriptor.FieldDescriptorProto_TYPE_SINT64: true,
descriptor.FieldDescriptorProto_TYPE_BOOL: true,
descriptor.FieldDescriptorProto_TYPE_ENUM: true,
},
// "1 64-bit: fixed64, sfixed64, double"
1: {
descriptor.FieldDescriptorProto_TYPE_FIXED64: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptor.FieldDescriptorProto_TYPE_DOUBLE: true,
},
// "2 Length-delimited: string, bytes, embedded messages, packed repeated fields"
2: {
descriptor.FieldDescriptorProto_TYPE_STRING: true,
descriptor.FieldDescriptorProto_TYPE_BYTES: true,
descriptor.FieldDescriptorProto_TYPE_MESSAGE: true,
},
// "3 Start group: groups (deprecated)"
3: {
descriptor.FieldDescriptorProto_TYPE_GROUP: true,
},
// "4 End group: groups (deprecated)"
4: {
descriptor.FieldDescriptorProto_TYPE_GROUP: true,
},
// "5 32-bit: fixed32, sfixed32, float"
5: {
descriptor.FieldDescriptorProto_TYPE_FIXED32: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED32: true,
descriptor.FieldDescriptorProto_TYPE_FLOAT: true,
},
}
// canEncodeType returns true if the wireType is suitable for encoding the descriptor type.
// See https://developers.google.com/protocol-buffers/docs/encoding#structure.
func canEncodeType(wireType protowire.Type, descType descriptor.FieldDescriptorProto_Type) bool {
if iwt := int(wireType); iwt < 0 || iwt >= len(checks) {
return false
}
return checks[wireType][descType]
}
// errMismatchedWireType describes a mismatch between
// expected and got wireTypes for a specific tag number.
type errMismatchedWireType struct {
Type string
GotWireType protowire.Type
WantWireType protowire.Type
TagNum protowire.Number
}
// String implements fmt.Stringer.
func (mwt *errMismatchedWireType) String() string {
return fmt.Sprintf("Mismatched %q: {TagNum: %d, GotWireType: %q != WantWireType: %q}",
mwt.Type, mwt.TagNum, wireTypeToString(mwt.GotWireType), wireTypeToString(mwt.WantWireType))
}
// Error implements the error interface.
func (mwt *errMismatchedWireType) Error() string {
return mwt.String()
}
var _ error = (*errMismatchedWireType)(nil)
func wireTypeToString(wt protowire.Type) string {
switch wt {
case 0:
return "varint"
case 1:
return "fixed64"
case 2:
return "bytes"
case 3:
return "start_group"
case 4:
return "end_group"
case 5:
return "fixed32"
default:
return fmt.Sprintf("unknown type: %d", wt)
}
}
// errUnknownField represents an error indicating that we encountered
// a field that isn't available in the target proto.Message.
type errUnknownField struct {
Type string
TagNum protowire.Number
WireType protowire.Type
}
// String implements fmt.Stringer.
func (twt *errUnknownField) String() string {
return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}",
twt.Type, twt.TagNum, wireTypeToString(twt.WireType))
}
// Error implements the error interface.
func (twt *errUnknownField) Error() string {
return twt.String()
}
var _ error = (*errUnknownField)(nil)
var (
protoFileToDesc = make(map[string]*descriptor.FileDescriptorProto)
protoFileToDescMu sync.RWMutex
)
func unnestDesc(mdescs []*descriptor.DescriptorProto, indices []int) *descriptor.DescriptorProto {
mdesc := mdescs[indices[0]]
for _, index := range indices[1:] {
mdesc = mdesc.NestedType[index]
}
return mdesc
}
// Invoking descriptor.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow
// for every single message, thus the need for a hand-rolled custom version that's performant and cacheable.
func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescriptorProto, *descriptor.DescriptorProto, error) {
gzippedPb, indices := desc.Descriptor()
protoFileToDescMu.RLock()
cached, ok := protoFileToDesc[string(gzippedPb)]
protoFileToDescMu.RUnlock()
if ok {
return cached, unnestDesc(cached.MessageType, indices), nil
}
// Time to gunzip the content of the FileDescriptor and then proto unmarshal them.
gzr, err := gzip.NewReader(bytes.NewReader(gzippedPb))
if err != nil {
return nil, nil, err
}
protoBlob, err := ioutil.ReadAll(gzr)
if err != nil {
return nil, nil, err
}
fdesc := new(descriptor.FileDescriptorProto)
if err := proto.Unmarshal(protoBlob, fdesc); err != nil {
return nil, nil, err
}
// Now cache the FileDescriptor.
protoFileToDescMu.Lock()
protoFileToDesc[string(gzippedPb)] = fdesc
protoFileToDescMu.Unlock()
// Unnest the type if necessary.
return fdesc, unnestDesc(fdesc.MessageType, indices), nil
}
type descriptorMatch struct {
cache map[int32]*descriptor.FieldDescriptorProto
desc *descriptor.DescriptorProto
}
var descprotoCacheMu sync.RWMutex
var descprotoCache = make(map[reflect.Type]*descriptorMatch)
// getDescriptorInfo retrieves the mapping of field numbers to their respective field descriptors.
func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptor.FieldDescriptorProto, *descriptor.DescriptorProto, error) {
key := reflect.ValueOf(msg).Type()
descprotoCacheMu.RLock()
got, ok := descprotoCache[key]
descprotoCacheMu.RUnlock()
if ok {
return got.cache, got.desc, nil
}
// Now compute and cache the index.
_, md, err := extractFileDescMessageDesc(desc)
if err != nil {
return nil, nil, err
}
tagNumToTypeIndex := make(map[int32]*descriptor.FieldDescriptorProto)
for _, field := range md.Field {
tagNumToTypeIndex[field.GetNumber()] = field
}
descprotoCacheMu.Lock()
descprotoCache[key] = &descriptorMatch{
cache: tagNumToTypeIndex,
desc: md,
}
descprotoCacheMu.Unlock()
return tagNumToTypeIndex, md, nil
}

View File

@ -0,0 +1,756 @@
package unknownproto
import (
"reflect"
"testing"
"github.com/gogo/protobuf/proto"
"github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
)
func TestRejectUnknownFieldsRepeated(t *testing.T) {
tests := []struct {
name string
in proto.Message
recv proto.Message
wantErr error
allowUnknownNonCriticals bool
}{
{
name: "Unknown field in midst of repeated values",
in: &testdata.TestVersion2{
C: []*testdata.TestVersion2{
{
C: []*testdata.TestVersion2{
{
Sum: &testdata.TestVersion2_F{
F: &testdata.TestVersion2{
A: &testdata.TestVersion2{
B: &testdata.TestVersion2{
H: []*testdata.TestVersion2{
{
X: 0x01,
},
},
},
},
},
},
},
{
Sum: &testdata.TestVersion2_F{
F: &testdata.TestVersion2{
A: &testdata.TestVersion2{
B: &testdata.TestVersion2{
H: []*testdata.TestVersion2{
{
X: 0x02,
},
},
},
},
},
},
},
{
Sum: &testdata.TestVersion2_F{
F: &testdata.TestVersion2{
NewField: 411,
},
},
},
},
},
},
},
recv: new(testdata.TestVersion1),
wantErr: &errUnknownField{
Type: "*testdata.TestVersion1",
TagNum: 25,
WireType: 0,
},
},
{
name: "Unknown field in midst of repeated values, allowUnknownNonCriticals set",
allowUnknownNonCriticals: true,
in: &testdata.TestVersion2{
C: []*testdata.TestVersion2{
{
C: []*testdata.TestVersion2{
{
Sum: &testdata.TestVersion2_F{
F: &testdata.TestVersion2{
A: &testdata.TestVersion2{
B: &testdata.TestVersion2{
H: []*testdata.TestVersion2{
{
X: 0x01,
},
},
},
},
},
},
},
{
Sum: &testdata.TestVersion2_F{
F: &testdata.TestVersion2{
A: &testdata.TestVersion2{
B: &testdata.TestVersion2{
H: []*testdata.TestVersion2{
{
X: 0x02,
},
},
},
},
},
},
},
{
Sum: &testdata.TestVersion2_F{
F: &testdata.TestVersion2{
NewField: 411,
},
},
},
},
},
},
},
recv: new(testdata.TestVersion1),
wantErr: &errUnknownField{
Type: "*testdata.TestVersion1",
TagNum: 25,
WireType: 0,
},
},
{
name: "Unknown field in midst of repeated values, non-critical field to be rejected",
in: &testdata.TestVersion3{
C: []*testdata.TestVersion3{
{
C: []*testdata.TestVersion3{
{
Sum: &testdata.TestVersion3_F{
F: &testdata.TestVersion3{
A: &testdata.TestVersion3{
B: &testdata.TestVersion3{
X: 0x01,
},
},
},
},
},
{
Sum: &testdata.TestVersion3_F{
F: &testdata.TestVersion3{
A: &testdata.TestVersion3{
B: &testdata.TestVersion3{
X: 0x02,
},
},
},
},
},
{
Sum: &testdata.TestVersion3_F{
F: &testdata.TestVersion3{
NonCriticalField: "non-critical",
},
},
},
},
},
},
},
recv: new(testdata.TestVersion1),
wantErr: &errUnknownField{
Type: "*testdata.TestVersion1",
TagNum: 1031,
WireType: 2,
},
},
{
name: "Unknown field in midst of repeated values, non-critical field ignored",
allowUnknownNonCriticals: true,
in: &testdata.TestVersion3{
C: []*testdata.TestVersion3{
{
C: []*testdata.TestVersion3{
{
Sum: &testdata.TestVersion3_F{
F: &testdata.TestVersion3{
A: &testdata.TestVersion3{
B: &testdata.TestVersion3{
X: 0x01,
},
},
},
},
},
{
Sum: &testdata.TestVersion3_F{
F: &testdata.TestVersion3{
A: &testdata.TestVersion3{
B: &testdata.TestVersion3{
X: 0x02,
},
},
},
},
},
{
Sum: &testdata.TestVersion3_F{
F: &testdata.TestVersion3{
NonCriticalField: "non-critical",
},
},
},
},
},
},
},
recv: new(testdata.TestVersion1),
wantErr: nil,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
protoBlob, err := proto.Marshal(tt.in)
if err != nil {
t.Fatal(err)
}
ckr := &Checker{AllowUnknownNonCriticals: tt.allowUnknownNonCriticals}
gotErr := ckr.RejectUnknownFields(protoBlob, tt.recv)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%v\n\nWant:\n%v", gotErr, tt.wantErr)
}
})
}
}
func TestRejectUnknownFields_allowUnknownNonCriticals(t *testing.T) {
tests := []struct {
name string
in proto.Message
allowUnknownNonCriticals bool
wantErr error
}{
{
name: "Field that's in the reserved range, should fail by default",
in: &testdata.Customer2{
Id: 289,
Reserved: 99,
},
wantErr: &errUnknownField{
Type: "*testdata.Customer1",
TagNum: 1047,
WireType: 0,
},
},
{
name: "Field that's in the reserved range, toggle allowUnknownNonCriticals",
allowUnknownNonCriticals: true,
in: &testdata.Customer2{
Id: 289,
Reserved: 99,
},
wantErr: nil,
},
{
name: "Unkown fields that are critical, but with allowUnknownNonCriticals set",
allowUnknownNonCriticals: true,
in: &testdata.Customer2{
Id: 289,
City: testdata.Customer2_PaloAlto,
},
wantErr: &errUnknownField{
Type: "*testdata.Customer1",
TagNum: 6,
WireType: 0,
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
blob, err := proto.Marshal(tt.in)
if err != nil {
t.Fatalf("Failed to marshal input: %v", err)
}
ckr := &Checker{AllowUnknownNonCriticals: tt.allowUnknownNonCriticals}
c1 := new(testdata.Customer1)
gotErr := ckr.RejectUnknownFields(blob, c1)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
}
})
}
}
func TestRejectUnknownFieldsNested(t *testing.T) {
tests := []struct {
name string
in proto.Message
recv proto.Message
wantErr error
}{
{
name: "TestVersion3 from TestVersionFD1",
in: &testdata.TestVersion2{
X: 5,
Sum: &testdata.TestVersion2_E{
E: 100,
},
H: []*testdata.TestVersion2{
{X: 999},
{X: -55},
{
X: 102,
Sum: &testdata.TestVersion2_F{
F: &testdata.TestVersion2{
X: 4,
},
},
},
},
Customer1: &testdata.Customer1{
Id: 45,
Name: "customer1",
SubscriptionFee: 99,
},
},
recv: new(testdata.TestVersionFD1),
wantErr: &errUnknownField{
Type: "*testdata.TestVersionFD1",
TagNum: 12,
WireType: 2,
},
},
{
name: "Alternating oneofs",
in: &testdata.TestVersion3{
Sum: &testdata.TestVersion3_E{
E: 99,
},
},
recv: new(testdata.TestVersion3LoneOneOfValue),
wantErr: nil,
},
{
name: "Alternating oneofs mismatched field",
in: &testdata.TestVersion3{
Sum: &testdata.TestVersion3_F{
F: &testdata.TestVersion3{
X: 99,
},
},
},
recv: new(testdata.TestVersion3LoneOneOfValue),
wantErr: &errUnknownField{
Type: "*testdata.TestVersion3LoneOneOfValue",
TagNum: 7,
WireType: 2,
},
},
{
name: "Discrepancy in a deeply nested one of field",
in: &testdata.TestVersion3{
Sum: &testdata.TestVersion3_F{
F: &testdata.TestVersion3{
Sum: &testdata.TestVersion3_F{
F: &testdata.TestVersion3{
X: 19,
Sum: &testdata.TestVersion3_E{
E: 99,
},
},
},
},
},
},
recv: new(testdata.TestVersion3LoneNesting),
wantErr: &errUnknownField{
Type: "*testdata.TestVersion3LoneNesting",
TagNum: 6,
WireType: 0,
},
},
{
name: "unknown field types.Any in G",
in: &testdata.TestVersion3{
G: &types.Any{
TypeUrl: "/testdata.TestVersion1",
Value: mustMarshal(&testdata.TestVersion2{
Sum: &testdata.TestVersion2_F{
F: &testdata.TestVersion2{
NewField: 999,
},
},
}),
},
},
recv: new(testdata.TestVersion3),
wantErr: &errUnknownField{
Type: "*testdata.TestVersion1",
TagNum: 25,
},
},
{
name: "types.Any with extra fields",
in: &testdata.TestVersionFD1WithExtraAny{
G: &testdata.AnyWithExtra{
Any: &types.Any{
TypeUrl: "/testdata.TestVersion1",
Value: mustMarshal(&testdata.TestVersion2{
Sum: &testdata.TestVersion2_F{
F: &testdata.TestVersion2{
NewField: 999,
},
},
}),
},
B: 3,
C: 2,
},
},
recv: new(testdata.TestVersion3),
wantErr: &errUnknownField{
Type: "*types.Any",
TagNum: 3,
WireType: 0,
},
},
{
name: "mismatched types.Any in G",
in: &testdata.TestVersion1{
G: &types.Any{
TypeUrl: "/testdata.TestVersion4LoneNesting",
Value: mustMarshal(&testdata.TestVersion3LoneNesting_Inner1{
Inner: &testdata.TestVersion3LoneNesting_Inner1_InnerInner{
Id: "ID",
City: "Gotham",
},
}),
},
},
recv: new(testdata.TestVersion1),
wantErr: &errMismatchedWireType{
Type: "*testdata.TestVersion3",
TagNum: 1,
GotWireType: 2,
WantWireType: 0,
},
},
{
name: "From nested proto message, message index 0",
in: &testdata.TestVersion3LoneNesting{
Inner1: &testdata.TestVersion3LoneNesting_Inner1{
Id: 10,
Name: "foo",
Inner: &testdata.TestVersion3LoneNesting_Inner1_InnerInner{
Id: "ID",
City: "Palo Alto",
},
},
},
recv: new(testdata.TestVersion4LoneNesting),
wantErr: &errMismatchedWireType{
Type: "*testdata.TestVersion4LoneNesting_Inner1_InnerInner",
TagNum: 1,
GotWireType: 2,
WantWireType: 0,
},
},
{
name: "From nested proto message, message index 1",
in: &testdata.TestVersion3LoneNesting{
Inner2: &testdata.TestVersion3LoneNesting_Inner2{
Id: "ID",
Country: "Maldives",
Inner: &testdata.TestVersion3LoneNesting_Inner2_InnerInner{
Id: "ID",
City: "Unknown",
},
},
},
recv: new(testdata.TestVersion4LoneNesting),
wantErr: &errMismatchedWireType{
Type: "*testdata.TestVersion4LoneNesting_Inner2_InnerInner",
TagNum: 2,
GotWireType: 2,
WantWireType: 0,
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
protoBlob, err := proto.Marshal(tt.in)
if err != nil {
t.Fatal(err)
}
ckr := new(Checker)
gotErr := ckr.RejectUnknownFields(protoBlob, tt.recv)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
}
})
}
}
func TestRejectUnknownFieldsFlat(t *testing.T) {
tests := []struct {
name string
in proto.Message
wantErr error
}{
{
name: "Oneof with same field number, shouldn't complain",
in: &testdata.Customer3{
Id: 68,
Name: "ACME3",
Payment: &testdata.Customer3_CreditCardNo{
CreditCardNo: "123-XXXX-XXX881",
},
},
wantErr: nil,
},
{
name: "Oneof with different field number, should fail",
in: &testdata.Customer3{
Id: 68,
Name: "ACME3",
Payment: &testdata.Customer3_ChequeNo{
ChequeNo: "123XXXXXXX881",
},
},
wantErr: &errUnknownField{
Type: "*testdata.Customer1",
TagNum: 8, WireType: 2,
},
},
{
name: "Any in a field, the extra field will be serialized so should fail",
in: &testdata.Customer2{
Miscellaneous: &types.Any{},
},
wantErr: &errUnknownField{
Type: "*testdata.Customer1",
TagNum: 10,
WireType: 2,
},
},
{
name: "With a nested struct as a field",
in: &testdata.Customer3{
Id: 289,
Original: &testdata.Customer1{
Id: 991,
},
},
wantErr: &errUnknownField{
Type: "*testdata.Customer1",
TagNum: 9,
WireType: 2,
},
},
{
name: "An extra field that's non-existent in Customer1",
in: &testdata.Customer2{
Id: 289,
Name: "Customer1",
Industry: 5299,
Fewer: 199.9,
},
wantErr: &errMismatchedWireType{
Type: "*testdata.Customer1",
TagNum: 2, GotWireType: 0, WantWireType: 2,
},
},
{
name: "Using a field that's in the reserved range, should fail by default",
in: &testdata.Customer2{
Id: 289,
Reserved: 99,
},
wantErr: &errUnknownField{
Type: "*testdata.Customer1",
TagNum: 1047,
WireType: 0,
},
},
{
name: "Only fields matching",
in: &testdata.Customer2{
Id: 289,
Name: "Customer1",
},
wantErr: &errMismatchedWireType{
Type: "*testdata.Customer1",
TagNum: 3, GotWireType: 2, WantWireType: 5,
},
},
{
name: "Extra field that's non-existent in Customer1, along with Reserved set",
in: &testdata.Customer2{
Id: 289,
Name: "Customer1",
Industry: 5299,
Fewer: 199.9,
Reserved: 819,
},
wantErr: &errMismatchedWireType{
Type: "*testdata.Customer1",
TagNum: 2, GotWireType: 0, WantWireType: 2,
},
},
{
name: "Using enumerated field",
in: &testdata.Customer2{
Id: 289,
Name: "Customer1",
Industry: 5299,
City: testdata.Customer2_PaloAlto,
},
wantErr: &errMismatchedWireType{
Type: "*testdata.Customer1",
TagNum: 2,
GotWireType: 0, WantWireType: 2,
},
},
{
name: "multiple extraneous fields",
in: &testdata.Customer2{
Id: 289,
Name: "Customer1",
Industry: 5299,
City: testdata.Customer2_PaloAlto,
Fewer: 45,
},
wantErr: &errMismatchedWireType{
TagNum: 2, GotWireType: 0, WantWireType: 2,
Type: "*testdata.Customer1",
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
blob, err := proto.Marshal(tt.in)
if err != nil {
t.Fatalf("Failed to marshal input: %v", err)
}
c1 := new(testdata.Customer1)
ckr := new(Checker)
gotErr := ckr.RejectUnknownFields(blob, c1)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
}
})
}
}
func TestMismatchedTypes_Nested(t *testing.T) {
tests := []struct {
name string
in proto.Message
recv proto.Message
wantErr error
}{
{
name: "mismatched types.Any in G",
in: &testdata.TestVersion1{
G: &types.Any{
TypeUrl: "/testdata.TestVersion4LoneNesting",
Value: mustMarshal(&testdata.TestVersion3LoneNesting_Inner1{
Inner: &testdata.TestVersion3LoneNesting_Inner1_InnerInner{
Id: "ID",
City: "Gotham",
},
}),
},
},
recv: new(testdata.TestVersion1),
wantErr: &errMismatchedWireType{
Type: "*testdata.TestVersion3",
TagNum: 1,
GotWireType: 2,
WantWireType: 0,
},
},
{
name: "From nested proto message, message index 0",
in: &testdata.TestVersion3LoneNesting{
Inner1: &testdata.TestVersion3LoneNesting_Inner1{
Id: 10,
Name: "foo",
Inner: &testdata.TestVersion3LoneNesting_Inner1_InnerInner{
Id: "ID",
City: "Palo Alto",
},
},
},
recv: new(testdata.TestVersion4LoneNesting),
wantErr: &errMismatchedWireType{
Type: "*testdata.TestVersion4LoneNesting_Inner1_InnerInner",
TagNum: 1,
GotWireType: 2,
WantWireType: 0,
},
},
{
name: "From nested proto message, message index 1",
in: &testdata.TestVersion3LoneNesting{
Inner2: &testdata.TestVersion3LoneNesting_Inner2{
Id: "ID",
Country: "Maldives",
Inner: &testdata.TestVersion3LoneNesting_Inner2_InnerInner{
Id: "ID",
City: "Unknown",
},
},
},
recv: new(testdata.TestVersion4LoneNesting),
wantErr: &errMismatchedWireType{
Type: "*testdata.TestVersion4LoneNesting_Inner2_InnerInner",
TagNum: 2,
GotWireType: 2,
WantWireType: 0,
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
protoBlob, err := proto.Marshal(tt.in)
if err != nil {
t.Fatal(err)
}
ckr := new(Checker)
gotErr := ckr.RejectUnknownFields(protoBlob, tt.recv)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
}
})
}
}
func mustMarshal(msg proto.Message) []byte {
blob, err := proto.Marshal(msg)
if err != nil {
panic(err)
}
return blob
}

1
go.mod
View File

@ -41,6 +41,7 @@ require (
github.com/tendermint/tendermint v0.33.6
github.com/tendermint/tm-db v0.5.1
google.golang.org/grpc v1.30.0
google.golang.org/protobuf v1.24.0
gopkg.in/yaml.v2 v2.3.0
)

7
go.sum
View File

@ -185,6 +185,7 @@ github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrU
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
github.com/golang/protobuf v1.4.0 h1:oOuy+ugB+P/kBdUnG5QaMXSIyJ1q38wWSojYCb3z5VQ=
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
@ -706,6 +707,8 @@ google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBr
google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
google.golang.org/genproto v0.0.0-20200324203455-a04cca1dde73 h1:+yTMTeazSO5iBqU9NR53hgriivQQbYa5Uuaj8r3qKII=
google.golang.org/genproto v0.0.0-20200324203455-a04cca1dde73/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM=
@ -731,8 +734,12 @@ google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQ
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
google.golang.org/protobuf v1.21.0 h1:qdOKuR/EIArgaWNjetjgTzgVTAZ+S/WXVrq9HW9zimw=
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.24.0 h1:UhZDfRO8JRQru4/+LlLE0BRKGF8L+PICnvYZmx/fEGA=
google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,8 @@
syntax = "proto3";
package testdata;
import "google/protobuf/any.proto";
import "gogoproto/gogo.proto";
import "google/protobuf/any.proto";
option go_package = "github.com/cosmos/cosmos-sdk/testutil/testdata";
@ -70,3 +70,274 @@ message BadMultiSignature {
repeated bytes signatures = 1;
bytes malicious_field = 5;
}
message Customer1 {
int32 id = 1;
string name = 2;
float subscription_fee = 3;
string payment = 7;
}
message Customer2 {
int32 id = 1;
int32 industry = 2;
string name = 3;
float fewer = 4;
int64 reserved = 1047;
enum City {
Laos = 0;
LosAngeles = 1;
PaloAlto = 2;
Moscow = 3;
Nairobi = 4;
}
City city = 6;
google.protobuf.Any miscellaneous = 10;
}
message Nested4A {
int32 id = 1;
string name = 2;
}
message Nested3A {
int32 id = 1;
string name = 2;
repeated Nested4A a4 = 4;
map<int64, Nested4A> index = 5;
}
message Nested2A {
int32 id = 1;
string name = 2;
Nested3A nested = 3;
}
message Nested1A {
int32 id = 1;
Nested2A nested = 2;
}
message Nested4B {
int32 id = 1;
int32 age = 2;
string name = 3;
}
message Nested3B {
int32 id = 1;
int32 age = 2;
string name = 3;
repeated Nested4B b4 = 4;
}
message Nested2B {
int32 id = 1;
double fee = 2;
Nested3B nested = 3;
string route = 4;
}
message Nested1B {
int32 id = 1;
Nested2B nested = 2;
int32 age = 3;
}
message Customer3 {
int32 id = 1;
string name = 2;
float sf = 3;
float surcharge = 4;
string destination = 5;
oneof payment {
string credit_card_no = 7;
string cheque_no = 8;
}
Customer1 original = 9;
}
message TestVersion1 {
int64 x = 1;
TestVersion1 a = 2;
TestVersion1 b = 3; // [(gogoproto.nullable) = false] generates invalid recursive structs;
repeated TestVersion1 c = 4;
repeated TestVersion1 d = 5 [(gogoproto.nullable) = false];
oneof sum {
int32 e = 6;
TestVersion1 f = 7;
}
google.protobuf.Any g = 8;
repeated TestVersion1 h = 9; // [(gogoproto.castrepeated) = "TestVersion1"];
// google.protobuf.Timestamp i = 10;
// google.protobuf.Timestamp j = 11; // [(gogoproto.stdtime) = true];
Customer1 k = 12 [(gogoproto.embed) = true];
}
message TestVersion2 {
int64 x = 1;
TestVersion2 a = 2;
TestVersion2 b = 3; // [(gogoproto.nullable) = false];
repeated TestVersion2 c = 4;
repeated TestVersion2 d = 5; // [(gogoproto.nullable) = false];
oneof sum {
int32 e = 6;
TestVersion2 f = 7;
}
google.protobuf.Any g = 8;
repeated TestVersion1 h = 9; // [(gogoproto.castrepeated) = "TestVersion1"];
// google.protobuf.Timestamp i = 10;
// google.protobuf.Timestamp j = 11; // [(gogoproto.stdtime) = true];
Customer1 k = 12 [(gogoproto.embed) = true];
uint64 new_field = 25;
}
message TestVersion3 {
int64 x = 1;
TestVersion3 a = 2;
TestVersion3 b = 3; // [(gogoproto.nullable) = false];
repeated TestVersion3 c = 4;
repeated TestVersion3 d = 5; // [(gogoproto.nullable) = false];
oneof sum {
int32 e = 6;
TestVersion3 f = 7;
}
google.protobuf.Any g = 8;
repeated TestVersion1 h = 9; //[(gogoproto.castrepeated) = "TestVersion1"];
// google.protobuf.Timestamp i = 10;
// google.protobuf.Timestamp j = 11; // [(gogoproto.stdtime) = true];
Customer1 k = 12 [(gogoproto.embed) = true];
string non_critical_field = 1031;
}
message TestVersion3LoneOneOfValue {
int64 x = 1;
TestVersion3 a = 2;
TestVersion3 b = 3; // [(gogoproto.nullable) = false];
repeated TestVersion3 c = 4;
repeated TestVersion3 d = 5; // [(gogoproto.nullable) = false];
oneof sum {
int32 e = 6;
}
google.protobuf.Any g = 8;
repeated TestVersion1 h = 9; //[(gogoproto.castrepeated) = "TestVersion1"];
// google.protobuf.Timestamp i = 10;
// google.protobuf.Timestamp j = 11; // [(gogoproto.stdtime) = true];
Customer1 k = 12 [(gogoproto.embed) = true];
string non_critical_field = 1031;
}
message TestVersion3LoneNesting {
int64 x = 1;
TestVersion3 a = 2;
TestVersion3 b = 3; // [(gogoproto.nullable) = false];
repeated TestVersion3 c = 4;
repeated TestVersion3 d = 5; // [(gogoproto.nullable) = false];
oneof sum {
TestVersion3LoneNesting f = 7;
}
google.protobuf.Any g = 8;
repeated TestVersion1 h = 9; //[(gogoproto.castrepeated) = "TestVersion1"];
// google.protobuf.Timestamp i = 10;
// google.protobuf.Timestamp j = 11; // [(gogoproto.stdtime) = true];
Customer1 k = 12 [(gogoproto.embed) = true];
string non_critical_field = 1031;
message Inner1 {
int64 id = 1;
string name = 2;
message InnerInner {
string id = 1;
string city = 2;
}
InnerInner inner = 3;
}
Inner1 inner1 = 14;
message Inner2 {
string id = 1;
string country = 2;
message InnerInner {
string id = 1;
string city = 2;
}
InnerInner inner = 3;
}
Inner2 inner2 = 15;
}
message TestVersion4LoneNesting {
int64 x = 1;
TestVersion3 a = 2;
TestVersion3 b = 3; // [(gogoproto.nullable) = false];
repeated TestVersion3 c = 4;
repeated TestVersion3 d = 5; // [(gogoproto.nullable) = false];
oneof sum {
TestVersion3LoneNesting f = 7;
}
google.protobuf.Any g = 8;
repeated TestVersion1 h = 9; //[(gogoproto.castrepeated) = "TestVersion1"];
// google.protobuf.Timestamp i = 10;
// google.protobuf.Timestamp j = 11; // [(gogoproto.stdtime) = true];
Customer1 k = 12 [(gogoproto.embed) = true];
string non_critical_field = 1031;
message Inner1 {
int64 id = 1;
string name = 2;
message InnerInner {
int64 id = 1;
string city = 2;
}
InnerInner inner = 3;
}
Inner1 inner1 = 14;
message Inner2 {
string id = 1;
string country = 2;
message InnerInner {
string id = 1;
int64 value = 2;
}
InnerInner inner = 3;
}
Inner2 inner2 = 15;
}
message TestVersionFD1 {
int64 x = 1;
TestVersion1 a = 2;
oneof sum {
int32 e = 6;
TestVersion1 f = 7;
}
google.protobuf.Any g = 8;
repeated TestVersion1 h = 9; // [(gogoproto.castrepeated) = "TestVersion1"];
}
message TestVersionFD1WithExtraAny {
int64 x = 1;
TestVersion1 a = 2;
oneof sum {
int32 e = 6;
TestVersion1 f = 7;
}
AnyWithExtra g = 8;
repeated TestVersion1 h = 9; // [(gogoproto.castrepeated) = "TestVersion1"];
}
message AnyWithExtra {
google.protobuf.Any a = 1 [(gogoproto.embed) = true];
int64 b = 3;
int64 c = 4;
}

View File

@ -24,6 +24,8 @@ var _ = math.Inf
// proto package needs to be updated.
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
// GenesisState is currently only used to ensure that the InitGenesis gets run
// by the module manager
type GenesisState struct {
PortID string `protobuf:"bytes,1,opt,name=port_id,json=portId,proto3" json:"port_id,omitempty" yaml:"port_id"`
}

File diff suppressed because it is too large Load Diff