diff --git a/server/v2/stf/stf.go b/server/v2/stf/stf.go index d6cf27115d..ae57b9f175 100644 --- a/server/v2/stf/stf.go +++ b/server/v2/stf/stf.go @@ -26,8 +26,8 @@ var Identity = []byte("stf") type STF[T transaction.Tx] struct { logger log.Logger - msgRouter Router - queryRouter Router + msgRouter coreRouterImpl + queryRouter coreRouterImpl doPreBlock func(ctx context.Context, txs []T) error doBeginBlock func(ctx context.Context) error @@ -584,8 +584,8 @@ func newExecutionContext( sender transaction.Identity, state store.WriterMap, execMode transaction.ExecMode, - msgRouter Router, - queryRouter Router, + msgRouter coreRouterImpl, + queryRouter coreRouterImpl, ) *executionContext { meter := makeGasMeterFn(gas.NoGasLimit) meteredState := makeGasMeteredStoreFn(meter, state) diff --git a/server/v2/stf/stf_router.go b/server/v2/stf/stf_router.go index b54c537f85..9f08ddcfb4 100644 --- a/server/v2/stf/stf_router.go +++ b/server/v2/stf/stf_router.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "strings" gogoproto "github.com/cosmos/gogoproto/proto" @@ -61,7 +62,7 @@ func (b *MsgRouterBuilder) HandlerExists(msgType string) bool { return ok } -func (b *MsgRouterBuilder) Build() (Router, error) { +func (b *MsgRouterBuilder) Build() (coreRouterImpl, error) { handlers := make(map[string]appmodulev2.Handler) globalPreHandler := func(ctx context.Context, msg appmodulev2.Message) error { @@ -93,7 +94,7 @@ func (b *MsgRouterBuilder) Build() (Router, error) { handlers[msgType] = buildHandler(handler, preHandlers, globalPreHandler, postHandlers, globalPostHandler) } - return Router{ + return coreRouterImpl{ handlers: handlers, }, nil } @@ -139,14 +140,18 @@ func msgTypeURL(msg gogoproto.Message) string { return gogoproto.MessageName(msg) } -var _ router.Service = (*Router)(nil) +var _ router.Service = (*coreRouterImpl)(nil) -// Router implements the STF router for msg and query handlers. -type Router struct { +// coreRouterImpl implements the STF router for msg and query handlers. +type coreRouterImpl struct { handlers map[string]appmodulev2.Handler } -func (r Router) CanInvoke(_ context.Context, typeURL string) error { +func (r coreRouterImpl) CanInvoke(_ context.Context, typeURL string) error { + // trimming prefixes is a backwards compatibility strategy that we use + // for baseapp components that did routing through type URL rather + // than protobuf message names. + typeURL = strings.TrimPrefix(typeURL, "/") _, exists := r.handlers[typeURL] if !exists { return fmt.Errorf("%w: %s", ErrNoHandler, typeURL) @@ -154,20 +159,15 @@ func (r Router) CanInvoke(_ context.Context, typeURL string) error { return nil } -func (r Router) InvokeTyped(ctx context.Context, req, resp gogoproto.Message) error { +func (r coreRouterImpl) InvokeTyped(ctx context.Context, req, resp gogoproto.Message) error { handlerResp, err := r.InvokeUntyped(ctx, req) if err != nil { return err } - merge(handlerResp, resp) - return nil + return merge(handlerResp, resp) } -func merge(src, dst gogoproto.Message) { - reflect.Indirect(reflect.ValueOf(dst)).Set(reflect.Indirect(reflect.ValueOf(src))) -} - -func (r Router) InvokeUntyped(ctx context.Context, req gogoproto.Message) (res gogoproto.Message, err error) { +func (r coreRouterImpl) InvokeUntyped(ctx context.Context, req gogoproto.Message) (res gogoproto.Message, err error) { typeName := msgTypeURL(req) handler, exists := r.handlers[typeName] if !exists { @@ -175,3 +175,38 @@ func (r Router) InvokeUntyped(ctx context.Context, req gogoproto.Message) (res g } return handler(ctx, req) } + +// merge merges together two protobuf messages by setting the pointer +// to src in dst. Used internally. +func merge(src, dst gogoproto.Message) error { + if src == nil { + return fmt.Errorf("source message is nil") + } + if dst == nil { + return fmt.Errorf("destination message is nil") + } + + srcVal := reflect.ValueOf(src) + dstVal := reflect.ValueOf(dst) + + if srcVal.Kind() == reflect.Interface { + srcVal = srcVal.Elem() + } + if dstVal.Kind() == reflect.Interface { + dstVal = dstVal.Elem() + } + + if srcVal.Kind() != reflect.Ptr || dstVal.Kind() != reflect.Ptr { + return fmt.Errorf("both source and destination must be pointers") + } + + srcElem := srcVal.Elem() + dstElem := dstVal.Elem() + + if !srcElem.Type().AssignableTo(dstElem.Type()) { + return fmt.Errorf("incompatible types: cannot merge %v into %v", srcElem.Type(), dstElem.Type()) + } + + dstElem.Set(srcElem) + return nil +} diff --git a/server/v2/stf/stf_router_test.go b/server/v2/stf/stf_router_test.go new file mode 100644 index 0000000000..70d25e3334 --- /dev/null +++ b/server/v2/stf/stf_router_test.go @@ -0,0 +1,107 @@ +package stf + +import ( + "context" + "testing" + + gogoproto "github.com/cosmos/gogoproto/proto" + gogotypes "github.com/cosmos/gogoproto/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cosmossdk.io/core/appmodule/v2" +) + +func TestRouter(t *testing.T) { + expectedMsg := &gogotypes.BoolValue{Value: true} + expectedMsgName := gogoproto.MessageName(expectedMsg) + + expectedResp := &gogotypes.StringValue{Value: "test"} + + router := coreRouterImpl{handlers: map[string]appmodule.Handler{ + gogoproto.MessageName(expectedMsg): func(ctx context.Context, gotMsg appmodule.Message) (msgResp appmodule.Message, err error) { + require.Equal(t, expectedMsg, gotMsg) + return expectedResp, nil + }, + }} + + t.Run("can invoke message by name", func(t *testing.T) { + err := router.CanInvoke(context.Background(), expectedMsgName) + require.NoError(t, err, "must be invokable") + }) + + t.Run("can invoke message by type URL", func(t *testing.T) { + err := router.CanInvoke(context.Background(), "/"+expectedMsgName) + require.NoError(t, err) + }) + + t.Run("cannot invoke unknown message", func(t *testing.T) { + err := router.CanInvoke(context.Background(), "not exist") + require.Error(t, err) + }) + + t.Run("invoke untyped", func(t *testing.T) { + gotResp, err := router.InvokeUntyped(context.Background(), expectedMsg) + require.NoError(t, err) + require.Equal(t, expectedResp, gotResp) + }) + + t.Run("invoked typed", func(t *testing.T) { + gotResp := new(gogotypes.StringValue) + err := router.InvokeTyped(context.Background(), expectedMsg, gotResp) + require.NoError(t, err) + require.Equal(t, expectedResp, gotResp) + }) +} + +func TestMerge(t *testing.T) { + tests := []struct { + name string + src gogoproto.Message + dst gogoproto.Message + expected gogoproto.Message + wantErr bool + }{ + { + name: "success", + src: &gogotypes.BoolValue{Value: true}, + dst: &gogotypes.BoolValue{}, + expected: &gogotypes.BoolValue{Value: true}, + wantErr: false, + }, + { + name: "nil src", + src: nil, + dst: &gogotypes.StringValue{}, + expected: &gogotypes.StringValue{}, + wantErr: true, + }, + { + name: "nil dst", + src: &gogotypes.StringValue{Value: "hello"}, + dst: nil, + expected: nil, + wantErr: true, + }, + { + name: "incompatible types", + src: &gogotypes.StringValue{Value: "hello"}, + dst: &gogotypes.BoolValue{}, + expected: &gogotypes.BoolValue{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := merge(tt.src, tt.dst) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, tt.dst) + } + }) + } +}