171 lines
4.7 KiB
Go
171 lines
4.7 KiB
Go
package stf
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
gogoproto "github.com/cosmos/gogoproto/proto"
|
|
|
|
appmodulev2 "cosmossdk.io/core/appmodule/v2"
|
|
"cosmossdk.io/core/router"
|
|
"cosmossdk.io/core/transaction"
|
|
)
|
|
|
|
var ErrNoHandler = errors.New("no handler")
|
|
|
|
// NewMsgRouterBuilder is a router that routes messages to their respective handlers.
|
|
func NewMsgRouterBuilder() *MsgRouterBuilder {
|
|
return &MsgRouterBuilder{
|
|
handlers: make(map[string]appmodulev2.HandlerFunc),
|
|
preHandlers: make(map[string][]appmodulev2.PreMsgHandler),
|
|
postHandlers: make(map[string][]appmodulev2.PostMsgHandler),
|
|
}
|
|
}
|
|
|
|
type MsgRouterBuilder struct {
|
|
handlers map[string]appmodulev2.HandlerFunc
|
|
globalPreHandlers []appmodulev2.PreMsgHandler
|
|
preHandlers map[string][]appmodulev2.PreMsgHandler
|
|
postHandlers map[string][]appmodulev2.PostMsgHandler
|
|
globalPostHandlers []appmodulev2.PostMsgHandler
|
|
}
|
|
|
|
func (b *MsgRouterBuilder) RegisterHandler(msgType string, handler appmodulev2.HandlerFunc) error {
|
|
// panic on override
|
|
if _, ok := b.handlers[msgType]; ok {
|
|
return fmt.Errorf("handler already registered: %s", msgType)
|
|
}
|
|
b.handlers[msgType] = handler
|
|
return nil
|
|
}
|
|
|
|
func (b *MsgRouterBuilder) RegisterGlobalPreMsgHandler(handler appmodulev2.PreMsgHandler) {
|
|
b.globalPreHandlers = append(b.globalPreHandlers, handler)
|
|
}
|
|
|
|
func (b *MsgRouterBuilder) RegisterPreMsgHandler(msgType string, handler appmodulev2.PreMsgHandler) {
|
|
b.preHandlers[msgType] = append(b.preHandlers[msgType], handler)
|
|
}
|
|
|
|
func (b *MsgRouterBuilder) RegisterPostMsgHandler(msgType string, handler appmodulev2.PostMsgHandler) {
|
|
b.postHandlers[msgType] = append(b.postHandlers[msgType], handler)
|
|
}
|
|
|
|
func (b *MsgRouterBuilder) RegisterGlobalPostMsgHandler(handler appmodulev2.PostMsgHandler) {
|
|
b.globalPostHandlers = append(b.globalPostHandlers, handler)
|
|
}
|
|
|
|
func (b *MsgRouterBuilder) HandlerExists(msgType string) bool {
|
|
_, ok := b.handlers[msgType]
|
|
return ok
|
|
}
|
|
|
|
func (b *MsgRouterBuilder) build() (coreRouterImpl, error) {
|
|
handlers := make(map[string]appmodulev2.HandlerFunc)
|
|
|
|
globalPreHandler := func(ctx context.Context, msg transaction.Msg) error {
|
|
for _, h := range b.globalPreHandlers {
|
|
err := h(ctx, msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
globalPostHandler := func(ctx context.Context, msg, msgResp transaction.Msg) error {
|
|
for _, h := range b.globalPostHandlers {
|
|
err := h(ctx, msg, msgResp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
for msgType, handler := range b.handlers {
|
|
// find pre handler
|
|
preHandlers := b.preHandlers[msgType]
|
|
// find post handler
|
|
postHandlers := b.postHandlers[msgType]
|
|
// build the handler
|
|
handlers[msgType] = buildHandler(handler, preHandlers, globalPreHandler, postHandlers, globalPostHandler)
|
|
}
|
|
|
|
return coreRouterImpl{
|
|
handlers: handlers,
|
|
}, nil
|
|
}
|
|
|
|
func buildHandler(
|
|
handler appmodulev2.HandlerFunc,
|
|
preHandlers []appmodulev2.PreMsgHandler,
|
|
globalPreHandler appmodulev2.PreMsgHandler,
|
|
postHandlers []appmodulev2.PostMsgHandler,
|
|
globalPostHandler appmodulev2.PostMsgHandler,
|
|
) appmodulev2.HandlerFunc {
|
|
return func(ctx context.Context, msg transaction.Msg) (msgResp transaction.Msg, err error) {
|
|
if len(preHandlers) != 0 {
|
|
for _, preHandler := range preHandlers {
|
|
if err := preHandler(ctx, msg); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
err = globalPreHandler(ctx, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
msgResp, err = handler(ctx, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(postHandlers) != 0 {
|
|
for _, postHandler := range postHandlers {
|
|
if err := postHandler(ctx, msg, msgResp); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
err = globalPostHandler(ctx, msg, msgResp)
|
|
return msgResp, err
|
|
}
|
|
}
|
|
|
|
// msgTypeURL returns the TypeURL of a proto message.
|
|
func msgTypeURL(msg gogoproto.Message) string {
|
|
return gogoproto.MessageName(msg)
|
|
}
|
|
|
|
var _ router.Service = (*coreRouterImpl)(nil)
|
|
|
|
// coreRouterImpl implements the STF router for msg and query handlers.
|
|
type coreRouterImpl struct {
|
|
handlers map[string]appmodulev2.HandlerFunc
|
|
}
|
|
|
|
func (r coreRouterImpl) CanInvoke(_ context.Context, typeURL string) error {
|
|
// trimming prefixes is a backwards compatibility strategy that we use
|
|
// for baseapp components that did routing through type URL rather
|
|
// than protobuf message names.
|
|
typeURL = strings.TrimPrefix(typeURL, "/")
|
|
_, exists := r.handlers[typeURL]
|
|
if !exists {
|
|
return fmt.Errorf("%w: %s", ErrNoHandler, typeURL)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r coreRouterImpl) Invoke(ctx context.Context, req transaction.Msg) (res transaction.Msg, err error) {
|
|
typeName := msgTypeURL(req)
|
|
handler, exists := r.handlers[typeName]
|
|
if !exists {
|
|
return nil, fmt.Errorf("%w: %s", ErrNoHandler, typeName)
|
|
}
|
|
|
|
return handler(ctx, req)
|
|
}
|