feat(stf/router): support backwards compat type URL in router (#21177)

This commit is contained in:
testinginprod 2024-08-07 15:17:42 +02:00 committed by GitHub
parent 90fd6320a6
commit 4dc9469320
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 160 additions and 18 deletions

View File

@ -26,8 +26,8 @@ var Identity = []byte("stf")
type STF[T transaction.Tx] struct {
logger log.Logger
msgRouter Router
queryRouter Router
msgRouter coreRouterImpl
queryRouter coreRouterImpl
doPreBlock func(ctx context.Context, txs []T) error
doBeginBlock func(ctx context.Context) error
@ -584,8 +584,8 @@ func newExecutionContext(
sender transaction.Identity,
state store.WriterMap,
execMode transaction.ExecMode,
msgRouter Router,
queryRouter Router,
msgRouter coreRouterImpl,
queryRouter coreRouterImpl,
) *executionContext {
meter := makeGasMeterFn(gas.NoGasLimit)
meteredState := makeGasMeteredStoreFn(meter, state)

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"reflect"
"strings"
gogoproto "github.com/cosmos/gogoproto/proto"
@ -61,7 +62,7 @@ func (b *MsgRouterBuilder) HandlerExists(msgType string) bool {
return ok
}
func (b *MsgRouterBuilder) Build() (Router, error) {
func (b *MsgRouterBuilder) Build() (coreRouterImpl, error) {
handlers := make(map[string]appmodulev2.Handler)
globalPreHandler := func(ctx context.Context, msg appmodulev2.Message) error {
@ -93,7 +94,7 @@ func (b *MsgRouterBuilder) Build() (Router, error) {
handlers[msgType] = buildHandler(handler, preHandlers, globalPreHandler, postHandlers, globalPostHandler)
}
return Router{
return coreRouterImpl{
handlers: handlers,
}, nil
}
@ -139,14 +140,18 @@ func msgTypeURL(msg gogoproto.Message) string {
return gogoproto.MessageName(msg)
}
var _ router.Service = (*Router)(nil)
var _ router.Service = (*coreRouterImpl)(nil)
// Router implements the STF router for msg and query handlers.
type Router struct {
// coreRouterImpl implements the STF router for msg and query handlers.
type coreRouterImpl struct {
handlers map[string]appmodulev2.Handler
}
func (r Router) CanInvoke(_ context.Context, typeURL string) error {
func (r coreRouterImpl) CanInvoke(_ context.Context, typeURL string) error {
// trimming prefixes is a backwards compatibility strategy that we use
// for baseapp components that did routing through type URL rather
// than protobuf message names.
typeURL = strings.TrimPrefix(typeURL, "/")
_, exists := r.handlers[typeURL]
if !exists {
return fmt.Errorf("%w: %s", ErrNoHandler, typeURL)
@ -154,20 +159,15 @@ func (r Router) CanInvoke(_ context.Context, typeURL string) error {
return nil
}
func (r Router) InvokeTyped(ctx context.Context, req, resp gogoproto.Message) error {
func (r coreRouterImpl) InvokeTyped(ctx context.Context, req, resp gogoproto.Message) error {
handlerResp, err := r.InvokeUntyped(ctx, req)
if err != nil {
return err
}
merge(handlerResp, resp)
return nil
return merge(handlerResp, resp)
}
func merge(src, dst gogoproto.Message) {
reflect.Indirect(reflect.ValueOf(dst)).Set(reflect.Indirect(reflect.ValueOf(src)))
}
func (r Router) InvokeUntyped(ctx context.Context, req gogoproto.Message) (res gogoproto.Message, err error) {
func (r coreRouterImpl) InvokeUntyped(ctx context.Context, req gogoproto.Message) (res gogoproto.Message, err error) {
typeName := msgTypeURL(req)
handler, exists := r.handlers[typeName]
if !exists {
@ -175,3 +175,38 @@ func (r Router) InvokeUntyped(ctx context.Context, req gogoproto.Message) (res g
}
return handler(ctx, req)
}
// merge merges together two protobuf messages by setting the pointer
// to src in dst. Used internally.
func merge(src, dst gogoproto.Message) error {
if src == nil {
return fmt.Errorf("source message is nil")
}
if dst == nil {
return fmt.Errorf("destination message is nil")
}
srcVal := reflect.ValueOf(src)
dstVal := reflect.ValueOf(dst)
if srcVal.Kind() == reflect.Interface {
srcVal = srcVal.Elem()
}
if dstVal.Kind() == reflect.Interface {
dstVal = dstVal.Elem()
}
if srcVal.Kind() != reflect.Ptr || dstVal.Kind() != reflect.Ptr {
return fmt.Errorf("both source and destination must be pointers")
}
srcElem := srcVal.Elem()
dstElem := dstVal.Elem()
if !srcElem.Type().AssignableTo(dstElem.Type()) {
return fmt.Errorf("incompatible types: cannot merge %v into %v", srcElem.Type(), dstElem.Type())
}
dstElem.Set(srcElem)
return nil
}

View File

@ -0,0 +1,107 @@
package stf
import (
"context"
"testing"
gogoproto "github.com/cosmos/gogoproto/proto"
gogotypes "github.com/cosmos/gogoproto/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cosmossdk.io/core/appmodule/v2"
)
func TestRouter(t *testing.T) {
expectedMsg := &gogotypes.BoolValue{Value: true}
expectedMsgName := gogoproto.MessageName(expectedMsg)
expectedResp := &gogotypes.StringValue{Value: "test"}
router := coreRouterImpl{handlers: map[string]appmodule.Handler{
gogoproto.MessageName(expectedMsg): func(ctx context.Context, gotMsg appmodule.Message) (msgResp appmodule.Message, err error) {
require.Equal(t, expectedMsg, gotMsg)
return expectedResp, nil
},
}}
t.Run("can invoke message by name", func(t *testing.T) {
err := router.CanInvoke(context.Background(), expectedMsgName)
require.NoError(t, err, "must be invokable")
})
t.Run("can invoke message by type URL", func(t *testing.T) {
err := router.CanInvoke(context.Background(), "/"+expectedMsgName)
require.NoError(t, err)
})
t.Run("cannot invoke unknown message", func(t *testing.T) {
err := router.CanInvoke(context.Background(), "not exist")
require.Error(t, err)
})
t.Run("invoke untyped", func(t *testing.T) {
gotResp, err := router.InvokeUntyped(context.Background(), expectedMsg)
require.NoError(t, err)
require.Equal(t, expectedResp, gotResp)
})
t.Run("invoked typed", func(t *testing.T) {
gotResp := new(gogotypes.StringValue)
err := router.InvokeTyped(context.Background(), expectedMsg, gotResp)
require.NoError(t, err)
require.Equal(t, expectedResp, gotResp)
})
}
func TestMerge(t *testing.T) {
tests := []struct {
name string
src gogoproto.Message
dst gogoproto.Message
expected gogoproto.Message
wantErr bool
}{
{
name: "success",
src: &gogotypes.BoolValue{Value: true},
dst: &gogotypes.BoolValue{},
expected: &gogotypes.BoolValue{Value: true},
wantErr: false,
},
{
name: "nil src",
src: nil,
dst: &gogotypes.StringValue{},
expected: &gogotypes.StringValue{},
wantErr: true,
},
{
name: "nil dst",
src: &gogotypes.StringValue{Value: "hello"},
dst: nil,
expected: nil,
wantErr: true,
},
{
name: "incompatible types",
src: &gogotypes.StringValue{Value: "hello"},
dst: &gogotypes.BoolValue{},
expected: &gogotypes.BoolValue{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := merge(tt.src, tt.dst)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, tt.dst)
}
})
}
}