From 502661ba936ef4394f9ad91c9010ae72b8860720 Mon Sep 17 00:00:00 2001 From: Ezequiel Raynaudo Date: Wed, 4 Sep 2024 12:13:49 -0300 Subject: [PATCH] refactor(staking): check for nil ptrs after GetCachedValue() (#21300) --- x/staking/keeper/cons_pubkey.go | 16 ++++-- x/staking/keeper/grpc_query.go | 9 +++- x/staking/keeper/msg_server.go | 78 +++++++++++++++++----------- x/staking/keeper/val_state_change.go | 16 ++++-- 4 files changed, 79 insertions(+), 40 deletions(-) diff --git a/x/staking/keeper/cons_pubkey.go b/x/staking/keeper/cons_pubkey.go index 0f91710525..e1938597fa 100644 --- a/x/staking/keeper/cons_pubkey.go +++ b/x/staking/keeper/cons_pubkey.go @@ -94,14 +94,22 @@ func (k Keeper) updateToNewPubkey(ctx context.Context, val types.Validator, oldP return err } - oldPk, ok := oldPubKey.GetCachedValue().(cryptotypes.PubKey) + oldPkCached := oldPubKey.GetCachedValue() + if oldPkCached == nil { + return errorsmod.Wrap(sdkerrors.ErrInvalidType, "OldPubKey cached value is nil") + } + oldPk, ok := oldPkCached.(cryptotypes.PubKey) if !ok { - return errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting cryptotypes.PubKey, got %T", oldPk) + return errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting cryptotypes.PubKey, got %T", oldPkCached) } - newPk, ok := newPubKey.GetCachedValue().(cryptotypes.PubKey) + newPkCached := newPubKey.GetCachedValue() + if newPkCached == nil { + return errorsmod.Wrap(sdkerrors.ErrInvalidType, "NewPubKey cached value is nil") + } + newPk, ok := newPkCached.(cryptotypes.PubKey) if !ok { - return errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting cryptotypes.PubKey, got %T", newPk) + return errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting cryptotypes.PubKey, got %T", newPkCached) } // sets a map: oldConsKey -> newConsKey diff --git a/x/staking/keeper/grpc_query.go b/x/staking/keeper/grpc_query.go index 11f5cb4d0c..d6866b8b64 100644 --- a/x/staking/keeper/grpc_query.go +++ b/x/staking/keeper/grpc_query.go @@ -9,6 +9,7 @@ import ( "google.golang.org/grpc/status" "cosmossdk.io/collections" + errorsmod "cosmossdk.io/errors" "cosmossdk.io/store/prefix" storetypes "cosmossdk.io/store/types" "cosmossdk.io/x/staking/types" @@ -16,6 +17,7 @@ import ( cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" "github.com/cosmos/cosmos-sdk/runtime" sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/cosmos/cosmos-sdk/types/query" ) @@ -63,7 +65,12 @@ func (k Querier) Validators(ctx context.Context, req *types.QueryValidatorsReque vals.Validators = append(vals.Validators, *val) valInfo := types.ValidatorInfo{} - cpk, ok := val.ConsensusPubkey.GetCachedValue().(cryptotypes.PubKey) + cv := val.ConsensusPubkey.GetCachedValue() + if cv == nil { + return nil, errorsmod.Wrap(sdkerrors.ErrInvalidType, "public key cached value is nil") + } + + cpk, ok := cv.(cryptotypes.PubKey) if ok { consAddr, err := k.consensusAddressCodec.BytesToString(cpk.Address()) if err == nil { diff --git a/x/staking/keeper/msg_server.go b/x/staking/keeper/msg_server.go index cd8fc499c0..06ac78acff 100644 --- a/x/staking/keeper/msg_server.go +++ b/x/staking/keeper/msg_server.go @@ -39,7 +39,8 @@ func NewMsgServerImpl(keeper *Keeper) types.MsgServer { var _ types.MsgServer = msgServer{} -// CreateValidator defines a method for creating a new validator +// CreateValidator defines a method for creating a new validator. +// The validator's params should not be nil for this function to execute successfully. func (k msgServer) CreateValidator(ctx context.Context, msg *types.MsgCreateValidator) (*types.MsgCreateValidatorResponse, error) { valAddr, err := k.validatorAddressCodec.StringToBytes(msg.ValidatorAddress) if err != nil { @@ -64,9 +65,14 @@ func (k msgServer) CreateValidator(ctx context.Context, msg *types.MsgCreateVali return nil, types.ErrValidatorOwnerExists } - pk, ok := msg.Pubkey.GetCachedValue().(cryptotypes.PubKey) + cv := msg.Pubkey.GetCachedValue() + if cv == nil { + return nil, errorsmod.Wrap(sdkerrors.ErrInvalidType, "Pubkey cached value is nil") + } + + pk, ok := cv.(cryptotypes.PubKey) if !ok { - return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting cryptotypes.PubKey, got %T", msg.Pubkey.GetCachedValue()) + return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting cryptotypes.PubKey, got %T", cv) } resp, err := k.QueryRouterService.Invoke(ctx, &consensusv1.QueryParamsRequest{}) @@ -78,21 +84,12 @@ func (k msgServer) CreateValidator(ctx context.Context, msg *types.MsgCreateVali return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidRequest, "unexpected response type: %T", resp) } - if res.Params.Validator != nil { - pkType := pk.Type() - if !slices.Contains(res.Params.Validator.PubKeyTypes, pkType) { - return nil, errorsmod.Wrapf( - types.ErrValidatorPubKeyTypeNotSupported, - "got: %s, expected: %s", pk.Type(), res.Params.Validator.PubKeyTypes, - ) - } + if res.Params.Validator == nil { + return nil, errorsmod.Wrap(sdkerrors.ErrInvalidRequest, "validator params are not set") + } - if pkType == sdk.PubKeyEd25519Type && len(pk.Bytes()) != ed25519.PubKeySize { - return nil, errorsmod.Wrapf( - types.ErrConsensusPubKeyLenInvalid, - "got: %d, expected: %d", len(pk.Bytes()), ed25519.PubKeySize, - ) - } + if err = validatePubKey(pk, res.Params.Validator.PubKeyTypes); err != nil { + return nil, err } err = k.checkConsKeyAlreadyUsed(ctx, pk) @@ -649,8 +646,15 @@ func (k msgServer) UpdateParams(ctx context.Context, msg *types.MsgUpdateParams) return &types.MsgUpdateParamsResponse{}, nil } +// RotateConsPubKey handles the rotation of a validator's consensus public key. +// It validates the new key, checks for conflicts, and updates the necessary state. +// The function requires that the validator params are not nil for successful execution. func (k msgServer) RotateConsPubKey(ctx context.Context, msg *types.MsgRotateConsPubKey) (res *types.MsgRotateConsPubKeyResponse, err error) { cv := msg.NewPubkey.GetCachedValue() + if cv == nil { + return nil, errorsmod.Wrap(sdkerrors.ErrInvalidType, "new public key is nil") + } + pk, ok := cv.(cryptotypes.PubKey) if !ok { return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "expecting cryptotypes.PubKey, got %T", cv) @@ -666,21 +670,12 @@ func (k msgServer) RotateConsPubKey(ctx context.Context, msg *types.MsgRotateCon return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidRequest, "unexpected response type: %T", resp) } - if paramsRes.Params.Validator != nil { - pkType := pk.Type() - if !slices.Contains(paramsRes.Params.Validator.PubKeyTypes, pkType) { - return nil, errorsmod.Wrapf( - types.ErrValidatorPubKeyTypeNotSupported, - "got: %s, expected: %s", pk.Type(), paramsRes.Params.Validator.PubKeyTypes, - ) - } + if paramsRes.Params.Validator == nil { + return nil, errorsmod.Wrap(sdkerrors.ErrInvalidRequest, "validator params are not set") + } - if pkType == sdk.PubKeyEd25519Type && len(pk.Bytes()) != ed25519.PubKeySize { - return nil, errorsmod.Wrapf( - types.ErrConsensusPubKeyLenInvalid, - "got: %d, expected: %d", len(pk.Bytes()), ed25519.PubKeySize, - ) - } + if err = validatePubKey(pk, paramsRes.Params.Validator.PubKeyTypes); err != nil { + return nil, err } err = k.checkConsKeyAlreadyUsed(ctx, pk) @@ -778,3 +773,24 @@ func (k msgServer) checkConsKeyAlreadyUsed(ctx context.Context, newConsPubKey cr return nil } + +func validatePubKey(pk cryptotypes.PubKey, knownPubKeyTypes []string) error { + pkType := pk.Type() + if !slices.Contains(knownPubKeyTypes, pkType) { + return errorsmod.Wrapf( + types.ErrValidatorPubKeyTypeNotSupported, + "got: %s, expected: %s", pk.Type(), knownPubKeyTypes, + ) + } + + if pkType == sdk.PubKeyEd25519Type { + if len(pk.Bytes()) != ed25519.PubKeySize { + return errorsmod.Wrapf( + types.ErrConsensusPubKeyLenInvalid, + "invalid Ed25519 pubkey size: got %d, expected %d", len(pk.Bytes()), ed25519.PubKeySize, + ) + } + } + + return nil +} diff --git a/x/staking/keeper/val_state_change.go b/x/staking/keeper/val_state_change.go index 8d71114c15..867dc89dd5 100644 --- a/x/staking/keeper/val_state_change.go +++ b/x/staking/keeper/val_state_change.go @@ -266,14 +266,22 @@ func (k Keeper) ApplyAndReturnValidatorSetUpdates(ctx context.Context) ([]appmod return nil, err } - oldPk, ok := history.OldConsPubkey.GetCachedValue().(cryptotypes.PubKey) + oldPkCached := history.OldConsPubkey.GetCachedValue() + if oldPkCached == nil { + return nil, errorsmod.Wrap(sdkerrors.ErrInvalidType, "OldConsPubkey cached value is nil") + } + oldPk, ok := oldPkCached.(cryptotypes.PubKey) if !ok { - return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting cryptotypes.PubKey, got %T", oldPk) + return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting cryptotypes.PubKey, got %T", oldPkCached) } - newPk, ok := history.NewConsPubkey.GetCachedValue().(cryptotypes.PubKey) + newPkCached := history.NewConsPubkey.GetCachedValue() + if newPkCached == nil { + return nil, errorsmod.Wrap(sdkerrors.ErrInvalidType, "NewConsPubkey cached value is nil") + } + newPk, ok := newPkCached.(cryptotypes.PubKey) if !ok { - return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting cryptotypes.PubKey, got %T", newPk) + return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting cryptotypes.PubKey, got %T", newPkCached) } // a validator cannot rotate keys if it's not bonded or if it's jailed