fix(server/v2/stf): include safety checks to the execution context (#21359)
This commit is contained in:
parent
0aa9eeb533
commit
8ddea56bb2
@ -14,7 +14,12 @@ var _ branch.Service = (*BranchService)(nil)
|
||||
type BranchService struct{}
|
||||
|
||||
func (bs BranchService) Execute(ctx context.Context, f func(ctx context.Context) error) error {
|
||||
return bs.execute(ctx.(*executionContext), f)
|
||||
exCtx, err := getExecutionCtxFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bs.execute(exCtx, f)
|
||||
}
|
||||
|
||||
func (bs BranchService) ExecuteWithGasLimit(
|
||||
@ -22,18 +27,21 @@ func (bs BranchService) ExecuteWithGasLimit(
|
||||
gasLimit uint64,
|
||||
f func(ctx context.Context) error,
|
||||
) (gasUsed uint64, err error) {
|
||||
stfCtx := ctx.(*executionContext)
|
||||
exCtx, err := getExecutionCtxFromContext(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
originalGasMeter := stfCtx.meter
|
||||
originalGasMeter := exCtx.meter
|
||||
|
||||
stfCtx.setGasLimit(gasLimit)
|
||||
exCtx.setGasLimit(gasLimit)
|
||||
|
||||
// execute branched, with predefined gas limit.
|
||||
err = bs.execute(stfCtx, f)
|
||||
err = bs.execute(exCtx, f)
|
||||
// restore original context
|
||||
gasUsed = stfCtx.meter.Limit() - stfCtx.meter.Remaining()
|
||||
gasUsed = exCtx.meter.Limit() - exCtx.meter.Remaining()
|
||||
_ = originalGasMeter.Consume(gasUsed, "execute-with-gas-limit")
|
||||
stfCtx.setGasLimit(originalGasMeter.Limit() - originalGasMeter.Remaining())
|
||||
exCtx.setGasLimit(originalGasMeter.Limit() - originalGasMeter.Remaining())
|
||||
|
||||
return gasUsed, err
|
||||
}
|
||||
|
||||
@ -11,7 +11,7 @@ import (
|
||||
gogoproto "github.com/cosmos/gogoproto/proto"
|
||||
|
||||
"cosmossdk.io/core/event"
|
||||
transaction "cosmossdk.io/core/transaction"
|
||||
"cosmossdk.io/core/transaction"
|
||||
)
|
||||
|
||||
func NewEventService() event.Service {
|
||||
@ -22,7 +22,12 @@ type eventService struct{}
|
||||
|
||||
// EventManager implements event.Service.
|
||||
func (eventService) EventManager(ctx context.Context) event.Manager {
|
||||
return &eventManager{ctx.(*executionContext)}
|
||||
exCtx, err := getExecutionCtxFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &eventManager{exCtx}
|
||||
}
|
||||
|
||||
var _ event.Manager = (*eventManager)(nil)
|
||||
|
||||
@ -30,5 +30,10 @@ func (g gasService) GasConfig(ctx context.Context) gas.GasConfig {
|
||||
}
|
||||
|
||||
func (g gasService) GasMeter(ctx context.Context) gas.Meter {
|
||||
return ctx.(*executionContext).meter
|
||||
exCtx, err := getExecutionCtxFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return exCtx.meter
|
||||
}
|
||||
|
||||
@ -12,7 +12,12 @@ var _ header.Service = (*HeaderService)(nil)
|
||||
type HeaderService struct{}
|
||||
|
||||
func (h HeaderService) HeaderInfo(ctx context.Context) header.Info {
|
||||
return ctx.(*executionContext).headerInfo
|
||||
exCtx, err := getExecutionCtxFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return exCtx.headerInfo
|
||||
}
|
||||
|
||||
const headerInfoPrefix = 0x37
|
||||
|
||||
@ -23,19 +23,34 @@ type msgRouterService struct {
|
||||
|
||||
// CanInvoke returns an error if the given message cannot be invoked.
|
||||
func (m msgRouterService) CanInvoke(ctx context.Context, typeURL string) error {
|
||||
return ctx.(*executionContext).msgRouter.CanInvoke(ctx, typeURL)
|
||||
exCtx, err := getExecutionCtxFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return exCtx.msgRouter.CanInvoke(ctx, typeURL)
|
||||
}
|
||||
|
||||
// InvokeTyped execute a message and fill-in a response.
|
||||
// The response must be known and passed as a parameter.
|
||||
// Use InvokeUntyped if the response type is not known.
|
||||
func (m msgRouterService) InvokeTyped(ctx context.Context, msg, resp transaction.Msg) error {
|
||||
return ctx.(*executionContext).msgRouter.InvokeTyped(ctx, msg, resp)
|
||||
exCtx, err := getExecutionCtxFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return exCtx.msgRouter.InvokeTyped(ctx, msg, resp)
|
||||
}
|
||||
|
||||
// InvokeUntyped execute a message and returns a response.
|
||||
func (m msgRouterService) InvokeUntyped(ctx context.Context, msg transaction.Msg) (transaction.Msg, error) {
|
||||
return ctx.(*executionContext).msgRouter.InvokeUntyped(ctx, msg)
|
||||
exCtx, err := getExecutionCtxFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return exCtx.msgRouter.InvokeUntyped(ctx, msg)
|
||||
}
|
||||
|
||||
// NewQueryRouterService implements router.Service.
|
||||
@ -49,7 +64,12 @@ type queryRouterService struct{}
|
||||
|
||||
// CanInvoke returns an error if the given request cannot be invoked.
|
||||
func (m queryRouterService) CanInvoke(ctx context.Context, typeURL string) error {
|
||||
return ctx.(*executionContext).queryRouter.CanInvoke(ctx, typeURL)
|
||||
exCtx, err := getExecutionCtxFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return exCtx.queryRouter.CanInvoke(ctx, typeURL)
|
||||
}
|
||||
|
||||
// InvokeTyped execute a message and fill-in a response.
|
||||
@ -59,7 +79,12 @@ func (m queryRouterService) InvokeTyped(
|
||||
ctx context.Context,
|
||||
req, resp transaction.Msg,
|
||||
) error {
|
||||
return ctx.(*executionContext).queryRouter.InvokeTyped(ctx, req, resp)
|
||||
exCtx, err := getExecutionCtxFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return exCtx.queryRouter.InvokeTyped(ctx, req, resp)
|
||||
}
|
||||
|
||||
// InvokeUntyped execute a message and returns a response.
|
||||
@ -67,5 +92,10 @@ func (m queryRouterService) InvokeUntyped(
|
||||
ctx context.Context,
|
||||
req transaction.Msg,
|
||||
) (transaction.Msg, error) {
|
||||
return ctx.(*executionContext).queryRouter.InvokeUntyped(ctx, req)
|
||||
exCtx, err := getExecutionCtxFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return exCtx.queryRouter.InvokeUntyped(ctx, req)
|
||||
}
|
||||
|
||||
@ -21,7 +21,12 @@ type storeService struct {
|
||||
}
|
||||
|
||||
func (s storeService) OpenKVStore(ctx context.Context) store.KVStore {
|
||||
state, err := ctx.(*executionContext).state.GetWriter(s.actor)
|
||||
exCtx, err := getExecutionCtxFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
state, err := exCtx.state.GetWriter(s.actor)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
package stf
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
func GetExecutionContext(ctx context.Context) *executionContext {
|
||||
executionCtx, ok := ctx.(*executionContext)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return executionCtx
|
||||
}
|
||||
@ -22,6 +22,10 @@ import (
|
||||
// Identity defines STF's bytes identity and it's used by STF to store things in its own state.
|
||||
var Identity = []byte("stf")
|
||||
|
||||
type eContextKey struct{}
|
||||
|
||||
var executionContextKey = eContextKey{}
|
||||
|
||||
// STF is a struct that manages the state transition component of the app.
|
||||
type STF[T transaction.Tx] struct {
|
||||
logger log.Logger
|
||||
@ -529,6 +533,14 @@ func (e *executionContext) setGasLimit(limit uint64) {
|
||||
e.state = meteredState
|
||||
}
|
||||
|
||||
func (e *executionContext) Value(key any) any {
|
||||
if key == executionContextKey {
|
||||
return e
|
||||
}
|
||||
|
||||
return e.Context.Value(key)
|
||||
}
|
||||
|
||||
// TODO: too many calls to makeContext can be expensive
|
||||
// makeContext creates and returns a new execution context for the STF[T] type.
|
||||
// It takes in the following parameters:
|
||||
|
||||
@ -11,7 +11,7 @@ import (
|
||||
|
||||
appmodulev2 "cosmossdk.io/core/appmodule/v2"
|
||||
"cosmossdk.io/core/router"
|
||||
transaction "cosmossdk.io/core/transaction"
|
||||
"cosmossdk.io/core/transaction"
|
||||
)
|
||||
|
||||
var ErrNoHandler = errors.New("no handler")
|
||||
|
||||
@ -10,7 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cosmossdk.io/core/appmodule/v2"
|
||||
transaction "cosmossdk.io/core/transaction"
|
||||
"cosmossdk.io/core/transaction"
|
||||
)
|
||||
|
||||
func TestRouter(t *testing.T) {
|
||||
|
||||
20
server/v2/stf/util.go
Normal file
20
server/v2/stf/util.go
Normal file
@ -0,0 +1,20 @@
|
||||
package stf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// getExecutionCtxFromContext tries to get the execution context from the given go context.
|
||||
func getExecutionCtxFromContext(ctx context.Context) (*executionContext, error) {
|
||||
if ec, ok := ctx.(*executionContext); ok {
|
||||
return ec, nil
|
||||
}
|
||||
|
||||
value, ok := ctx.Value(executionContextKey).(*executionContext)
|
||||
if ok {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to get executionContext from context")
|
||||
}
|
||||
43
server/v2/stf/util_test.go
Normal file
43
server/v2/stf/util_test.go
Normal file
@ -0,0 +1,43 @@
|
||||
package stf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetExecutionCtxFromContext(t *testing.T) {
|
||||
t.Run("direct type *executionContext", func(t *testing.T) {
|
||||
ec := &executionContext{}
|
||||
result, err := getExecutionCtxFromContext(ec)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if result != ec {
|
||||
t.Fatalf("expected %v, got %v", ec, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("context value of type *executionContext", func(t *testing.T) {
|
||||
ec := &executionContext{}
|
||||
ctx := context.WithValue(context.Background(), executionContextKey, ec)
|
||||
result, err := getExecutionCtxFromContext(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if result != ec {
|
||||
t.Fatalf("expected %v, got %v", ec, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid context type or value", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, err := getExecutionCtxFromContext(ctx)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
expectedErr := "failed to get executionContext from context"
|
||||
if err.Error() != expectedErr {
|
||||
t.Fatalf("expected error message %v, got %v", expectedErr, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user