fix(server/v2/stf): include safety checks to the execution context (#21359)

This commit is contained in:
Randy Grok 2024-08-23 12:22:17 +02:00 committed by GitHub
parent 0aa9eeb533
commit 8ddea56bb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 153 additions and 33 deletions

View File

@ -14,7 +14,12 @@ var _ branch.Service = (*BranchService)(nil)
type BranchService struct{}
func (bs BranchService) Execute(ctx context.Context, f func(ctx context.Context) error) error {
return bs.execute(ctx.(*executionContext), f)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return err
}
return bs.execute(exCtx, f)
}
func (bs BranchService) ExecuteWithGasLimit(
@ -22,18 +27,21 @@ func (bs BranchService) ExecuteWithGasLimit(
gasLimit uint64,
f func(ctx context.Context) error,
) (gasUsed uint64, err error) {
stfCtx := ctx.(*executionContext)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return 0, err
}
originalGasMeter := stfCtx.meter
originalGasMeter := exCtx.meter
stfCtx.setGasLimit(gasLimit)
exCtx.setGasLimit(gasLimit)
// execute branched, with predefined gas limit.
err = bs.execute(stfCtx, f)
err = bs.execute(exCtx, f)
// restore original context
gasUsed = stfCtx.meter.Limit() - stfCtx.meter.Remaining()
gasUsed = exCtx.meter.Limit() - exCtx.meter.Remaining()
_ = originalGasMeter.Consume(gasUsed, "execute-with-gas-limit")
stfCtx.setGasLimit(originalGasMeter.Limit() - originalGasMeter.Remaining())
exCtx.setGasLimit(originalGasMeter.Limit() - originalGasMeter.Remaining())
return gasUsed, err
}

View File

@ -11,7 +11,7 @@ import (
gogoproto "github.com/cosmos/gogoproto/proto"
"cosmossdk.io/core/event"
transaction "cosmossdk.io/core/transaction"
"cosmossdk.io/core/transaction"
)
func NewEventService() event.Service {
@ -22,7 +22,12 @@ type eventService struct{}
// EventManager implements event.Service.
func (eventService) EventManager(ctx context.Context) event.Manager {
return &eventManager{ctx.(*executionContext)}
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
panic(err)
}
return &eventManager{exCtx}
}
var _ event.Manager = (*eventManager)(nil)

View File

@ -30,5 +30,10 @@ func (g gasService) GasConfig(ctx context.Context) gas.GasConfig {
}
func (g gasService) GasMeter(ctx context.Context) gas.Meter {
return ctx.(*executionContext).meter
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
panic(err)
}
return exCtx.meter
}

View File

@ -12,7 +12,12 @@ var _ header.Service = (*HeaderService)(nil)
type HeaderService struct{}
func (h HeaderService) HeaderInfo(ctx context.Context) header.Info {
return ctx.(*executionContext).headerInfo
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
panic(err)
}
return exCtx.headerInfo
}
const headerInfoPrefix = 0x37

View File

@ -23,19 +23,34 @@ type msgRouterService struct {
// CanInvoke returns an error if the given message cannot be invoked.
func (m msgRouterService) CanInvoke(ctx context.Context, typeURL string) error {
return ctx.(*executionContext).msgRouter.CanInvoke(ctx, typeURL)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return err
}
return exCtx.msgRouter.CanInvoke(ctx, typeURL)
}
// InvokeTyped execute a message and fill-in a response.
// The response must be known and passed as a parameter.
// Use InvokeUntyped if the response type is not known.
func (m msgRouterService) InvokeTyped(ctx context.Context, msg, resp transaction.Msg) error {
return ctx.(*executionContext).msgRouter.InvokeTyped(ctx, msg, resp)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return err
}
return exCtx.msgRouter.InvokeTyped(ctx, msg, resp)
}
// InvokeUntyped execute a message and returns a response.
func (m msgRouterService) InvokeUntyped(ctx context.Context, msg transaction.Msg) (transaction.Msg, error) {
return ctx.(*executionContext).msgRouter.InvokeUntyped(ctx, msg)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return nil, err
}
return exCtx.msgRouter.InvokeUntyped(ctx, msg)
}
// NewQueryRouterService implements router.Service.
@ -49,7 +64,12 @@ type queryRouterService struct{}
// CanInvoke returns an error if the given request cannot be invoked.
func (m queryRouterService) CanInvoke(ctx context.Context, typeURL string) error {
return ctx.(*executionContext).queryRouter.CanInvoke(ctx, typeURL)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return err
}
return exCtx.queryRouter.CanInvoke(ctx, typeURL)
}
// InvokeTyped execute a message and fill-in a response.
@ -59,7 +79,12 @@ func (m queryRouterService) InvokeTyped(
ctx context.Context,
req, resp transaction.Msg,
) error {
return ctx.(*executionContext).queryRouter.InvokeTyped(ctx, req, resp)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return err
}
return exCtx.queryRouter.InvokeTyped(ctx, req, resp)
}
// InvokeUntyped execute a message and returns a response.
@ -67,5 +92,10 @@ func (m queryRouterService) InvokeUntyped(
ctx context.Context,
req transaction.Msg,
) (transaction.Msg, error) {
return ctx.(*executionContext).queryRouter.InvokeUntyped(ctx, req)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return nil, err
}
return exCtx.queryRouter.InvokeUntyped(ctx, req)
}

