cosmos-sdk/server/v2/stf/stf_router.go
testinginprod 90f362d782
refactor(core/handlers): be more devx friendly (#21984)
Co-authored-by: Julien Robert <julien@rbrt.fr>
2024-10-02 05:38:22 +00:00

171 lines
4.7 KiB
Go

package stf
import (
"context"
"errors"
"fmt"
"strings"
gogoproto "github.com/cosmos/gogoproto/proto"
appmodulev2 "cosmossdk.io/core/appmodule/v2"
"cosmossdk.io/core/router"
"cosmossdk.io/core/transaction"
)
var ErrNoHandler = errors.New("no handler")
// NewMsgRouterBuilder is a router that routes messages to their respective handlers.
func NewMsgRouterBuilder() *MsgRouterBuilder {
return &MsgRouterBuilder{
handlers: make(map[string]appmodulev2.HandlerFunc),
preHandlers: make(map[string][]appmodulev2.PreMsgHandler),
postHandlers: make(map[string][]appmodulev2.PostMsgHandler),
}
}
type MsgRouterBuilder struct {
handlers map[string]appmodulev2.HandlerFunc
globalPreHandlers []appmodulev2.PreMsgHandler
preHandlers map[string][]appmodulev2.PreMsgHandler
postHandlers map[string][]appmodulev2.PostMsgHandler
globalPostHandlers []appmodulev2.PostMsgHandler
}
func (b *MsgRouterBuilder) RegisterHandler(msgType string, handler appmodulev2.HandlerFunc) error {
// panic on override
if _, ok := b.handlers[msgType]; ok {
return fmt.Errorf("handler already registered: %s", msgType)
}
b.handlers[msgType] = handler
return nil
}
func (b *MsgRouterBuilder) RegisterGlobalPreMsgHandler(handler appmodulev2.PreMsgHandler) {
b.globalPreHandlers = append(b.globalPreHandlers, handler)
}
func (b *MsgRouterBuilder) RegisterPreMsgHandler(msgType string, handler appmodulev2.PreMsgHandler) {
b.preHandlers[msgType] = append(b.preHandlers[msgType], handler)
}
func (b *MsgRouterBuilder) RegisterPostMsgHandler(msgType string, handler appmodulev2.PostMsgHandler) {
b.postHandlers[msgType] = append(b.postHandlers[msgType], handler)
}
func (b *MsgRouterBuilder) RegisterGlobalPostMsgHandler(handler appmodulev2.PostMsgHandler) {
b.globalPostHandlers = append(b.globalPostHandlers, handler)
}
func (b *MsgRouterBuilder) HandlerExists(msgType string) bool {
_, ok := b.handlers[msgType]
return ok
}
func (b *MsgRouterBuilder) build() (coreRouterImpl, error) {
handlers := make(map[string]appmodulev2.HandlerFunc)
globalPreHandler := func(ctx context.Context, msg transaction.Msg) error {
for _, h := range b.globalPreHandlers {
err := h(ctx, msg)
if err != nil {
return err
}
}
return nil
}
globalPostHandler := func(ctx context.Context, msg, msgResp transaction.Msg) error {
for _, h := range b.globalPostHandlers {
err := h(ctx, msg, msgResp)
if err != nil {
return err
}
}
return nil
}
for msgType, handler := range b.handlers {
// find pre handler
preHandlers := b.preHandlers[msgType]
// find post handler
postHandlers := b.postHandlers[msgType]
// build the handler
handlers[msgType] = buildHandler(handler, preHandlers, globalPreHandler, postHandlers, globalPostHandler)
}
return coreRouterImpl{
handlers: handlers,
}, nil
}
func buildHandler(
handler appmodulev2.HandlerFunc,
preHandlers []appmodulev2.PreMsgHandler,
globalPreHandler appmodulev2.PreMsgHandler,
postHandlers []appmodulev2.PostMsgHandler,
globalPostHandler appmodulev2.PostMsgHandler,
) appmodulev2.HandlerFunc {
return func(ctx context.Context, msg transaction.Msg) (msgResp transaction.Msg, err error) {
if len(preHandlers) != 0 {
for _, preHandler := range preHandlers {
if err := preHandler(ctx, msg); err != nil {
return nil, err
}
}
}
err = globalPreHandler(ctx, msg)
if err != nil {
return nil, err
}
msgResp, err = handler(ctx, msg)
if err != nil {
return nil, err
}
if len(postHandlers) != 0 {
for _, postHandler := range postHandlers {
if err := postHandler(ctx, msg, msgResp); err != nil {
return nil, err
}
}
}
err = globalPostHandler(ctx, msg, msgResp)
return msgResp, err
}
}
// msgTypeURL returns the TypeURL of a proto message.
func msgTypeURL(msg gogoproto.Message) string {
return gogoproto.MessageName(msg)
}
var _ router.Service = (*coreRouterImpl)(nil)
// coreRouterImpl implements the STF router for msg and query handlers.
type coreRouterImpl struct {
handlers map[string]appmodulev2.HandlerFunc
}
func (r coreRouterImpl) CanInvoke(_ context.Context, typeURL string) error {
// trimming prefixes is a backwards compatibility strategy that we use
// for baseapp components that did routing through type URL rather
// than protobuf message names.
typeURL = strings.TrimPrefix(typeURL, "/")
_, exists := r.handlers[typeURL]
if !exists {
return fmt.Errorf("%w: %s", ErrNoHandler, typeURL)
}
return nil
}
func (r coreRouterImpl) Invoke(ctx context.Context, req transaction.Msg) (res transaction.Msg, err error) {
typeName := msgTypeURL(req)
handler, exists := r.handlers[typeName]
if !exists {
return nil, fmt.Errorf("%w: %s", ErrNoHandler, typeName)
}
return handler(ctx, req)
}