package decode import ( "errors" "fmt" "strings" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/known/anypb" ) const bit11NonCritical = 1 << 10 var ( anyDesc = (&anypb.Any{}).ProtoReflect().Descriptor() anyFullName = anyDesc.FullName() ) // RejectUnknownFieldsStrict operates by the same rules as RejectUnknownFields, but returns an error if any unknown // non-critical fields are encountered. func RejectUnknownFieldsStrict(bz []byte, msg protoreflect.MessageDescriptor, resolver protodesc.Resolver) error { _, err := RejectUnknownFields(bz, msg, false, resolver) return err } // RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an // option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the // hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be // used to treat a message with non-critical field different in different security contexts (such as transaction signing). // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message. // An AnyResolver must be provided for traversing inside google.protobuf.Any's. func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUnknownNonCriticals bool, resolver protodesc.Resolver) (hasUnknownNonCriticals bool, err error) { // recursion limit with same default as https://github.com/protocolbuffers/protobuf-go/blob/v1.35.2/encoding/protowire/wire.go#L28 return doRejectUnknownFields(bz, desc, allowUnknownNonCriticals, resolver, 10_000) } func doRejectUnknownFields( bz []byte, desc protoreflect.MessageDescriptor, allowUnknownNonCriticals bool, resolver protodesc.Resolver, recursionLimit int, ) (hasUnknownNonCriticals bool, err error) { if len(bz) == 0 { return hasUnknownNonCriticals, nil } if recursionLimit <= 0 { return false, errors.New("recursion limit reached") } fields := desc.Fields() for len(bz) > 0 { tagNum, wireType, m := protowire.ConsumeTag(bz) if m < 0 { return hasUnknownNonCriticals, errors.New("invalid length") } fieldDesc := fields.ByNumber(tagNum) if fieldDesc == nil { isCriticalField := tagNum&bit11NonCritical == 0 if !isCriticalField { hasUnknownNonCriticals = true } if isCriticalField || !allowUnknownNonCriticals { // The tag is critical, so report it. return hasUnknownNonCriticals, ErrUnknownField.Wrapf( "%s: {TagNum: %d, WireType:%q}", desc.FullName(), tagNum, WireTypeToString(wireType)) } } // Skip over the bytes that store fieldNumber and wireType bytes. bz = bz[m:] n := protowire.ConsumeFieldValue(tagNum, wireType, bz) if n < 0 { err = fmt.Errorf("could not consume field value for tagNum: %d, wireType: %q; %w", tagNum, WireTypeToString(wireType), protowire.ParseError(n)) return hasUnknownNonCriticals, err } fieldBytes := bz[:n] bz = bz[n:] // An unknown but non-critical field if fieldDesc == nil { continue } fieldMessage := fieldDesc.Message() // not message or group kind if fieldMessage == nil { continue } // if a message descriptor is a placeholder resolve it using the injected resolver. // this can happen when a descriptor has been registered in the // "google.golang.org/protobuf" registry but not in "github.com/cosmos/gogoproto". // fixes: https://github.com/cosmos/cosmos-sdk/issues/22574 if fieldMessage.IsPlaceholder() { gogoDesc, err := resolver.FindDescriptorByName(fieldMessage.FullName()) if err != nil { return hasUnknownNonCriticals, fmt.Errorf("could not resolve placeholder descriptor: %v: %w", fieldMessage, err) } fieldMessage = gogoDesc.(protoreflect.MessageDescriptor) } // consume length prefix of nested message _, o := protowire.ConsumeVarint(fieldBytes) if o < 0 { err = fmt.Errorf("could not consume length prefix fieldBytes for nested message: %v: %w", fieldMessage, protowire.ParseError(o)) return hasUnknownNonCriticals, err } else if o > len(fieldBytes) { err = fmt.Errorf("length prefix > len(fieldBytes) for nested message: %v", fieldMessage) return hasUnknownNonCriticals, err } fieldBytes = fieldBytes[o:] var err error if fieldMessage.FullName() == anyFullName { // Firstly typecheck types.Any to ensure nothing snuck in. hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, anyDesc, allowUnknownNonCriticals, resolver, recursionLimit-1) hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild if err != nil { return hasUnknownNonCriticals, err } var a anypb.Any if err = proto.Unmarshal(fieldBytes, &a); err != nil { return hasUnknownNonCriticals, err } msgName := protoreflect.FullName(strings.TrimPrefix(a.TypeUrl, "/")) msgDesc, err := resolver.FindDescriptorByName(msgName) if err != nil { return hasUnknownNonCriticals, err } fieldMessage = msgDesc.(protoreflect.MessageDescriptor) fieldBytes = a.Value } hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, fieldMessage, allowUnknownNonCriticals, resolver, recursionLimit-1) hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild if err != nil { return hasUnknownNonCriticals, err } } return hasUnknownNonCriticals, nil } // errUnknownField represents an error indicating that we encountered // a field that isn't available in the target proto.Message. type errUnknownField struct { Desc protoreflect.MessageDescriptor TagNum protowire.Number WireType protowire.Type } // String implements fmt.Stringer. func (twt *errUnknownField) String() string { return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}", twt.Desc.FullName(), twt.TagNum, WireTypeToString(twt.WireType)) } // Error implements the error interface. func (twt *errUnknownField) Error() string { return twt.String() } var _ error = (*errUnknownField)(nil) // WireTypeToString returns a string representation of the given protowire.Type. func WireTypeToString(wt protowire.Type) string { switch wt { case 0: return "varint" case 1: return "fixed64" case 2: return "bytes" case 3: return "start_group" case 4: return "end_group" case 5: return "fixed32" default: return fmt.Sprintf("unknown type: %d", wt) } }