View File

@ -21,7 +21,12 @@ type storeService struct {
}
func (s storeService) OpenKVStore(ctx context.Context) store.KVStore {
state, err := ctx.(*executionContext).state.GetWriter(s.actor)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
panic(err)
}
state, err := exCtx.state.GetWriter(s.actor)
if err != nil {
panic(err)
}

View File

@ -1,13 +0,0 @@
package stf
import (
"context"
)
func GetExecutionContext(ctx context.Context) *executionContext {
executionCtx, ok := ctx.(*executionContext)
if !ok {
return nil
}
return executionCtx
}

View File

@ -22,6 +22,10 @@ import (
// Identity defines STF's bytes identity and it's used by STF to store things in its own state.
var Identity = []byte("stf")
type eContextKey struct{}
var executionContextKey = eContextKey{}
// STF is a struct that manages the state transition component of the app.
type STF[T transaction.Tx] struct {
logger log.Logger
@ -529,6 +533,14 @@ func (e *executionContext) setGasLimit(limit uint64) {
e.state = meteredState
}
func (e *executionContext) Value(key any) any {
if key == executionContextKey {
return e
}
return e.Context.Value(key)
}
// TODO: too many calls to makeContext can be expensive
// makeContext creates and returns a new execution context for the STF[T] type.
// It takes in the following parameters:

View File

@ -11,7 +11,7 @@ import (
appmodulev2 "cosmossdk.io/core/appmodule/v2"
"cosmossdk.io/core/router"
transaction "cosmossdk.io/core/transaction"
"cosmossdk.io/core/transaction"
)
var ErrNoHandler = errors.New("no handler")

View File

@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
"cosmossdk.io/core/appmodule/v2"
transaction "cosmossdk.io/core/transaction"
"cosmossdk.io/core/transaction"
)
func TestRouter(t *testing.T) {

20
server/v2/stf/util.go Normal file
View File

@ -0,0 +1,20 @@
package stf
import (
"context"
"fmt"
)
// getExecutionCtxFromContext tries to get the execution context from the given go context.
func getExecutionCtxFromContext(ctx context.Context) (*executionContext, error) {
if ec, ok := ctx.(*executionContext); ok {
return ec, nil
}
value, ok := ctx.Value(executionContextKey).(*executionContext)
if ok {
return value, nil
}
return nil, fmt.Errorf("failed to get executionContext from context")
}

View File

@ -0,0 +1,43 @@
package stf
import (
"context"
"testing"
)
func TestGetExecutionCtxFromContext(t *testing.T) {
t.Run("direct type *executionContext", func(t *testing.T) {
ec := &executionContext{}
result, err := getExecutionCtxFromContext(ec)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if result != ec {
t.Fatalf("expected %v, got %v", ec, result)
}
})
t.Run("context value of type *executionContext", func(t *testing.T) {
ec := &executionContext{}
ctx := context.WithValue(context.Background(), executionContextKey, ec)
result, err := getExecutionCtxFromContext(ctx)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if result != ec {
t.Fatalf("expected %v, got %v", ec, result)
}
})
t.Run("invalid context type or value", func(t *testing.T) {
ctx := context.Background()
_, err := getExecutionCtxFromContext(ctx)
if err == nil {
t.Fatalf("expected error, got nil")
}
expectedErr := "failed to get executionContext from context"
if err.Error() != expectedErr {
t.Fatalf("expected error message %v, got %v", expectedErr, err.Error())
}
})
}