cosmos-sdk/server/v2/api/grpc/server.go

258 lines
6.7 KiB
Go

package grpc
import (
"context"
"errors"
"fmt"
"io"
"maps"
"net"
"slices"
"strconv"
"strings"
"sync"
gogoproto "github.com/cosmos/gogoproto/proto"
"github.com/spf13/pflag"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/reflect/protoreflect"
appmodulev2 "cosmossdk.io/core/appmodule/v2"
"cosmossdk.io/core/server"
"cosmossdk.io/core/transaction"
"cosmossdk.io/log"
serverv2 "cosmossdk.io/server/v2"
"cosmossdk.io/server/v2/api/grpc/gogoreflection"
)
const (
ServerName = "grpc"
BlockHeightHeader = "x-cosmos-block-height"
)
type Server[T transaction.Tx] struct {
logger log.Logger
config *Config
cfgOptions []CfgOption
grpcSrv *grpc.Server
extraGRPCHandlers []func(*grpc.Server) error
}
// New creates a new grpc server.
func New[T transaction.Tx](
logger log.Logger,
interfaceRegistry server.InterfaceRegistry,
queryHandlers map[string]appmodulev2.Handler,
queryable func(ctx context.Context, version uint64, msg transaction.Msg) (transaction.Msg, error),
cfg server.ConfigMap,
opts ...OptionFunc[T],
) (*Server[T], error) {
srv := &Server[T]{}
for _, opt := range opts {
opt(srv)
}
serverCfg := srv.Config().(*Config)
if len(cfg) > 0 {
if err := serverv2.UnmarshalSubConfig(cfg, srv.Name(), &serverCfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
}
grpcSrv := grpc.NewServer(
grpc.ForceServerCodec(newProtoCodec(interfaceRegistry).GRPCCodec()),
grpc.MaxSendMsgSize(serverCfg.MaxSendMsgSize),
grpc.MaxRecvMsgSize(serverCfg.MaxRecvMsgSize),
grpc.UnknownServiceHandler(makeUnknownServiceHandler(queryHandlers, queryable)),
)
// register grpc query handler v2
RegisterServiceServer(grpcSrv, &v2Service{queryHandlers, queryable})
// reflection allows external clients to see what services and methods the gRPC server exposes.
gogoreflection.Register(grpcSrv, slices.Collect(maps.Keys(queryHandlers)), logger.With("sub-module", "grpc-reflection"))
// register extra handlers on the grpc server
var err error
for _, fn := range srv.extraGRPCHandlers {
err = errors.Join(err, fn(grpcSrv))
}
if err != nil {
return nil, fmt.Errorf("failed to register extra gRPC handlers: %w", err)
}
srv.grpcSrv = grpcSrv
srv.config = serverCfg
srv.logger = logger.With(log.ModuleKey, srv.Name())
return srv, nil
}
type OptionFunc[T transaction.Tx] func(*Server[T])
// WithCfgOptions allows to overwrite the default server configuration.
func WithCfgOptions[T transaction.Tx](cfgOptions ...CfgOption) OptionFunc[T] {
return func(srv *Server[T]) {
srv.cfgOptions = cfgOptions
}
}
// WithExtraGRPCHandlers allows to register extra handlers on the grpc server.
func WithExtraGRPCHandlers[T transaction.Tx](handlers ...func(*grpc.Server) error) OptionFunc[T] {
return func(srv *Server[T]) {
srv.extraGRPCHandlers = handlers
}
}
// NewWithConfigOptions creates a new GRPC server with the provided config options.
// It is *not* a fully functional server (since it has been created without dependencies)
// The returned server should only be used to get and set configuration.
func NewWithConfigOptions[T transaction.Tx](opts ...CfgOption) *Server[T] {
return &Server[T]{
cfgOptions: opts,
}
}
func (s *Server[T]) StartCmdFlags() *pflag.FlagSet {
flags := pflag.NewFlagSet(s.Name(), pflag.ExitOnError)
flags.String(FlagAddress, "localhost:9090", "Listen address")
return flags
}
func makeUnknownServiceHandler(
handlers map[string]appmodulev2.Handler,
queryable func(ctx context.Context, version uint64, msg transaction.Msg) (transaction.Msg, error),
) grpc.StreamHandler {
getRegistry := sync.OnceValues(gogoproto.MergedRegistry)
return func(srv any, stream grpc.ServerStream) error {
method, ok := grpc.MethodFromServerStream(stream)
if !ok {
return status.Error(codes.InvalidArgument, "unable to get method")
}
// if this fails we cannot serve queries anymore...
registry, err := getRegistry()
if err != nil {
return fmt.Errorf("failed to get registry: %w", err)
}
method = strings.TrimPrefix(method, "/")
fullName := protoreflect.FullName(strings.ReplaceAll(method, "/", "."))
// get descriptor from the invoke method
desc, err := registry.FindDescriptorByName(fullName)
if err != nil {
return fmt.Errorf("failed to find descriptor %s: %w", method, err)
}
md, ok := desc.(protoreflect.MethodDescriptor)
if !ok {
return fmt.Errorf("%s is not a method", method)
}
// find handler
handler, exists := handlers[string(md.Input().FullName())]
if !exists {
return status.Errorf(codes.Unimplemented, "gRPC method %s is not handled", method)
}
for {
req := handler.MakeMsg()
err := stream.RecvMsg(req)
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
// extract height header
ctx := stream.Context()
height, err := getHeightFromCtx(ctx)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid get height from context: %v", err)
}
resp, err := queryable(ctx, height, req)
if err != nil {
return err
}
err = stream.SendMsg(resp)
if err != nil {
return err
}
}
}
}
func getHeightFromCtx(ctx context.Context) (uint64, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return 0, nil
}
values := md.Get(BlockHeightHeader)
if len(values) == 0 {
return 0, nil
}
if len(values) != 1 {
return 0, fmt.Errorf("gRPC height metadata must be of length 1, got: %d", len(values))
}
heightStr := values[0]
height, err := strconv.ParseUint(heightStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("unable to parse height string from gRPC metadata %s: %w", heightStr, err)
}
return height, nil
}
func (s *Server[T]) Name() string {
return ServerName
}
func (s *Server[T]) Config() any {
if s.config == nil || s.config.Address == "" {
cfg := DefaultConfig()
// overwrite the default config with the provided options
for _, opt := range s.cfgOptions {
opt(cfg)
}
return cfg
}
return s.config
}
func (s *Server[T]) Start(ctx context.Context) error {
if !s.config.Enable {
s.logger.Info(fmt.Sprintf("%s server is disabled via config", s.Name()))
return nil
}
listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", s.config.Address)
if err != nil {
return fmt.Errorf("failed to listen on address %s: %w", s.config.Address, err)
}
s.logger.Info("starting gRPC server...", "address", s.config.Address)
if err := s.grpcSrv.Serve(listener); err != nil {
return fmt.Errorf("failed to start gRPC server: %w", err)
}
return nil
}
func (s *Server[T]) Stop(ctx context.Context) error {
if !s.config.Enable {
return nil
}
s.logger.Info("stopping gRPC server...", "address", s.config.Address)
s.grpcSrv.GracefulStop()
return nil
}