feat(stf/router): support backwards compat type URL in router (#21177)
This commit is contained in:
parent
90fd6320a6
commit
4dc9469320
@ -26,8 +26,8 @@ var Identity = []byte("stf")
|
||||
type STF[T transaction.Tx] struct {
|
||||
logger log.Logger
|
||||
|
||||
msgRouter Router
|
||||
queryRouter Router
|
||||
msgRouter coreRouterImpl
|
||||
queryRouter coreRouterImpl
|
||||
|
||||
doPreBlock func(ctx context.Context, txs []T) error
|
||||
doBeginBlock func(ctx context.Context) error
|
||||
@ -584,8 +584,8 @@ func newExecutionContext(
|
||||
sender transaction.Identity,
|
||||
state store.WriterMap,
|
||||
execMode transaction.ExecMode,
|
||||
msgRouter Router,
|
||||
queryRouter Router,
|
||||
msgRouter coreRouterImpl,
|
||||
queryRouter coreRouterImpl,
|
||||
) *executionContext {
|
||||
meter := makeGasMeterFn(gas.NoGasLimit)
|
||||
meteredState := makeGasMeteredStoreFn(meter, state)
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
gogoproto "github.com/cosmos/gogoproto/proto"
|
||||
|
||||
@ -61,7 +62,7 @@ func (b *MsgRouterBuilder) HandlerExists(msgType string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
func (b *MsgRouterBuilder) Build() (Router, error) {
|
||||
func (b *MsgRouterBuilder) Build() (coreRouterImpl, error) {
|
||||
handlers := make(map[string]appmodulev2.Handler)
|
||||
|
||||
globalPreHandler := func(ctx context.Context, msg appmodulev2.Message) error {
|
||||
@ -93,7 +94,7 @@ func (b *MsgRouterBuilder) Build() (Router, error) {
|
||||
handlers[msgType] = buildHandler(handler, preHandlers, globalPreHandler, postHandlers, globalPostHandler)
|
||||
}
|
||||
|
||||
return Router{
|
||||
return coreRouterImpl{
|
||||
handlers: handlers,
|
||||
}, nil
|
||||
}
|
||||
@ -139,14 +140,18 @@ func msgTypeURL(msg gogoproto.Message) string {
|
||||
return gogoproto.MessageName(msg)
|
||||
}
|
||||
|
||||
var _ router.Service = (*Router)(nil)
|
||||
var _ router.Service = (*coreRouterImpl)(nil)
|
||||
|
||||
// Router implements the STF router for msg and query handlers.
|
||||
type Router struct {
|
||||
// coreRouterImpl implements the STF router for msg and query handlers.
|
||||
type coreRouterImpl struct {
|
||||
handlers map[string]appmodulev2.Handler
|
||||
}
|
||||
|
||||
func (r Router) CanInvoke(_ context.Context, typeURL string) error {
|
||||
func (r coreRouterImpl) CanInvoke(_ context.Context, typeURL string) error {
|
||||
// trimming prefixes is a backwards compatibility strategy that we use
|
||||
// for baseapp components that did routing through type URL rather
|
||||
// than protobuf message names.
|
||||
typeURL = strings.TrimPrefix(typeURL, "/")
|
||||
_, exists := r.handlers[typeURL]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: %s", ErrNoHandler, typeURL)
|
||||
@ -154,20 +159,15 @@ func (r Router) CanInvoke(_ context.Context, typeURL string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Router) InvokeTyped(ctx context.Context, req, resp gogoproto.Message) error {
|
||||
func (r coreRouterImpl) InvokeTyped(ctx context.Context, req, resp gogoproto.Message) error {
|
||||
handlerResp, err := r.InvokeUntyped(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
merge(handlerResp, resp)
|
||||
return nil
|
||||
return merge(handlerResp, resp)
|
||||
}
|
||||
|
||||
func merge(src, dst gogoproto.Message) {
|
||||
reflect.Indirect(reflect.ValueOf(dst)).Set(reflect.Indirect(reflect.ValueOf(src)))
|
||||
}
|
||||
|
||||
func (r Router) InvokeUntyped(ctx context.Context, req gogoproto.Message) (res gogoproto.Message, err error) {
|
||||
func (r coreRouterImpl) InvokeUntyped(ctx context.Context, req gogoproto.Message) (res gogoproto.Message, err error) {
|
||||
typeName := msgTypeURL(req)
|
||||
handler, exists := r.handlers[typeName]
|
||||
if !exists {
|
||||
@ -175,3 +175,38 @@ func (r Router) InvokeUntyped(ctx context.Context, req gogoproto.Message) (res g
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
// merge merges together two protobuf messages by setting the pointer
|
||||
// to src in dst. Used internally.
|
||||
func merge(src, dst gogoproto.Message) error {
|
||||
if src == nil {
|
||||
return fmt.Errorf("source message is nil")
|
||||
}
|
||||
if dst == nil {
|
||||
return fmt.Errorf("destination message is nil")
|
||||
}
|
||||
|
||||
srcVal := reflect.ValueOf(src)
|
||||
dstVal := reflect.ValueOf(dst)
|
||||
|
||||
if srcVal.Kind() == reflect.Interface {
|
||||
srcVal = srcVal.Elem()
|
||||
}
|
||||
if dstVal.Kind() == reflect.Interface {
|
||||
dstVal = dstVal.Elem()
|
||||
}
|
||||
|
||||
if srcVal.Kind() != reflect.Ptr || dstVal.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("both source and destination must be pointers")
|
||||
}
|
||||
|
||||
srcElem := srcVal.Elem()
|
||||
dstElem := dstVal.Elem()
|
||||
|
||||
if !srcElem.Type().AssignableTo(dstElem.Type()) {
|
||||
return fmt.Errorf("incompatible types: cannot merge %v into %v", srcElem.Type(), dstElem.Type())
|
||||
}
|
||||
|
||||
dstElem.Set(srcElem)
|
||||
return nil
|
||||
}
|
||||
|
||||
107
server/v2/stf/stf_router_test.go
Normal file
107
server/v2/stf/stf_router_test.go
Normal file
@ -0,0 +1,107 @@
|
||||
package stf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
gogoproto "github.com/cosmos/gogoproto/proto"
|
||||
gogotypes "github.com/cosmos/gogoproto/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cosmossdk.io/core/appmodule/v2"
|
||||
)
|
||||
|
||||
func TestRouter(t *testing.T) {
|
||||
expectedMsg := &gogotypes.BoolValue{Value: true}
|
||||
expectedMsgName := gogoproto.MessageName(expectedMsg)
|
||||
|
||||
expectedResp := &gogotypes.StringValue{Value: "test"}
|
||||
|
||||
router := coreRouterImpl{handlers: map[string]appmodule.Handler{
|
||||
gogoproto.MessageName(expectedMsg): func(ctx context.Context, gotMsg appmodule.Message) (msgResp appmodule.Message, err error) {
|
||||
require.Equal(t, expectedMsg, gotMsg)
|
||||
return expectedResp, nil
|
||||
},
|
||||
}}
|
||||
|
||||
t.Run("can invoke message by name", func(t *testing.T) {
|
||||
err := router.CanInvoke(context.Background(), expectedMsgName)
|
||||
require.NoError(t, err, "must be invokable")
|
||||
})
|
||||
|
||||
t.Run("can invoke message by type URL", func(t *testing.T) {
|
||||
err := router.CanInvoke(context.Background(), "/"+expectedMsgName)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("cannot invoke unknown message", func(t *testing.T) {
|
||||
err := router.CanInvoke(context.Background(), "not exist")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invoke untyped", func(t *testing.T) {
|
||||
gotResp, err := router.InvokeUntyped(context.Background(), expectedMsg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedResp, gotResp)
|
||||
})
|
||||
|
||||
t.Run("invoked typed", func(t *testing.T) {
|
||||
gotResp := new(gogotypes.StringValue)
|
||||
err := router.InvokeTyped(context.Background(), expectedMsg, gotResp)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedResp, gotResp)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMerge(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
src gogoproto.Message
|
||||
dst gogoproto.Message
|
||||
expected gogoproto.Message
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
src: &gogotypes.BoolValue{Value: true},
|
||||
dst: &gogotypes.BoolValue{},
|
||||
expected: &gogotypes.BoolValue{Value: true},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nil src",
|
||||
src: nil,
|
||||
dst: &gogotypes.StringValue{},
|
||||
expected: &gogotypes.StringValue{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil dst",
|
||||
src: &gogotypes.StringValue{Value: "hello"},
|
||||
dst: nil,
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "incompatible types",
|
||||
src: &gogotypes.StringValue{Value: "hello"},
|
||||
dst: &gogotypes.BoolValue{},
|
||||
expected: &gogotypes.BoolValue{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := merge(tt.src, tt.dst)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, tt.dst)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user