cosmos-sdk/baseapp/internal/protocompat/protocompat.go
Julien Robert cfe194013c
test: fix failing tests after protocompat changes (#22889)
Co-authored-by: Marko <marko@baricevic.me>
2024-12-16 16:06:22 +00:00

253 lines
9.1 KiB
Go

package protocompat
import (
"context"
"fmt"
"reflect"
gogoproto "github.com/cosmos/gogoproto/proto"
"google.golang.org/grpc"
proto2 "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/runtime/protoiface"
"github.com/cosmos/cosmos-sdk/codec"
)
var (
gogoType = reflect.TypeOf((*gogoproto.Message)(nil)).Elem()
protov2Type = reflect.TypeOf((*proto2.Message)(nil)).Elem()
protov2MarshalOpts = proto2.MarshalOptions{Deterministic: true}
)
type Handler = func(ctx context.Context, request, response protoiface.MessageV1) error
// MakeHybridHandler returns a handler that can handle both gogo and protov2 messages, no matter
// if the handler is a gogo or protov2 handler.
func MakeHybridHandler(cdc codec.BinaryCodec, sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) (Handler, error) {
methodFullName := protoreflect.FullName(fmt.Sprintf("%s.%s", sd.ServiceName, method.MethodName))
desc, err := gogoproto.HybridResolver.FindDescriptorByName(methodFullName)
if err != nil {
return nil, err
}
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
if !ok {
return nil, fmt.Errorf("invalid method descriptor %s", methodFullName)
}
isProtov2Handler, err := isProtov2(method)
if err != nil {
return nil, err
}
if isProtov2Handler {
return makeProtoV2HybridHandler(methodDesc, cdc, method, handler)
}
return makeGogoHybridHandler(methodDesc, cdc, method, handler)
}
// makeProtoV2HybridHandler returns a handler that can handle both gogo and protov2 messages.
func makeProtoV2HybridHandler(prefMethod protoreflect.MethodDescriptor, cdc codec.BinaryCodec, method grpc.MethodDesc, handler any) (Handler, error) {
// it's a protov2 handler, if a gogo counterparty is not found we cannot handle gogo messages.
gogoExists := gogoproto.MessageType(string(prefMethod.Output().FullName())) != nil
if !gogoExists {
return func(ctx context.Context, inReq, outResp protoiface.MessageV1) error {
protov2Request, ok := inReq.(proto2.Message)
if !ok {
return fmt.Errorf("invalid request type %T, method %s does not accept gogoproto messages", inReq, prefMethod.FullName())
}
resp, err := method.Handler(handler, ctx, func(msg any) error {
proto2.Merge(msg.(proto2.Message), protov2Request)
return nil
}, nil)
if err != nil {
return err
}
// merge on the resp
proto2.Merge(outResp.(proto2.Message), resp.(proto2.Message))
return nil
}, nil
}
return func(ctx context.Context, inReq, outResp protoiface.MessageV1) error {
// we check if the request is a protov2 message.
switch m := inReq.(type) {
case proto2.Message:
// we can just call the handler after making a copy of the message, for safety reasons.
resp, err := method.Handler(handler, ctx, func(msg any) error {
proto2.Merge(msg.(proto2.Message), m)
return nil
}, nil)
if err != nil {
return err
}
// merge on the resp
proto2.Merge(outResp.(proto2.Message), resp.(proto2.Message))
return nil
case gogoproto.Message:
// we need to marshal and unmarshal the request.
requestBytes, err := cdc.Marshal(m)
if err != nil {
return err
}
resp, err := method.Handler(handler, ctx, func(msg any) error {
// unmarshal request into the message.
return proto2.Unmarshal(requestBytes, msg.(proto2.Message))
}, nil)
if err != nil {
return err
}
// the response is a protov2 message, so we cannot just return it.
// since the request came as gogoproto, we expect the response
// to also be gogoproto.
respBytes, err := protov2MarshalOpts.Marshal(resp.(proto2.Message))
if err != nil {
return err
}
// unmarshal response into a gogo message.
return cdc.Unmarshal(respBytes, outResp.(gogoproto.Message))
default:
panic("unreachable")
}
}, nil
}
func makeGogoHybridHandler(prefMethod protoreflect.MethodDescriptor, cdc codec.BinaryCodec, method grpc.MethodDesc, handler any) (Handler, error) {
// it's a gogo handler, we check if the existing protov2 counterparty exists.
_, err := protoregistry.GlobalTypes.FindMessageByName(prefMethod.Output().FullName())
if err != nil {
// this can only be a gogo message.
return func(ctx context.Context, inReq, outResp protoiface.MessageV1) error {
_, ok := inReq.(proto2.Message)
if ok {
return fmt.Errorf("invalid request type %T, method %s does not accept protov2 messages", inReq, prefMethod.FullName())
}
resp, err := method.Handler(handler, ctx, func(msg any) error {
return setPointer(msg, inReq)
}, nil)
if err != nil {
return err
}
return setPointer(outResp, resp)
}, nil
}
// this is a gogo handler, and we have a protov2 counterparty.
return func(ctx context.Context, inReq, outResp protoiface.MessageV1) error {
switch m := inReq.(type) {
case proto2.Message:
// we need to marshal and unmarshal the request.
requestBytes, err := protov2MarshalOpts.Marshal(m)
if err != nil {
return err
}
resp, err := method.Handler(handler, ctx, func(msg any) error {
// unmarshal request into the message.
return cdc.Unmarshal(requestBytes, msg.(gogoproto.Message))
}, nil)
if err != nil {
return err
}
// the response is a gogo message, so we cannot just return it.
// since the request came as protov2, we expect the response
// to also be protov2.
respBytes, err := cdc.Marshal(resp.(gogoproto.Message))
if err != nil {
return err
}
// now we unmarshal back into a protov2 message.
return proto2.Unmarshal(respBytes, outResp.(proto2.Message))
case gogoproto.Message:
// we can just call the handler after making a copy of the message, for safety reasons.
resp, err := method.Handler(handler, ctx, func(msg any) error {
return setPointer(msg, m)
}, nil)
if err != nil {
return err
}
return setPointer(outResp, resp)
default:
panic("unreachable")
}
}, nil
}
// isProtov2 returns true if the given method accepts protov2 messages.
// Returns false if it does not.
// It uses the decoder function passed to the method handler to determine
// the type. Since the decoder function is passed in by the concrete implementer the expected
// message where bytes are unmarshaled to, we can use that to determine the type.
func isProtov2(md grpc.MethodDesc) (isV2Type bool, err error) {
pullRequestType := func(msg interface{}) error {
typ := reflect.TypeOf(msg)
switch {
case typ.Implements(protov2Type):
isV2Type = true
return nil
case typ.Implements(gogoType):
isV2Type = false
return nil
default:
err = fmt.Errorf("invalid request type %T, expected protov2 or gogo message", msg)
return nil
}
}
// doNotExecute is a dummy handler that stops the request execution.
doNotExecute := func(_ context.Context, _ any, _ *grpc.UnaryServerInfo, _ grpc.UnaryHandler) (any, error) {
return nil, nil
}
// we are allowed to pass in a nil context and nil request, since we are not actually executing the request.
// this is made possible by the doNotExecute function which immediately returns without calling other handlers.
_, _ = md.Handler(nil, nil, pullRequestType, doNotExecute)
return
}
// RequestFullNameFromMethodDesc returns the fully-qualified name of the request message of the provided service's method.
func RequestFullNameFromMethodDesc(sd *grpc.ServiceDesc, method grpc.MethodDesc) (protoreflect.FullName, error) {
methodFullName := protoreflect.FullName(fmt.Sprintf("%s.%s", sd.ServiceName, method.MethodName))
desc, err := gogoproto.HybridResolver.FindDescriptorByName(methodFullName)
if err != nil {
return "", fmt.Errorf("cannot find method descriptor %s", methodFullName)
}
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
if !ok {
return "", fmt.Errorf("invalid method descriptor %s", methodFullName)
}
return methodDesc.Input().FullName(), nil
}
// ResponseFullNameFromMethodDesc returns the fully-qualified name of the response message of the provided service's method.
func ResponseFullNameFromMethodDesc(sd *grpc.ServiceDesc, method grpc.MethodDesc) (protoreflect.FullName, error) {
methodFullName := protoreflect.FullName(fmt.Sprintf("%s.%s", sd.ServiceName, method.MethodName))
desc, err := gogoproto.HybridResolver.FindDescriptorByName(methodFullName)
if err != nil {
return "", fmt.Errorf("cannot find method descriptor %s", methodFullName)
}
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
if !ok {
return "", fmt.Errorf("invalid method descriptor %s", methodFullName)
}
return methodDesc.Output().FullName(), nil
}
// since proto.Merge breaks due to the custom cosmos sdk any, we are forced to do this ugly setPointer hack.
// ref: https://github.com/cosmos/cosmos-sdk/issues/22779
func setPointer(dst, src any) error {
dstValue := reflect.ValueOf(dst)
srcValue := reflect.ValueOf(src)
if !dstValue.IsValid() || !srcValue.IsValid() {
return fmt.Errorf("dst and src must be valid")
}
if dstValue.IsNil() || srcValue.IsNil() {
return fmt.Errorf("dst and src must be non-nil")
}
dstElem := dstValue.Elem()
srcElem := srcValue.Elem()
if dstElem.Type() != srcElem.Type() {
return fmt.Errorf("dst and src must have the same type")
}
dstElem.Set(srcElem)
return nil
}