From 432afb274b43e72f87a59d04b5d322ebe31fc0c8 Mon Sep 17 00:00:00 2001 From: Aditya Date: Wed, 4 Nov 2020 12:16:20 +0000 Subject: [PATCH] ibc: minor fixes from audit (#7807) --- x/ibc/core/04-channel/keeper/handshake.go | 22 +++++++++++++--------- x/ibc/core/04-channel/types/msgs.go | 14 +++++++++++++- x/ibc/core/04-channel/types/msgs_test.go | 8 ++++++++ 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/x/ibc/core/04-channel/keeper/handshake.go b/x/ibc/core/04-channel/keeper/handshake.go index 03b100be7a..75b0c1bd3d 100644 --- a/x/ibc/core/04-channel/keeper/handshake.go +++ b/x/ibc/core/04-channel/keeper/handshake.go @@ -14,18 +14,22 @@ import ( // CounterpartyHops returns the connection hops of the counterparty channel. // The counterparty hops are stored in the inverse order as the channel's. +// NOTE: Since connectionHops only supports single connection channels for now, +// this function requires that connection hops only contain a single connection id func (k Keeper) CounterpartyHops(ctx sdk.Context, ch types.Channel) ([]string, bool) { - counterPartyHops := make([]string, len(ch.ConnectionHops)) - - for i, hop := range ch.ConnectionHops { - conn, found := k.connectionKeeper.GetConnection(ctx, hop) - if !found { - return []string{}, false - } - - counterPartyHops[len(counterPartyHops)-1-i] = conn.GetCounterparty().GetConnectionID() + // Return empty array if connection hops is more than one + // ConnectionHops length should be verified earlier + if len(ch.ConnectionHops) != 1 { + return []string{}, false + } + counterPartyHops := make([]string, 1) + hop := ch.ConnectionHops[0] + conn, found := k.connectionKeeper.GetConnection(ctx, hop) + if !found { + return []string{}, false } + counterPartyHops[0] = conn.GetCounterparty().GetConnectionID() return counterPartyHops, true } diff --git a/x/ibc/core/04-channel/types/msgs.go b/x/ibc/core/04-channel/types/msgs.go index ed29cc24ff..96a4283d8f 100644 --- a/x/ibc/core/04-channel/types/msgs.go +++ b/x/ibc/core/04-channel/types/msgs.go @@ -46,6 +46,12 @@ func (msg MsgChannelOpenInit) ValidateBasic() error { if err := host.ChannelIdentifierValidator(msg.ChannelId); err != nil { return sdkerrors.Wrap(err, "invalid channel ID") } + if msg.Channel.State != INIT { + return sdkerrors.Wrapf(ErrInvalidChannelState, + "channel state must be INIT in MsgChannelOpenInit. expected: %s, got: %s", + INIT, msg.Channel.State, + ) + } _, err := sdk.AccAddressFromBech32(msg.Signer) if err != nil { return sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "string could not be parsed as address: %v", err) @@ -78,7 +84,7 @@ func NewMsgChannelOpenTry( proofInit []byte, proofHeight clienttypes.Height, signer sdk.AccAddress, ) *MsgChannelOpenTry { counterparty := NewCounterparty(counterpartyPortID, counterpartyChannelID) - channel := NewChannel(INIT, channelOrder, counterparty, connectionHops, version) + channel := NewChannel(TRYOPEN, channelOrder, counterparty, connectionHops, version) return &MsgChannelOpenTry{ PortId: portID, DesiredChannelId: desiredChannelID, @@ -118,6 +124,12 @@ func (msg MsgChannelOpenTry) ValidateBasic() error { if msg.ProofHeight.IsZero() { return sdkerrors.Wrap(sdkerrors.ErrInvalidHeight, "proof height must be non-zero") } + if msg.Channel.State != TRYOPEN { + return sdkerrors.Wrapf(ErrInvalidChannelState, + "channel state must be TRYOPEN in MsgChannelOpenTry. expected: %s, got: %s", + TRYOPEN, msg.Channel.State, + ) + } _, err := sdk.AccAddressFromBech32(msg.Signer) if err != nil { return sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "string could not be parsed as address: %v", err) diff --git a/x/ibc/core/04-channel/types/msgs_test.go b/x/ibc/core/04-channel/types/msgs_test.go index c7c1bbb4ba..9cf466840c 100644 --- a/x/ibc/core/04-channel/types/msgs_test.go +++ b/x/ibc/core/04-channel/types/msgs_test.go @@ -105,6 +105,9 @@ func TestTypesTestSuite(t *testing.T) { } func (suite *TypesTestSuite) TestMsgChannelOpenInitValidateBasic() { + counterparty := types.NewCounterparty(cpportid, cpchanid) + tryOpenChannel := types.NewChannel(types.TRYOPEN, types.ORDERED, counterparty, connHops, version) + testCases := []struct { name string msg *types.MsgChannelOpenInit @@ -125,6 +128,7 @@ func (suite *TypesTestSuite) TestMsgChannelOpenInitValidateBasic() { {"", types.NewMsgChannelOpenInit(portid, chanid, "", types.UNORDERED, connHops, cpportid, cpchanid, addr), true}, {"invalid counterparty port id", types.NewMsgChannelOpenInit(portid, chanid, version, types.UNORDERED, connHops, invalidPort, cpchanid, addr), false}, {"invalid counterparty channel id", types.NewMsgChannelOpenInit(portid, chanid, version, types.UNORDERED, connHops, cpportid, invalidChannel, addr), false}, + {"channel not in INIT state", &types.MsgChannelOpenInit{portid, chanid, tryOpenChannel, addr.String()}, false}, } for _, tc := range testCases { @@ -142,6 +146,9 @@ func (suite *TypesTestSuite) TestMsgChannelOpenInitValidateBasic() { } func (suite *TypesTestSuite) TestMsgChannelOpenTryValidateBasic() { + counterparty := types.NewCounterparty(cpportid, cpchanid) + initChannel := types.NewChannel(types.INIT, types.ORDERED, counterparty, connHops, version) + testCases := []struct { name string msg *types.MsgChannelOpenTry @@ -167,6 +174,7 @@ func (suite *TypesTestSuite) TestMsgChannelOpenTryValidateBasic() { {"empty proof", types.NewMsgChannelOpenTry(portid, chanid, chanid, version, types.UNORDERED, connHops, cpportid, cpchanid, version, emptyProof, height, addr), false}, {"valid empty proved channel id", types.NewMsgChannelOpenTry(portid, chanid, "", version, types.ORDERED, connHops, cpportid, cpchanid, version, suite.proof, height, addr), true}, {"invalid proved channel id, doesn't match channel id", types.NewMsgChannelOpenTry(portid, chanid, "differentchannel", version, types.ORDERED, connHops, cpportid, cpchanid, version, suite.proof, height, addr), false}, + {"channel not in TRYOPEN state", &types.MsgChannelOpenTry{portid, chanid, chanid, initChannel, version, suite.proof, height, addr.String()}, false}, } for _, tc := range testCases {