From 229eb30354932e07ce82745226c5ab82b1f10a31 Mon Sep 17 00:00:00 2001 From: Prathamesh Musale Date: Mon, 4 Mar 2024 19:45:24 +0530 Subject: [PATCH] Add msg validations for bond and registry module --- x/auction/keeper/msg_server.go | 6 +- x/bond/keeper/msg_server.go | 18 ++++++ x/bond/msgs.go | 48 ++++++++++++++++ x/registry/keeper/msg_server.go | 40 ++++++++++++++ x/registry/msgs.go | 97 +++++++++++++++++++++++++++++++++ 5 files changed, 205 insertions(+), 4 deletions(-) diff --git a/x/auction/keeper/msg_server.go b/x/auction/keeper/msg_server.go index 382b1691..b2654d8b 100644 --- a/x/auction/keeper/msg_server.go +++ b/x/auction/keeper/msg_server.go @@ -52,8 +52,7 @@ func (ms msgServer) CreateAuction(c context.Context, msg *auctiontypes.MsgCreate // CommitBid is the command for committing a bid // nolint: all func (ms msgServer) CommitBid(c context.Context, msg *auctiontypes.MsgCommitBid) (*auctiontypes.MsgCommitBidResponse, error) { - err := msg.ValidateBasic() - if err != nil { + if err := msg.ValidateBasic(); err != nil { return nil, err } @@ -88,8 +87,7 @@ func (ms msgServer) CommitBid(c context.Context, msg *auctiontypes.MsgCommitBid) // RevealBid is the command for revealing a bid // nolint: all func (ms msgServer) RevealBid(c context.Context, msg *auctiontypes.MsgRevealBid) (*auctiontypes.MsgRevealBidResponse, error) { - err := msg.ValidateBasic() - if err != nil { + if err := msg.ValidateBasic(); err != nil { return nil, err } diff --git a/x/bond/keeper/msg_server.go b/x/bond/keeper/msg_server.go index 719f17b4..21b8630f 100644 --- a/x/bond/keeper/msg_server.go +++ b/x/bond/keeper/msg_server.go @@ -20,6 +20,10 @@ func NewMsgServerImpl(keeper *Keeper) bond.MsgServer { } func (ms msgServer) CreateBond(c context.Context, msg *bond.MsgCreateBond) (*bond.MsgCreateBondResponse, error) { + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) signerAddress, err := sdk.AccAddressFromBech32(msg.Signer) @@ -50,6 +54,10 @@ func (ms msgServer) CreateBond(c context.Context, msg *bond.MsgCreateBond) (*bon // RefillBond implements bond.MsgServer. func (ms msgServer) RefillBond(c context.Context, msg *bond.MsgRefillBond) (*bond.MsgRefillBondResponse, error) { + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) signerAddress, err := sdk.AccAddressFromBech32(msg.Signer) @@ -81,6 +89,10 @@ func (ms msgServer) RefillBond(c context.Context, msg *bond.MsgRefillBond) (*bon // WithdrawBond implements bond.MsgServer. func (ms msgServer) WithdrawBond(c context.Context, msg *bond.MsgWithdrawBond) (*bond.MsgWithdrawBondResponse, error) { + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) signerAddress, err := sdk.AccAddressFromBech32(msg.Signer) @@ -112,11 +124,17 @@ func (ms msgServer) WithdrawBond(c context.Context, msg *bond.MsgWithdrawBond) ( // CancelBond implements bond.MsgServer. func (ms msgServer) CancelBond(c context.Context, msg *bond.MsgCancelBond) (*bond.MsgCancelBondResponse, error) { + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) + signerAddress, err := sdk.AccAddressFromBech32(msg.Signer) if err != nil { return nil, err } + _, err = ms.k.CancelBond(ctx, msg.Id, signerAddress) if err != nil { return nil, err diff --git a/x/bond/msgs.go b/x/bond/msgs.go index c473c612..135127f6 100644 --- a/x/bond/msgs.go +++ b/x/bond/msgs.go @@ -1,7 +1,9 @@ package bond import ( + errorsmod "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" ) var ( @@ -15,3 +17,49 @@ func NewMsgCreateBond(coins sdk.Coins, signer sdk.AccAddress) MsgCreateBond { Signer: signer.String(), } } + +func (msg MsgCreateBond) ValidateBasic() error { + if len(msg.Signer) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidAddress, msg.Signer) + } + if len(msg.Coins) == 0 || !msg.Coins.IsValid() { + return errorsmod.Wrap(sdkerrors.ErrInvalidCoins, "Invalid amount.") + } + return nil +} + +func (msg MsgRefillBond) ValidateBasic() error { + if len(msg.Id) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidRequest, msg.Id) + } + if len(msg.Signer) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidAddress, msg.Signer) + } + if len(msg.Coins) == 0 || !msg.Coins.IsValid() { + return errorsmod.Wrap(sdkerrors.ErrInvalidCoins, "Invalid amount.") + } + return nil +} + +func (msg MsgWithdrawBond) ValidateBasic() error { + if len(msg.Id) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidRequest, msg.Id) + } + if len(msg.Signer) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidAddress, msg.Signer) + } + if len(msg.Coins) == 0 || !msg.Coins.IsValid() { + return errorsmod.Wrap(sdkerrors.ErrInvalidCoins, "Invalid amount.") + } + return nil +} + +func (msg MsgCancelBond) ValidateBasic() error { + if len(msg.Id) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidRequest, msg.Id) + } + if len(msg.Signer) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidAddress, msg.Signer) + } + return nil +} diff --git a/x/registry/keeper/msg_server.go b/x/registry/keeper/msg_server.go index ba48d633..fc9e47a7 100644 --- a/x/registry/keeper/msg_server.go +++ b/x/registry/keeper/msg_server.go @@ -20,6 +20,10 @@ func NewMsgServerImpl(keeper Keeper) registrytypes.MsgServer { } func (ms msgServer) SetRecord(c context.Context, msg *registrytypes.MsgSetRecord) (*registrytypes.MsgSetRecordResponse, error) { + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) _, err := sdk.AccAddressFromBech32(msg.Signer) @@ -51,6 +55,10 @@ func (ms msgServer) SetRecord(c context.Context, msg *registrytypes.MsgSetRecord // nolint: all func (ms msgServer) SetName(c context.Context, msg *registrytypes.MsgSetName) (*registrytypes.MsgSetNameResponse, error) { + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) _, err := sdk.AccAddressFromBech32(msg.Signer) @@ -80,6 +88,10 @@ func (ms msgServer) SetName(c context.Context, msg *registrytypes.MsgSetName) (* } func (ms msgServer) ReserveName(c context.Context, msg *registrytypes.MsgReserveAuthority) (*registrytypes.MsgReserveAuthorityResponse, error) { + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) _, err := sdk.AccAddressFromBech32(msg.Signer) @@ -114,6 +126,10 @@ func (ms msgServer) ReserveName(c context.Context, msg *registrytypes.MsgReserve // nolint: all func (ms msgServer) SetAuthorityBond(c context.Context, msg *registrytypes.MsgSetAuthorityBond) (*registrytypes.MsgSetAuthorityBondResponse, error) { + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) _, err := sdk.AccAddressFromBech32(msg.Signer) @@ -144,6 +160,10 @@ func (ms msgServer) SetAuthorityBond(c context.Context, msg *registrytypes.MsgSe } func (ms msgServer) DeleteName(c context.Context, msg *registrytypes.MsgDeleteNameAuthority) (*registrytypes.MsgDeleteNameAuthorityResponse, error) { + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) _, err := sdk.AccAddressFromBech32(msg.Signer) @@ -172,6 +192,10 @@ func (ms msgServer) DeleteName(c context.Context, msg *registrytypes.MsgDeleteNa } func (ms msgServer) RenewRecord(c context.Context, msg *registrytypes.MsgRenewRecord) (*registrytypes.MsgRenewRecordResponse, error) { + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) _, err := sdk.AccAddressFromBech32(msg.Signer) @@ -201,6 +225,10 @@ func (ms msgServer) RenewRecord(c context.Context, msg *registrytypes.MsgRenewRe // nolint: all func (ms msgServer) AssociateBond(c context.Context, msg *registrytypes.MsgAssociateBond) (*registrytypes.MsgAssociateBondResponse, error) { + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) _, err := sdk.AccAddressFromBech32(msg.Signer) @@ -231,6 +259,10 @@ func (ms msgServer) AssociateBond(c context.Context, msg *registrytypes.MsgAssoc } func (ms msgServer) DissociateBond(c context.Context, msg *registrytypes.MsgDissociateBond) (*registrytypes.MsgDissociateBondResponse, error) { + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) _, err := sdk.AccAddressFromBech32(msg.Signer) @@ -260,6 +292,10 @@ func (ms msgServer) DissociateBond(c context.Context, msg *registrytypes.MsgDiss } func (ms msgServer) DissociateRecords(c context.Context, msg *registrytypes.MsgDissociateRecords) (*registrytypes.MsgDissociateRecordsResponse, error) { + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) _, err := sdk.AccAddressFromBech32(msg.Signer) @@ -289,6 +325,10 @@ func (ms msgServer) DissociateRecords(c context.Context, msg *registrytypes.MsgD } func (ms msgServer) ReassociateRecords(c context.Context, msg *registrytypes.MsgReassociateRecords) (*registrytypes.MsgReassociateRecordsResponse, error) { //nolint: all + if err := msg.ValidateBasic(); err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) _, err := sdk.AccAddressFromBech32(msg.Signer) diff --git a/x/registry/msgs.go b/x/registry/msgs.go index 93f65743..c091441a 100644 --- a/x/registry/msgs.go +++ b/x/registry/msgs.go @@ -1,6 +1,8 @@ package registry import ( + "net/url" + errorsmod "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -34,6 +36,18 @@ func (msg MsgSetRecord) ValidateBasic() error { return nil } +func (msg MsgRenewRecord) ValidateBasic() error { + if len(msg.RecordId) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidRequest, "record id is required.") + } + + if len(msg.Signer) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidAddress, "invalid signer.") + } + + return nil +} + // NewMsgReserveAuthority is the constructor function for MsgReserveName. func NewMsgReserveAuthority(name string, signer sdk.AccAddress, owner sdk.AccAddress) MsgReserveAuthority { return MsgReserveAuthority{ @@ -65,6 +79,22 @@ func NewMsgSetAuthorityBond(name string, bondID string, signer sdk.AccAddress) M } } +func (msg MsgSetAuthorityBond) ValidateBasic() error { + if len(msg.Name) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidRequest, "name is required.") + } + + if len(msg.Signer) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidAddress, "invalid signer.") + } + + if len(msg.BondId) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidAddress, "bond id is required.") + } + + return nil +} + // NewMsgSetName is the constructor function for MsgSetName. func NewMsgSetName(lrn string, cid string, signer sdk.AccAddress) *MsgSetName { return &MsgSetName{ @@ -90,3 +120,70 @@ func (msg MsgSetName) ValidateBasic() error { return nil } + +func (msg MsgDeleteNameAuthority) ValidateBasic() error { + if len(msg.Lrn) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidRequest, "lrn is required.") + } + + if len(msg.Signer) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidAddress, "invalid signer.") + } + + _, err := url.Parse(msg.Lrn) + if err != nil { + return errorsmod.Wrap(sdkerrors.ErrInvalidRequest, "invalid lrn.") + } + + return nil +} + +func (msg MsgAssociateBond) ValidateBasic() error { + if len(msg.RecordId) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidRequest, "record id is required.") + } + if len(msg.BondId) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidRequest, "bond id is required.") + } + if len(msg.Signer) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidAddress, "invalid signer.") + } + + return nil +} + +func (msg MsgDissociateBond) ValidateBasic() error { + if len(msg.RecordId) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidRequest, "record id is required.") + } + if len(msg.Signer) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidAddress, "invalid signer.") + } + + return nil +} + +func (msg MsgDissociateRecords) ValidateBasic() error { + if len(msg.BondId) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidRequest, "bond id is required.") + } + if len(msg.Signer) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidAddress, "invalid signer.") + } + + return nil +} + +func (msg MsgReassociateRecords) ValidateBasic() error { + if len(msg.OldBondId) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidRequest, "old-bond-id is required.") + } + if len(msg.NewBondId) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidRequest, "new-bond-id is required.") + } + if len(msg.Signer) == 0 { + return errorsmod.Wrap(sdkerrors.ErrInvalidAddress, "invalid signer.") + } + + return nil +}