From 7843204d63b73be82ef34f35eca8454564b12f3d Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Fri, 3 Nov 2023 11:08:02 +0000 Subject: [PATCH] feat(baseapp): Add Hybrid Protobuf handlers to MsgServiceRouter (backport #18071) (#18338) Co-authored-by: testinginprod <98415576+testinginprod@users.noreply.github.com> Co-authored-by: unknown unknown --- CHANGELOG.md | 1 + baseapp/grpcrouter.go | 32 ++- baseapp/grpcrouter_test.go | 2 +- baseapp/internal/protocompat/protocompat.go | 14 ++ baseapp/msg_service_router.go | 204 ++++++++++++-------- baseapp/msg_service_router_test.go | 35 ++++ server/mock/app.go | 50 ++++- 7 files changed, 238 insertions(+), 100 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d1068cfab..337ccee732 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ Ref: https://keepachangelog.com/en/1.0.0/ * (server) [#18162](https://github.com/cosmos/cosmos-sdk/pull/18162) Start gRPC & API server in standalone mode. * (baseapp) [#16581](https://github.com/cosmos/cosmos-sdk/pull/16581) Implement Optimistic Execution as an experimental feature (not enabled by default). +* (baseapp) [#18071](https://github.com/cosmos/cosmos-sdk/pull/18071) Add hybrid handlers to `MsgServiceRouter`. ### Improvements diff --git a/baseapp/grpcrouter.go b/baseapp/grpcrouter.go index 2c3980ccbb..9955ecb460 100644 --- a/baseapp/grpcrouter.go +++ b/baseapp/grpcrouter.go @@ -6,10 +6,8 @@ import ( abci "github.com/cometbft/cometbft/abci/types" gogogrpc "github.com/cosmos/gogoproto/grpc" - "github.com/cosmos/gogoproto/proto" "google.golang.org/grpc" "google.golang.org/grpc/encoding" - "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/runtime/protoiface" "github.com/cosmos/cosmos-sdk/baseapp/internal/protocompat" @@ -23,9 +21,9 @@ import ( type GRPCQueryRouter struct { // routes maps query handlers used in ABCIQuery. routes map[string]GRPCQueryHandler - // handlerByMessageName maps the request name to the handler. It is a hybrid handler which seamlessly + // hybridHandlers maps the request name to the handler. It is a hybrid handler which seamlessly // handles both gogo and protov2 messages. - handlerByMessageName map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error + hybridHandlers map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error // binaryCodec is used to encode/decode binary protobuf messages. binaryCodec codec.BinaryCodec // cdc is the gRPC codec used by the router to correctly unmarshal messages. @@ -45,8 +43,8 @@ var _ gogogrpc.Server = &GRPCQueryRouter{} // NewGRPCQueryRouter creates a new GRPCQueryRouter func NewGRPCQueryRouter() *GRPCQueryRouter { return &GRPCQueryRouter{ - routes: map[string]GRPCQueryHandler{}, - handlerByMessageName: map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error{}, + routes: map[string]GRPCQueryHandler{}, + hybridHandlers: map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error{}, } } @@ -76,7 +74,7 @@ func (qrt *GRPCQueryRouter) RegisterService(sd *grpc.ServiceDesc, handler interf if err != nil { panic(err) } - err = qrt.registerHandlerByMessageName(sd, method, handler) + err = qrt.registerHybridHandler(sd, method, handler) if err != nil { panic(err) } @@ -131,36 +129,30 @@ func (qrt *GRPCQueryRouter) registerABCIQueryHandler(sd *grpc.ServiceDesc, metho return nil } -func (qrt *GRPCQueryRouter) HandlersByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error { - return qrt.handlerByMessageName[name] +func (qrt *GRPCQueryRouter) HybridHandlerByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error { + return qrt.hybridHandlers[name] } -func (qrt *GRPCQueryRouter) registerHandlerByMessageName(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error { +func (qrt *GRPCQueryRouter) registerHybridHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error { // extract message name from method descriptor - methodFullName := protoreflect.FullName(fmt.Sprintf("%s.%s", sd.ServiceName, method.MethodName)) - desc, err := proto.HybridResolver.FindDescriptorByName(methodFullName) + inputName, err := protocompat.RequestFullNameFromMethodDesc(sd, method) if err != nil { - return fmt.Errorf("cannot find method descriptor %s", methodFullName) + return err } - methodDesc, ok := desc.(protoreflect.MethodDescriptor) - if !ok { - return fmt.Errorf("invalid method descriptor %s", methodFullName) - } - inputName := methodDesc.Input().FullName() methodHandler, err := protocompat.MakeHybridHandler(qrt.binaryCodec, sd, method, handler) if err != nil { return err } - qrt.handlerByMessageName[string(inputName)] = append(qrt.handlerByMessageName[string(inputName)], methodHandler) + qrt.hybridHandlers[string(inputName)] = append(qrt.hybridHandlers[string(inputName)], methodHandler) return nil } // SetInterfaceRegistry sets the interface registry for the router. This will // also register the interface reflection gRPC service. func (qrt *GRPCQueryRouter) SetInterfaceRegistry(interfaceRegistry codectypes.InterfaceRegistry) { - qrt.binaryCodec = codec.NewProtoCodec(interfaceRegistry) // instantiate the codec qrt.cdc = codec.NewProtoCodec(interfaceRegistry).GRPCCodec() + qrt.binaryCodec = codec.NewProtoCodec(interfaceRegistry) // Once we have an interface registry, we can register the interface // registry reflection gRPC service. reflection.RegisterReflectionServiceServer(qrt, reflection.NewReflectionServiceServer(interfaceRegistry)) diff --git a/baseapp/grpcrouter_test.go b/baseapp/grpcrouter_test.go index de68fb0f46..d747fbb747 100644 --- a/baseapp/grpcrouter_test.go +++ b/baseapp/grpcrouter_test.go @@ -56,7 +56,7 @@ func TestGRPCQueryRouter(t *testing.T) { func TestGRPCRouterHybridHandlers(t *testing.T) { assertRouterBehaviour := func(helper *baseapp.QueryServiceTestHelper) { // test getting the handler by name - handlers := helper.GRPCQueryRouter.HandlersByRequestName("testpb.EchoRequest") + handlers := helper.GRPCQueryRouter.HybridHandlerByRequestName("testpb.EchoRequest") require.NotNil(t, handlers) require.Len(t, handlers, 1) handler := handlers[0] diff --git a/baseapp/internal/protocompat/protocompat.go b/baseapp/internal/protocompat/protocompat.go index e536fae961..bad4787d4e 100644 --- a/baseapp/internal/protocompat/protocompat.go +++ b/baseapp/internal/protocompat/protocompat.go @@ -205,3 +205,17 @@ func isProtov2(md grpc.MethodDesc) (isV2Type bool, err error) { _, _ = md.Handler(nil, nil, pullRequestType, doNotExecute) return } + +// RequestFullNameFromMethodDesc returns the fully-qualified name of the request message of the provided service's method. +func RequestFullNameFromMethodDesc(sd *grpc.ServiceDesc, method grpc.MethodDesc) (protoreflect.FullName, error) { + methodFullName := protoreflect.FullName(fmt.Sprintf("%s.%s", sd.ServiceName, method.MethodName)) + desc, err := gogoproto.HybridResolver.FindDescriptorByName(methodFullName) + if err != nil { + return "", fmt.Errorf("cannot find method descriptor %s", methodFullName) + } + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + if !ok { + return "", fmt.Errorf("invalid method descriptor %s", methodFullName) + } + return methodDesc.Input().FullName(), nil +} diff --git a/baseapp/msg_service_router.go b/baseapp/msg_service_router.go index b3876075af..126e0f65e5 100644 --- a/baseapp/msg_service_router.go +++ b/baseapp/msg_service_router.go @@ -7,9 +7,12 @@ import ( gogogrpc "github.com/cosmos/gogoproto/grpc" "github.com/cosmos/gogoproto/proto" "google.golang.org/grpc" + "google.golang.org/protobuf/runtime/protoiface" errorsmod "cosmossdk.io/errors" + "github.com/cosmos/cosmos-sdk/baseapp/internal/protocompat" + "github.com/cosmos/cosmos-sdk/codec" codectypes "github.com/cosmos/cosmos-sdk/codec/types" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -26,6 +29,7 @@ type MessageRouter interface { type MsgServiceRouter struct { interfaceRegistry codectypes.InterfaceRegistry routes map[string]MsgServiceHandler + hybridHandlers map[string]func(ctx context.Context, req, resp protoiface.MessageV1) error circuitBreaker CircuitBreaker } @@ -34,7 +38,8 @@ var _ gogogrpc.Server = &MsgServiceRouter{} // NewMsgServiceRouter creates a new MsgServiceRouter. func NewMsgServiceRouter() *MsgServiceRouter { return &MsgServiceRouter{ - routes: map[string]MsgServiceHandler{}, + routes: map[string]MsgServiceHandler{}, + hybridHandlers: map[string]func(ctx context.Context, req, resp protoiface.MessageV1) error{}, } } @@ -66,100 +71,143 @@ func (msr *MsgServiceRouter) HandlerByTypeURL(typeURL string) MsgServiceHandler func (msr *MsgServiceRouter) RegisterService(sd *grpc.ServiceDesc, handler interface{}) { // Adds a top-level query handler based on the gRPC service name. for _, method := range sd.Methods { - fqMethod := fmt.Sprintf("/%s/%s", sd.ServiceName, method.MethodName) - methodHandler := method.Handler + err := msr.registerMsgServiceHandler(sd, method, handler) + if err != nil { + panic(err) + } + err = msr.registerHybridHandler(sd, method, handler) + if err != nil { + panic(err) + } + } +} - var requestTypeName string +func (msr *MsgServiceRouter) HybridHandlerByMsgName(msgName string) func(ctx context.Context, req, resp protoiface.MessageV1) error { + return msr.hybridHandlers[msgName] +} - // NOTE: This is how we pull the concrete request type for each handler for registering in the InterfaceRegistry. - // This approach is maybe a bit hacky, but less hacky than reflecting on the handler object itself. - // We use a no-op interceptor to avoid actually calling into the handler itself. - _, _ = methodHandler(nil, context.Background(), func(i interface{}) error { - msg, ok := i.(sdk.Msg) - if !ok { - // We panic here because there is no other alternative and the app cannot be initialized correctly - // this should only happen if there is a problem with code generation in which case the app won't - // work correctly anyway. - panic(fmt.Errorf("unable to register service method %s: %T does not implement sdk.Msg", fqMethod, i)) - } +func (msr *MsgServiceRouter) registerHybridHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error { + inputName, err := protocompat.RequestFullNameFromMethodDesc(sd, method) + if err != nil { + return err + } + cdc := codec.NewProtoCodec(msr.interfaceRegistry) + hybridHandler, err := protocompat.MakeHybridHandler(cdc, sd, method, handler) + if err != nil { + return err + } + // if circuit breaker is not nil, then we decorate the hybrid handler with the circuit breaker + if msr.circuitBreaker == nil { + msr.hybridHandlers[string(inputName)] = hybridHandler + return nil + } + // decorate the hybrid handler with the circuit breaker + circuitBreakerHybridHandler := func(ctx context.Context, req, resp protoiface.MessageV1) error { + messageName := codectypes.MsgTypeURL(req) + allowed, err := msr.circuitBreaker.IsAllowed(ctx, messageName) + if err != nil { + return err + } + if !allowed { + return fmt.Errorf("circuit breaker disallows execution of message %s", messageName) + } + return hybridHandler(ctx, req, resp) + } + msr.hybridHandlers[string(inputName)] = circuitBreakerHybridHandler + return nil +} - requestTypeName = sdk.MsgTypeURL(msg) - return nil - }, noopInterceptor) +func (msr *MsgServiceRouter) registerMsgServiceHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error { + fqMethod := fmt.Sprintf("/%s/%s", sd.ServiceName, method.MethodName) + methodHandler := method.Handler - // Check that the service Msg fully-qualified method name has already - // been registered (via RegisterInterfaces). If the user registers a - // service without registering according service Msg type, there might be - // some unexpected behavior down the road. Since we can't return an error - // (`Server.RegisterService` interface restriction) we panic (at startup). - reqType, err := msr.interfaceRegistry.Resolve(requestTypeName) - if err != nil || reqType == nil { - panic( - fmt.Errorf( - "type_url %s has not been registered yet. "+ - "Before calling RegisterService, you must register all interfaces by calling the `RegisterInterfaces` "+ - "method on module.BasicManager. Each module should call `msgservice.RegisterMsgServiceDesc` inside its "+ - "`RegisterInterfaces` method with the `_Msg_serviceDesc` generated by proto-gen", - requestTypeName, - ), - ) + var requestTypeName string + + // NOTE: This is how we pull the concrete request type for each handler for registering in the InterfaceRegistry. + // This approach is maybe a bit hacky, but less hacky than reflecting on the handler object itself. + // We use a no-op interceptor to avoid actually calling into the handler itself. + _, _ = methodHandler(nil, context.Background(), func(i interface{}) error { + msg, ok := i.(sdk.Msg) + if !ok { + // We panic here because there is no other alternative and the app cannot be initialized correctly + // this should only happen if there is a problem with code generation in which case the app won't + // work correctly anyway. + panic(fmt.Errorf("unable to register service method %s: %T does not implement sdk.Msg", fqMethod, i)) } - // Check that each service is only registered once. If a service is - // registered more than once, then we should error. Since we can't - // return an error (`Server.RegisterService` interface restriction) we - // panic (at startup). - _, found := msr.routes[requestTypeName] - if found { - panic( - fmt.Errorf( - "msg service %s has already been registered. Please make sure to only register each service once. "+ - "This usually means that there are conflicting modules registering the same msg service", - fqMethod, - ), - ) + requestTypeName = sdk.MsgTypeURL(msg) + return nil + }, noopInterceptor) + + // Check that the service Msg fully-qualified method name has already + // been registered (via RegisterInterfaces). If the user registers a + // service without registering according service Msg type, there might be + // some unexpected behavior down the road. Since we can't return an error + // (`Server.RegisterService` interface restriction) we panic (at startup). + reqType, err := msr.interfaceRegistry.Resolve(requestTypeName) + if err != nil || reqType == nil { + return fmt.Errorf( + "type_url %s has not been registered yet. "+ + "Before calling RegisterService, you must register all interfaces by calling the `RegisterInterfaces` "+ + "method on module.BasicManager. Each module should call `msgservice.RegisterMsgServiceDesc` inside its "+ + "`RegisterInterfaces` method with the `_Msg_serviceDesc` generated by proto-gen", + requestTypeName, + ) + } + + // Check that each service is only registered once. If a service is + // registered more than once, then we should error. Since we can't + // return an error (`Server.RegisterService` interface restriction) we + // panic (at startup). + _, found := msr.routes[requestTypeName] + if found { + return fmt.Errorf( + "msg service %s has already been registered. Please make sure to only register each service once. "+ + "This usually means that there are conflicting modules registering the same msg service", + fqMethod, + ) + } + + msr.routes[requestTypeName] = func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { + ctx = ctx.WithEventManager(sdk.NewEventManager()) + interceptor := func(goCtx context.Context, _ interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + goCtx = context.WithValue(goCtx, sdk.SdkContextKey, ctx) + return handler(goCtx, msg) } - msr.routes[requestTypeName] = func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { - ctx = ctx.WithEventManager(sdk.NewEventManager()) - interceptor := func(goCtx context.Context, _ interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - goCtx = context.WithValue(goCtx, sdk.SdkContextKey, ctx) - return handler(goCtx, msg) + if m, ok := msg.(sdk.HasValidateBasic); ok { + if err := m.ValidateBasic(); err != nil { + return nil, err } + } - if m, ok := msg.(sdk.HasValidateBasic); ok { - if err := m.ValidateBasic(); err != nil { - return nil, err - } - } - - if msr.circuitBreaker != nil { - msgURL := sdk.MsgTypeURL(msg) - isAllowed, err := msr.circuitBreaker.IsAllowed(ctx, msgURL) - if err != nil { - return nil, err - } - - if !isAllowed { - return nil, fmt.Errorf("circuit breaker disables execution of this message: %s", msgURL) - } - } - - // Call the method handler from the service description with the handler object. - // We don't do any decoding here because the decoding was already done. - res, err := methodHandler(handler, ctx, noopDecoder, interceptor) + if msr.circuitBreaker != nil { + msgURL := sdk.MsgTypeURL(msg) + isAllowed, err := msr.circuitBreaker.IsAllowed(ctx, msgURL) if err != nil { return nil, err } - resMsg, ok := res.(proto.Message) - if !ok { - return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting proto.Message, got %T", resMsg) + if !isAllowed { + return nil, fmt.Errorf("circuit breaker disables execution of this message: %s", msgURL) } - - return sdk.WrapServiceResult(ctx, resMsg, err) } + + // Call the method handler from the service description with the handler object. + // We don't do any decoding here because the decoding was already done. + res, err := methodHandler(handler, ctx, noopDecoder, interceptor) + if err != nil { + return nil, err + } + + resMsg, ok := res.(proto.Message) + if !ok { + return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting proto.Message, got %T", resMsg) + } + + return sdk.WrapServiceResult(ctx, resMsg, err) } + return nil } // SetInterfaceRegistry sets the interface registry for the router. diff --git a/baseapp/msg_service_router_test.go b/baseapp/msg_service_router_test.go index 4d8cff45a4..8ddb490e99 100644 --- a/baseapp/msg_service_router_test.go +++ b/baseapp/msg_service_router_test.go @@ -86,6 +86,41 @@ func TestRegisterMsgServiceTwice(t *testing.T) { }) } +func TestHybridHandlerByMsgName(t *testing.T) { + // Setup baseapp and router. + var ( + appBuilder *runtime.AppBuilder + registry codectypes.InterfaceRegistry + ) + err := depinject.Inject( + depinject.Configs( + makeMinimalConfig(), + depinject.Supply(log.NewTestLogger(t)), + ), &appBuilder, ®istry) + require.NoError(t, err) + db := dbm.NewMemDB() + app := appBuilder.Build(db, nil) + testdata.RegisterInterfaces(registry) + + testdata.RegisterMsgServer( + app.MsgServiceRouter(), + testdata.MsgServerImpl{}, + ) + + handler := app.MsgServiceRouter().HybridHandlerByMsgName("testpb.MsgCreateDog") + + require.NotNil(t, handler) + require.NoError(t, app.Init()) + ctx := app.NewContext(true) + resp := new(testdata.MsgCreateDogResponse) + err = handler(ctx, &testdata.MsgCreateDog{ + Dog: &testdata.Dog{Name: "Spot"}, + Owner: "me", + }, resp) + require.NoError(t, err) + require.Equal(t, resp.Name, "Spot") +} + func TestMsgService(t *testing.T) { priv, _, _ := testdata.KeyTestPubAddr() diff --git a/server/mock/app.go b/server/mock/app.go index bbb74168fc..3b42051660 100644 --- a/server/mock/app.go +++ b/server/mock/app.go @@ -10,6 +10,10 @@ import ( abci "github.com/cometbft/cometbft/abci/types" db "github.com/cosmos/cosmos-db" "google.golang.org/grpc" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/descriptorpb" "cosmossdk.io/log" storetypes "cosmossdk.io/store/types" @@ -45,7 +49,7 @@ func NewApp(rootDir string, logger log.Logger) (servertypes.ABCI, error) { router.SetInterfaceRegistry(interfaceRegistry) newDesc := &grpc.ServiceDesc{ - ServiceName: "test", + ServiceName: "Test", Methods: []grpc.MethodDesc{ { MethodName: "Test", @@ -170,3 +174,47 @@ func MsgTestHandler(srv interface{}, ctx context.Context, dec func(interface{}) func (m MsgServerImpl) Test(ctx context.Context, msg *KVStoreTx) (*sdk.Result, error) { return KVStoreHandler(m.capKeyMainStore)(sdk.UnwrapSDKContext(ctx), msg) } + +func init() { + err := registerFauxDescriptor() + if err != nil { + panic(err) + } +} + +func registerFauxDescriptor() error { + fauxDescriptor, err := protodesc.NewFile(&descriptorpb.FileDescriptorProto{ + Name: proto.String("faux_proto/test.proto"), + Dependency: nil, + PublicDependency: nil, + WeakDependency: nil, + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("KVStoreTx"), + }, + }, + EnumType: nil, + Service: []*descriptorpb.ServiceDescriptorProto{ + { + Name: proto.String("Test"), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: proto.String("Test"), + InputType: proto.String("KVStoreTx"), + OutputType: proto.String("KVStoreTx"), + }, + }, + }, + }, + Extension: nil, + Options: nil, + SourceCodeInfo: nil, + Syntax: nil, + Edition: nil, + }, nil) + if err != nil { + return err + } + + return protoregistry.GlobalFiles.RegisterFile(fauxDescriptor) +}