/* * * Copyright 2016 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ /* Package gogoreflection implements server reflection service. The service implemented is defined in: https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto. To register server reflection on a gRPC server: import "google.golang.org/grpc/reflection" s := grpc.NewServer() pb.RegisterYourOwnServer(s, &server{}) // Register reflection service on gRPC server. reflection.Register(s) s.Serve(lis) */ package gogoreflection // import "google.golang.org/grpc/reflection" import ( "bytes" "compress/gzip" "errors" "fmt" "io" "reflect" "sort" "sync" gogoproto "github.com/cosmos/gogoproto/proto" dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" "google.golang.org/grpc" "google.golang.org/grpc/codes" rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" "google.golang.org/grpc/status" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" "cosmossdk.io/core/log" ) type serverReflectionServer struct { rpb.UnimplementedServerReflectionServer s *grpc.Server messages []string initSymbols sync.Once serviceNames []string symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files log log.Logger } // Register registers the server reflection service on the given gRPC server. func Register(s *grpc.Server, messages []string, logger log.Logger) { rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ s: s, messages: messages, log: logger, }) } // protoMessage is used for type assertion on proto messages. // Generated proto message implements function Descriptor(), but Descriptor() // is not part of interface proto.Message. This interface is needed to // call Descriptor(). type protoMessage interface { Descriptor() ([]byte, []int) } func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) { s.initSymbols.Do(func() { s.symbols = map[string]*dpb.FileDescriptorProto{} services, fds := s.getServices(s.messages) s.serviceNames = services processed := map[string]struct{}{} for _, fd := range fds { s.processFile(fd, processed) } sort.Strings(s.serviceNames) }) return s.serviceNames, s.symbols } func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) { filename := fd.GetName() if _, ok := processed[filename]; ok { return } processed[filename] = struct{}{} prefix := fd.GetPackage() for _, msg := range fd.MessageType { s.processMessage(fd, prefix, msg) } for _, en := range fd.EnumType { s.processEnum(fd, prefix, en) } for _, ext := range fd.Extension { s.processField(fd, prefix, ext) } for _, svc := range fd.Service { svcName := fqn(prefix, svc.GetName()) s.symbols[svcName] = fd for _, meth := range svc.Method { name := fqn(svcName, meth.GetName()) s.symbols[name] = fd } } for _, dep := range fd.Dependency { fdenc := getFileDescriptor(dep) fdDep, err := decodeFileDesc(fdenc) if err != nil { continue } s.processFile(fdDep, processed) } } func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) { msgName := fqn(prefix, msg.GetName()) s.symbols[msgName] = fd for _, nested := range msg.NestedType { s.processMessage(fd, msgName, nested) } for _, en := range msg.EnumType { s.processEnum(fd, msgName, en) } for _, ext := range msg.Extension { s.processField(fd, msgName, ext) } for _, fld := range msg.Field { s.processField(fd, msgName, fld) } for _, oneof := range msg.OneofDecl { oneofName := fqn(msgName, oneof.GetName()) s.symbols[oneofName] = fd } } func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) { enName := fqn(prefix, en.GetName()) s.symbols[enName] = fd for _, val := range en.Value { valName := fqn(enName, val.GetName()) s.symbols[valName] = fd } } func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) { fldName := fqn(prefix, fld.GetName()) s.symbols[fldName] = fd } func fqn(prefix, name string) string { if prefix == "" { return name } return prefix + "." + name } // fileDescForType gets the file descriptor for the given type. // The given type should be a proto message. func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) { m, ok := reflect.Zero(reflect.PointerTo(st)).Interface().(protoMessage) if !ok { return nil, fmt.Errorf("failed to create message from type: %v", st) } enc, _ := m.Descriptor() return decodeFileDesc(enc) } // decodeFileDesc does decompression and unmarshalling on the given // file descriptor byte slice. func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { raw, err := decompress(enc) if err != nil { return nil, fmt.Errorf("failed to decompress enc: %w", err) } fd := new(dpb.FileDescriptorProto) if err := gogoproto.Unmarshal(raw, fd); err != nil { return nil, fmt.Errorf("bad descriptor: %w", err) } return fd, nil } // decompress does gzip decompression. func decompress(b []byte) ([]byte, error) { r, err := gzip.NewReader(bytes.NewReader(b)) if err != nil { return nil, fmt.Errorf("bad gzipped descriptor: %w", err) } out, err := io.ReadAll(r) if err != nil { return nil, fmt.Errorf("bad gzipped descriptor: %w", err) } return out, nil } func typeForName(name string) (reflect.Type, error) { pt := getMessageType(name) if pt == nil { return nil, fmt.Errorf("unknown type: %q", name) } st := pt.Elem() return st, nil } func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) { m, ok := reflect.Zero(reflect.PointerTo(st)).Interface().(gogoproto.Message) if !ok { return nil, fmt.Errorf("failed to create message from type: %v", st) } extDesc := getExtension(ext, m) if extDesc == nil { return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext) } return decodeFileDesc(getFileDescriptor(extDesc.Filename)) } func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) { m, ok := reflect.Zero(reflect.PointerTo(st)).Interface().(gogoproto.Message) if !ok { return nil, fmt.Errorf("failed to create message from type: %v", st) } out := getExtensionsNumbers(m) return out, nil } // fileDescWithDependencies returns a slice of serialized fileDescriptors in // wire format ([]byte). The fileDescriptors will include fd and all the // transitive dependencies of fd with names not in sentFileDescriptors. func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) { r := [][]byte{} queue := []*dpb.FileDescriptorProto{fd} for len(queue) > 0 { currentfd := queue[0] queue = queue[1:] if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent { sentFileDescriptors[currentfd.GetName()] = true currentfdEncoded, err := gogoproto.Marshal(currentfd) if err != nil { return nil, err } r = append(r, currentfdEncoded) } for _, dep := range currentfd.Dependency { fdenc := getFileDescriptor(dep) fdDep, err := decodeFileDesc(fdenc) if err != nil { continue } queue = append(queue, fdDep) } } return r, nil } // fileDescEncodingByFilename finds the file descriptor for given filename, // finds all of its previously unsent transitive dependencies, does marshaling // on them, and returns the marshaled result. func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { enc := getFileDescriptor(name) if enc == nil { return nil, fmt.Errorf("unknown file: %v", name) } fd, err := decodeFileDesc(enc) if err != nil { return nil, err } return fileDescWithDependencies(fd, sentFileDescriptors) } // fileDescEncodingContainingSymbol finds the file descriptor containing the // given symbol, finds all of its previously unsent transitive dependencies, // does marshaling on them, and returns the marshaled result. The given symbol // can be a type, a service or a method. func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { _, symbols := s.getSymbols() fd := symbols[name] if fd == nil { // Check if it's a type name that was not present in the // transitive dependencies of the registered services. if st, err := typeForName(name); err == nil { fd, err = s.fileDescForType(st) if err != nil { return nil, err } } } if fd == nil { return nil, fmt.Errorf("unknown symbol: %v", name) } return fileDescWithDependencies(fd, sentFileDescriptors) } // fileDescEncodingContainingExtension finds the file descriptor containing // given extension, finds all of its previously unsent transitive dependencies, // does marshaling on them, and returns the marshaled result. func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) { st, err := typeForName(typeName) if err != nil { return nil, err } fd, err := fileDescContainingExtension(st, extNum) if err != nil { return nil, err } return fileDescWithDependencies(fd, sentFileDescriptors) } // allExtensionNumbersForTypeName returns all extension numbers for the given type. func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) { st, err := typeForName(name) if err != nil { return nil, err } extNums, err := s.allExtensionNumbersForType(st) if err != nil { return nil, err } return extNums, nil } // ServerReflectionInfo is the reflection service handler. func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error { sentFileDescriptors := make(map[string]bool) for { in, err := stream.Recv() if errors.Is(err, io.EOF) { return nil } if err != nil { return err } out := &rpb.ServerReflectionResponse{ //nolint:staticcheck // SA1019: we want to keep using v1alpha ValidHost: in.Host, //nolint:staticcheck // SA1019: we want to keep using v1alpha OriginalRequest: in, } switch req := in.MessageRequest.(type) { case *rpb.ServerReflectionRequest_FileByFilename: b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors) //nolint:staticcheck // SA1019: we want to keep using v1alpha if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ //nolint:staticcheck // SA1019: we want to keep using v1alpha ErrorCode: int32(codes.NotFound), ErrorMessage: err.Error(), }, } } else { out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, //nolint:staticcheck // SA1019: we want to keep using v1alpha } } case *rpb.ServerReflectionRequest_FileContainingSymbol: b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors) //nolint:staticcheck // SA1019: we want to keep using v1alpha if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ //nolint:staticcheck // SA1019: we want to keep using v1alpha ErrorCode: int32(codes.NotFound), ErrorMessage: err.Error(), }, } } else { out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, //nolint:staticcheck // SA1019: we want to keep using v1alpha } } case *rpb.ServerReflectionRequest_FileContainingExtension: typeName := req.FileContainingExtension.ContainingType //nolint:staticcheck // SA1019: we want to keep using v1alpha extNum := req.FileContainingExtension.ExtensionNumber //nolint:staticcheck // SA1019: we want to keep using v1alpha b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors) if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ //nolint:staticcheck // SA1019: we want to keep using v1alpha ErrorCode: int32(codes.NotFound), ErrorMessage: err.Error(), }, } } else { out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, //nolint:staticcheck // SA1019: we want to keep using v1alpha } } case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType: extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType) //nolint:staticcheck // SA1019: we want to keep using v1alpha if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ //nolint:staticcheck // SA1019: we want to keep using v1alpha ErrorCode: int32(codes.NotFound), ErrorMessage: err.Error(), }, } } else { out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{ AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{ //nolint:staticcheck // SA1019: we want to keep using v1alpha BaseTypeName: req.AllExtensionNumbersOfType, //nolint:staticcheck // SA1019: we want to keep using v1alpha ExtensionNumber: extNums, }, } } case *rpb.ServerReflectionRequest_ListServices: svcNames, _ := s.getSymbols() serviceResponses := make([]*rpb.ServiceResponse, len(svcNames)) //nolint:staticcheck // SA1019: we want to keep using v1alpha for i, n := range svcNames { serviceResponses[i] = &rpb.ServiceResponse{ //nolint:staticcheck // SA1019: we want to keep using v1alpha Name: n, } } out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{ ListServicesResponse: &rpb.ListServiceResponse{ //nolint:staticcheck // SA1019: we want to keep using v1alpha Service: serviceResponses, }, } default: return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest) } if err := stream.Send(out); err != nil { return err } } } // getServices gets the unique list of services given a list of methods. func (s *serverReflectionServer) getServices(messages []string) (svcs []string, fds []*dpb.FileDescriptorProto) { registry, err := gogoproto.MergedRegistry() if err != nil { s.log.Error("unable to load merged registry", "err", err) return nil, nil } seenSvc := map[protoreflect.FullName]struct{}{} for _, messageName := range messages { md, err := registry.FindDescriptorByName(protoreflect.FullName(messageName)) if err != nil { s.log.Error("unable to load message descriptor", "message", messageName, "err", err) continue } svc, ok := findServiceForMessage(registry, md.(protoreflect.MessageDescriptor)) if !ok { // if a service is not found for the message, simply skip // this is likely the message isn't part of a service and using appmodulev2.Handler instead. continue } if _, seen := seenSvc[svc.FullName()]; !seen { svcs = append(svcs, string(svc.FullName())) file := svc.ParentFile() fds = append(fds, protodesc.ToFileDescriptorProto(file)) } seenSvc[svc.FullName()] = struct{}{} } return svcs, fds } func findServiceForMessage(registry *protoregistry.Files, messageDesc protoreflect.MessageDescriptor) (protoreflect.ServiceDescriptor, bool) { var ( service protoreflect.ServiceDescriptor found bool ) registry.RangeFiles(func(fileDescriptor protoreflect.FileDescriptor) bool { for i := 0; i < fileDescriptor.Services().Len(); i++ { serviceDesc := fileDescriptor.Services().Get(i) for j := 0; j < serviceDesc.Methods().Len(); j++ { methodDesc := serviceDesc.Methods().Get(j) if methodDesc.Input() == messageDesc || methodDesc.Output() == messageDesc { service = serviceDesc found = true return false } } } return true }) return service, found }