From 765f52e402f5410db9d79914affdf8cda775154c Mon Sep 17 00:00:00 2001 From: Ethan Frey Date: Mon, 10 Jul 2017 12:36:30 +0200 Subject: [PATCH] Enforce the expiration height in Chain middleware --- errors/common.go | 8 ++++++++ modules/base/chain.go | 9 ++++++--- modules/base/chain_test.go | 14 ++++++++++++-- modules/base/tx.go | 7 ++++++- 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/errors/common.go b/errors/common.go index d38f6b07b6..c04af0984d 100644 --- a/errors/common.go +++ b/errors/common.go @@ -21,6 +21,7 @@ var ( errUnknownTxType = fmt.Errorf("Tx type unknown") errInvalidFormat = fmt.Errorf("Invalid format") errUnknownModule = fmt.Errorf("Unknown module") + errExpired = fmt.Errorf("Tx expired") ) // some crazy reflection to unwrap any generated struct. @@ -130,3 +131,10 @@ func ErrTooLarge() TMError { func IsTooLargeErr(err error) bool { return IsSameError(errTooLarge, err) } + +func ErrExpired() TMError { + return WithCode(errExpired, abci.CodeType_Unauthorized) +} +func IsExpiredErr(err error) bool { + return IsSameError(errExpired, err) +} diff --git a/modules/base/chain.go b/modules/base/chain.go index 4fc7071e63..8161f8802e 100644 --- a/modules/base/chain.go +++ b/modules/base/chain.go @@ -26,7 +26,7 @@ var _ stack.Middleware = Chain{} // CheckTx makes sure we are on the proper chain - fulfills Middlware interface func (c Chain) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Checker) (res basecoin.Result, err error) { - stx, err := c.checkChain(ctx.ChainID(), tx) + stx, err := c.checkChain(ctx.ChainID(), ctx.BlockHeight(), tx) if err != nil { return res, err } @@ -35,7 +35,7 @@ func (c Chain) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx // DeliverTx makes sure we are on the proper chain - fulfills Middlware interface func (c Chain) DeliverTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Deliver) (res basecoin.Result, err error) { - stx, err := c.checkChain(ctx.ChainID(), tx) + stx, err := c.checkChain(ctx.ChainID(), ctx.BlockHeight(), tx) if err != nil { return res, err } @@ -43,7 +43,7 @@ func (c Chain) DeliverTx(ctx basecoin.Context, store state.KVStore, tx basecoin. } // checkChain makes sure the tx is a Chain Tx and is on the proper chain -func (c Chain) checkChain(chainID string, tx basecoin.Tx) (basecoin.Tx, error) { +func (c Chain) checkChain(chainID string, height uint64, tx basecoin.Tx) (basecoin.Tx, error) { // make sure it is a chaintx ctx, ok := tx.Unwrap().(ChainTx) if !ok { @@ -60,5 +60,8 @@ func (c Chain) checkChain(chainID string, tx basecoin.Tx) (basecoin.Tx, error) { if ctx.ChainID != chainID { return tx, errors.ErrWrongChain(ctx.ChainID) } + if ctx.ExpiresAt != 0 && ctx.ExpiresAt <= height { + return tx, errors.ErrExpired() + } return ctx.Tx, nil } diff --git a/modules/base/chain_test.go b/modules/base/chain_test.go index 21f48dc6fb..8da281cddd 100644 --- a/modules/base/chain_test.go +++ b/modules/base/chain_test.go @@ -48,6 +48,7 @@ func TestChain(t *testing.T) { assert := assert.New(t) msg := "got it" chainID := "my-chain" + height := uint64(100) raw := stack.NewRawTx([]byte{1, 2, 3, 4}) cases := []struct { @@ -55,13 +56,22 @@ func TestChain(t *testing.T) { valid bool errorMsg string }{ + // check the chain ids are validated {NewChainTx(chainID, 0, raw), true, ""}, - {NewChainTx("someone-else", 0, raw), false, "someone-else"}, + // non-matching chainid, or impossible chain id + {NewChainTx("someone-else", 0, raw), false, "someone-else: Wrong chain"}, + {NewChainTx("Inval$$d:CH%%n", 0, raw), false, "Wrong chain"}, + // Wrong tx type {raw, false, "No chain id provided"}, + // Check different heights - must be 0 or higher than current height + {NewChainTx(chainID, height+1, raw), true, ""}, + {NewChainTx(chainID, height, raw), false, "Tx expired"}, + {NewChainTx(chainID, 1, raw), false, "expired"}, + {NewChainTx(chainID, 0, raw), true, ""}, } // generic args here... - ctx := stack.NewContext(chainID, 100, log.NewNopLogger()) + ctx := stack.NewContext(chainID, height, log.NewNopLogger()) store := state.NewMemKVStore() // build the stack diff --git a/modules/base/tx.go b/modules/base/tx.go index 30320bedcf..819811ef6c 100644 --- a/modules/base/tx.go +++ b/modules/base/tx.go @@ -69,7 +69,12 @@ var ( //nolint - TxInner Functions func NewChainTx(chainID string, expires uint64, tx basecoin.Tx) basecoin.Tx { - return (ChainTx{Tx: tx, ChainID: chainID}).Wrap() + c := ChainTx{ + ChainID: chainID, + ExpiresAt: expires, + Tx: tx, + } + return c.Wrap() } func (c ChainTx) Wrap() basecoin.Tx { return basecoin.Tx{c}