From af9ce5b5536b16f4f8ccb32d214b8cd55eb66708 Mon Sep 17 00:00:00 2001 From: Ethan Frey Date: Mon, 10 Jul 2017 11:44:40 +0200 Subject: [PATCH 1/5] Add expiration field to ChainTx --- modules/base/chain_test.go | 35 +++++++++++++++++++++++++++++++++-- modules/base/tx.go | 27 +++++++++++++++++++++++---- modules/base/tx_test.go | 2 +- 3 files changed, 57 insertions(+), 7 deletions(-) diff --git a/modules/base/chain_test.go b/modules/base/chain_test.go index 967c0e91b9..9d69d154fd 100644 --- a/modules/base/chain_test.go +++ b/modules/base/chain_test.go @@ -13,6 +13,37 @@ import ( "github.com/tendermint/basecoin/state" ) +func TestChainValidate(t *testing.T) { + assert := assert.New(t) + raw := stack.NewRawTx([]byte{1, 2, 3, 4}) + + cases := []struct { + name string + expires uint64 + valid bool + }{ + {"hello", 0, true}, + {"one-2-three", 123, true}, + {"super!@#$%@", 0, false}, + {"WISH_2_be", 14, true}, + {"öhhh", 54, false}, + } + + for _, tc := range cases { + tx := NewChainTx(tc.name, tc.expires, raw) + err := tx.ValidateBasic() + if tc.valid { + assert.Nil(err, "%s: %+v", tc.name, err) + } else { + assert.NotNil(err, tc.name) + } + } + + empty := NewChainTx("okay", 0, basecoin.Tx{}) + err := empty.ValidateBasic() + assert.NotNil(err) +} + func TestChain(t *testing.T) { assert := assert.New(t) msg := "got it" @@ -24,8 +55,8 @@ func TestChain(t *testing.T) { valid bool errorMsg string }{ - {NewChainTx(chainID, raw), true, ""}, - {NewChainTx("someone-else", raw), false, "someone-else"}, + {NewChainTx(chainID, 0, raw), true, ""}, + {NewChainTx("someone-else", 0, raw), false, "someone-else"}, {raw, false, "No chain id provided"}, } diff --git a/modules/base/tx.go b/modules/base/tx.go index b4f2cd2b78..30320bedcf 100644 --- a/modules/base/tx.go +++ b/modules/base/tx.go @@ -1,6 +1,11 @@ package base -import "github.com/tendermint/basecoin" +import ( + "regexp" + + "github.com/tendermint/basecoin" + "github.com/tendermint/basecoin/errors" +) // nolint const ( @@ -51,20 +56,34 @@ func (mt MultiTx) ValidateBasic() error { // ChainTx locks this tx to one chainTx, wrap with this before signing type ChainTx struct { - Tx basecoin.Tx `json:"tx"` - ChainID string `json:"chain_id"` + ChainID string `json:"chain_id"` // name of chain, must be [A-Za-z0-9_-]+ + ExpiresAt uint64 `json:"expires_at"` // block height at which it is no longer valid + Tx basecoin.Tx `json:"tx"` } var _ basecoin.TxInner = &ChainTx{} +var ( + chainPattern = regexp.MustCompile("^[A-Za-z0-9_-]+$") +) + //nolint - TxInner Functions -func NewChainTx(chainID string, tx basecoin.Tx) basecoin.Tx { +func NewChainTx(chainID string, expires uint64, tx basecoin.Tx) basecoin.Tx { return (ChainTx{Tx: tx, ChainID: chainID}).Wrap() } func (c ChainTx) Wrap() basecoin.Tx { return basecoin.Tx{c} } func (c ChainTx) ValidateBasic() error { + if c.ChainID == "" { + return errors.ErrNoChain() + } + if !chainPattern.MatchString(c.ChainID) { + return errors.ErrWrongChain(c.ChainID) + } + if c.Tx.Empty() { + return errors.ErrUnknownTxType(c.Tx) + } // TODO: more checks? chainID? return c.Tx.ValidateBasic() } diff --git a/modules/base/tx_test.go b/modules/base/tx_test.go index 215256674d..00bd4ea79c 100644 --- a/modules/base/tx_test.go +++ b/modules/base/tx_test.go @@ -25,7 +25,7 @@ func TestEncoding(t *testing.T) { }{ {raw}, {NewMultiTx(raw, raw2)}, - {NewChainTx("foobar", raw)}, + {NewChainTx("foobar", 0, raw)}, } for idx, tc := range cases { From 100d88d7dd482c420a8c41677c0c4b3bcc3a5bbf Mon Sep 17 00:00:00 2001 From: Ethan Frey Date: Mon, 10 Jul 2017 11:57:37 +0200 Subject: [PATCH 2/5] Fix up all tests to handle NewChainTx change --- app/app_test.go | 2 +- cmd/basecli/commands/cmds.go | 21 +++++++++++++++++-- .../cmd/countercli/commands/counter.go | 10 +++++---- .../counter/plugins/counter/counter_test.go | 2 +- modules/base/chain.go | 9 ++++++++ 5 files changed, 36 insertions(+), 8 deletions(-) diff --git a/app/app_test.go b/app/app_test.go index e98b8a5a25..7889a31b0e 100644 --- a/app/app_test.go +++ b/app/app_test.go @@ -44,7 +44,7 @@ func (at *appTest) getTx(seq int, coins coin.Coins) basecoin.Tx { in := []coin.TxInput{{Address: at.acctIn.Actor(), Coins: coins, Sequence: seq}} out := []coin.TxOutput{{Address: at.acctOut.Actor(), Coins: coins}} tx := coin.NewSendTx(in, out) - tx = base.NewChainTx(at.chainID, tx) + tx = base.NewChainTx(at.chainID, 0, tx) stx := auth.NewMulti(tx) auth.Sign(stx, at.acctIn.Key) return stx.Wrap() diff --git a/cmd/basecli/commands/cmds.go b/cmd/basecli/commands/cmds.go index b3c9ede914..e5fff36342 100644 --- a/cmd/basecli/commands/cmds.go +++ b/cmd/basecli/commands/cmds.go @@ -34,6 +34,7 @@ const ( FlagAmount = "amount" FlagFee = "fee" FlagGas = "gas" + FlagExpires = "expires" FlagSequence = "sequence" ) @@ -42,7 +43,8 @@ func init() { flags.String(FlagTo, "", "Destination address for the bits") flags.String(FlagAmount, "", "Coins to send in the format ,...") flags.String(FlagFee, "0mycoin", "Coins for the transaction fee of the format ") - flags.Int64(FlagGas, 0, "Amount of gas for this transaction") + flags.Uint64(FlagGas, 0, "Amount of gas for this transaction") + flags.Int64(FlagExpires, 0, "Block height at which this tx expires") flags.Int(FlagSequence, -1, "Sequence number for this transaction") } @@ -63,7 +65,12 @@ func doSendTx(cmd *cobra.Command, args []string) error { // TODO: make this more flexible for middleware // add the chain info - tx = base.NewChainTx(commands.GetChainID(), tx) + tx, err = WrapChainTx(tx) + if err != nil { + return err + } + + // Note: this is single sig (no multi sig yet) stx := auth.NewSig(tx) // Sign if needed and post. This it the work-horse @@ -76,6 +83,16 @@ func doSendTx(cmd *cobra.Command, args []string) error { return txcmd.OutputTx(bres) } +// WrapChainTx will wrap the tx with a ChainTx from the standard flags +func WrapChainTx(tx basecoin.Tx) (res basecoin.Tx, err error) { + expires := viper.GetInt64(FlagExpires) + if expires < 0 { + return res, errors.New("expires must be >= 0") + } + res = base.NewChainTx(commands.GetChainID(), uint64(expires), tx) + return res, nil +} + func readSendTxFlags() (tx basecoin.Tx, err error) { // parse to address chain, to, err := parseChainAddress(viper.GetString(FlagTo)) diff --git a/docs/guide/counter/cmd/countercli/commands/counter.go b/docs/guide/counter/cmd/countercli/commands/counter.go index 255a2f65e2..326b62b19c 100644 --- a/docs/guide/counter/cmd/countercli/commands/counter.go +++ b/docs/guide/counter/cmd/countercli/commands/counter.go @@ -4,13 +4,12 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" - "github.com/tendermint/basecoin" - "github.com/tendermint/light-client/commands" txcmd "github.com/tendermint/light-client/commands/txs" + "github.com/tendermint/basecoin" + bcmd "github.com/tendermint/basecoin/cmd/basecli/commands" "github.com/tendermint/basecoin/docs/guide/counter/plugins/counter" "github.com/tendermint/basecoin/modules/auth" - "github.com/tendermint/basecoin/modules/base" "github.com/tendermint/basecoin/modules/coin" ) @@ -57,7 +56,10 @@ func counterTx(cmd *cobra.Command, args []string) error { // TODO: make this more flexible for middleware // add the chain info - tx = base.NewChainTx(commands.GetChainID(), tx) + tx, err = bcmd.WrapChainTx(tx) + if err != nil { + return err + } stx := auth.NewSig(tx) // Sign if needed and post. This it the work-horse diff --git a/docs/guide/counter/plugins/counter/counter_test.go b/docs/guide/counter/plugins/counter/counter_test.go index 522d6a7215..2de91d62a7 100644 --- a/docs/guide/counter/plugins/counter/counter_test.go +++ b/docs/guide/counter/plugins/counter/counter_test.go @@ -42,7 +42,7 @@ func TestCounterPlugin(t *testing.T) { // Deliver a CounterTx DeliverCounterTx := func(valid bool, counterFee coin.Coins, inputSequence int) abci.Result { tx := NewTx(valid, counterFee, inputSequence) - tx = base.NewChainTx(chainID, tx) + tx = base.NewChainTx(chainID, 0, tx) stx := auth.NewSig(tx) auth.Sign(stx, acct.Key) txBytes := wire.BinaryBytes(stx.Wrap()) diff --git a/modules/base/chain.go b/modules/base/chain.go index a362267668..4fc7071e63 100644 --- a/modules/base/chain.go +++ b/modules/base/chain.go @@ -44,10 +44,19 @@ func (c Chain) DeliverTx(ctx basecoin.Context, store state.KVStore, tx basecoin. // checkChain makes sure the tx is a Chain Tx and is on the proper chain func (c Chain) checkChain(chainID string, tx basecoin.Tx) (basecoin.Tx, error) { + // make sure it is a chaintx ctx, ok := tx.Unwrap().(ChainTx) if !ok { return tx, errors.ErrNoChain() } + + // basic validation + err := ctx.ValidateBasic() + if err != nil { + return tx, err + } + + // compare against state if ctx.ChainID != chainID { return tx, errors.ErrWrongChain(ctx.ChainID) } From b6197a1c12f6250cbcd47468cadbfbe01c69dc88 Mon Sep 17 00:00:00 2001 From: Ethan Frey Date: Mon, 10 Jul 2017 12:19:42 +0200 Subject: [PATCH 3/5] Add height to context, cleanup, add to app and stack --- app/app.go | 6 +++ context.go | 1 + modules/auth/bench_test.go | 4 +- modules/auth/signature_test.go | 2 +- modules/base/chain_test.go | 2 +- modules/coin/bench_test.go | 2 +- modules/coin/handler_test.go | 4 +- stack/context.go | 72 ++++++++-------------------------- stack/helpers_test.go | 6 +-- stack/middleware_test.go | 2 +- stack/mock.go | 53 +++++++++++++++++-------- stack/recovery_test.go | 2 +- 12 files changed, 72 insertions(+), 84 deletions(-) diff --git a/app/app.go b/app/app.go index fa99458e6d..e053408a48 100644 --- a/app/app.go +++ b/app/app.go @@ -31,6 +31,7 @@ type Basecoin struct { state *sm.State cacheState *sm.State handler basecoin.Handler + height uint64 logger log.Logger } @@ -45,6 +46,7 @@ func NewBasecoin(handler basecoin.Handler, eyesCli *eyes.Client, logger log.Logg eyesCli: eyesCli, state: state, cacheState: nil, + height: 0, logger: logger, } } @@ -73,6 +75,7 @@ func (app *Basecoin) Info() abci.ResponseInfo { if err != nil { cmn.PanicCrisis(err) } + app.height = resp.LastBlockHeight return abci.ResponseInfo{ Data: fmt.Sprintf("Basecoin v%v", version.Version), LastBlockHeight: resp.LastBlockHeight, @@ -111,6 +114,7 @@ func (app *Basecoin) DeliverTx(txBytes []byte) abci.Result { cache := app.state.CacheWrap() ctx := stack.NewContext( app.state.GetChainID(), + app.height, app.logger.With("call", "delivertx"), ) res, err := app.handler.DeliverTx(ctx, cache, tx) @@ -134,6 +138,7 @@ func (app *Basecoin) CheckTx(txBytes []byte) abci.Result { // TODO: can we abstract this setup and commit logic?? ctx := stack.NewContext( app.state.GetChainID(), + app.height, app.logger.With("call", "checktx"), ) // checktx generally shouldn't touch the state, but we don't care @@ -187,6 +192,7 @@ func (app *Basecoin) InitChain(validators []*abci.Validator) { // BeginBlock - ABCI func (app *Basecoin) BeginBlock(hash []byte, header *abci.Header) { + app.height++ // for _, plugin := range app.plugins.GetList() { // plugin.BeginBlock(app.state, hash, header) // } diff --git a/context.go b/context.go index 66ba9c5d3e..b82aa042dc 100644 --- a/context.go +++ b/context.go @@ -36,4 +36,5 @@ type Context interface { IsParent(ctx Context) bool Reset() Context ChainID() string + BlockHeight() uint64 } diff --git a/modules/auth/bench_test.go b/modules/auth/bench_test.go index b79449a7dd..f2434f1eee 100644 --- a/modules/auth/bench_test.go +++ b/modules/auth/bench_test.go @@ -40,7 +40,7 @@ func BenchmarkCheckOneSig(b *testing.B) { h := makeHandler() store := state.NewMemKVStore() for i := 1; i <= b.N; i++ { - ctx := stack.NewContext("foo", log.NewNopLogger()) + ctx := stack.NewContext("foo", 100, log.NewNopLogger()) _, err := h.DeliverTx(ctx, store, tx) // never should error if err != nil { @@ -64,7 +64,7 @@ func benchmarkCheckMultiSig(b *testing.B, cnt int) { h := makeHandler() store := state.NewMemKVStore() for i := 1; i <= b.N; i++ { - ctx := stack.NewContext("foo", log.NewNopLogger()) + ctx := stack.NewContext("foo", 100, log.NewNopLogger()) _, err := h.DeliverTx(ctx, store, tx) // never should error if err != nil { diff --git a/modules/auth/signature_test.go b/modules/auth/signature_test.go index 3229e01dca..5afbee1a60 100644 --- a/modules/auth/signature_test.go +++ b/modules/auth/signature_test.go @@ -18,7 +18,7 @@ func TestSignatureChecks(t *testing.T) { assert := assert.New(t) // generic args - ctx := stack.NewContext("test-chain", log.NewNopLogger()) + ctx := stack.NewContext("test-chain", 100, log.NewNopLogger()) store := state.NewMemKVStore() raw := stack.NewRawTx([]byte{1, 2, 3, 4}) diff --git a/modules/base/chain_test.go b/modules/base/chain_test.go index 9d69d154fd..21f48dc6fb 100644 --- a/modules/base/chain_test.go +++ b/modules/base/chain_test.go @@ -61,7 +61,7 @@ func TestChain(t *testing.T) { } // generic args here... - ctx := stack.NewContext(chainID, log.NewNopLogger()) + ctx := stack.NewContext(chainID, 100, log.NewNopLogger()) store := state.NewMemKVStore() // build the stack diff --git a/modules/coin/bench_test.go b/modules/coin/bench_test.go index 52c819a85e..83d3da6b78 100644 --- a/modules/coin/bench_test.go +++ b/modules/coin/bench_test.go @@ -34,7 +34,7 @@ func BenchmarkSimpleTransfer(b *testing.B) { // now, loop... for i := 1; i <= b.N; i++ { - ctx := stack.MockContext("foo").WithPermissions(sender) + ctx := stack.MockContext("foo", 100).WithPermissions(sender) tx := makeSimpleTx(sender, receiver, Coins{{"mycoin", 2}}, i) _, err := h.DeliverTx(ctx, store, tx) // never should error diff --git a/modules/coin/handler_test.go b/modules/coin/handler_test.go index 74692bbc6b..50a4bbed75 100644 --- a/modules/coin/handler_test.go +++ b/modules/coin/handler_test.go @@ -74,7 +74,7 @@ func TestHandlerValidation(t *testing.T) { } for i, tc := range cases { - ctx := stack.MockContext("base-chain").WithPermissions(tc.perms...) + ctx := stack.MockContext("base-chain", 100).WithPermissions(tc.perms...) _, err := checkTx(ctx, tc.tx) if tc.valid { assert.Nil(err, "%d: %+v", i, err) @@ -148,7 +148,7 @@ func TestDeliverTx(t *testing.T) { require.Nil(err, "%d: %+v", i, err) } - ctx := stack.MockContext("base-chain").WithPermissions(tc.perms...) + ctx := stack.MockContext("base-chain", 100).WithPermissions(tc.perms...) _, err := h.DeliverTx(ctx, store, tc.tx) if len(tc.final) > 0 { // valid assert.Nil(err, "%d: %+v", i, err) diff --git a/stack/context.go b/stack/context.go index 7ff0fa4f22..0bcc0e4ecf 100644 --- a/stack/context.go +++ b/stack/context.go @@ -1,9 +1,6 @@ package stack import ( - "bytes" - "math/rand" - "github.com/pkg/errors" "github.com/tendermint/tmlibs/log" @@ -16,28 +13,20 @@ import ( type nonce int64 type secureContext struct { - id nonce - chain string - app string - perms []basecoin.Actor - log.Logger + app string + // this exposes the log.Logger and all other methods we don't override + naiveContext } // NewContext - create a new secureContext -func NewContext(chain string, logger log.Logger) basecoin.Context { +func NewContext(chain string, height uint64, logger log.Logger) basecoin.Context { return secureContext{ - id: nonce(rand.Int63()), - chain: chain, - Logger: logger, + naiveContext: MockContext(chain, height).(naiveContext), } } var _ basecoin.Context = secureContext{} -func (c secureContext) ChainID() string { - return c.chain -} - // WithPermissions will panic if they try to set permission without the proper app func (c secureContext) WithPermissions(perms ...basecoin.Actor) basecoin.Context { // the guard makes sure you only set permissions for the app you are inside @@ -50,32 +39,18 @@ func (c secureContext) WithPermissions(perms ...basecoin.Actor) basecoin.Context } return secureContext{ - id: c.id, - chain: c.chain, - app: c.app, - perms: append(c.perms, perms...), - Logger: c.Logger, + app: c.app, + naiveContext: c.naiveContext.WithPermissions(perms...).(naiveContext), } } -func (c secureContext) HasPermission(perm basecoin.Actor) bool { - for _, p := range c.perms { - if perm.App == p.App && bytes.Equal(perm.Address, p.Address) { - return true - } +// Reset should clear out all permissions, +// but carry on knowledge that this is a child +func (c secureContext) Reset() basecoin.Context { + return secureContext{ + app: c.app, + naiveContext: c.naiveContext.Reset().(naiveContext), } - return false -} - -func (c secureContext) GetPermissions(chain, app string) (res []basecoin.Actor) { - for _, p := range c.perms { - if chain == p.ChainID { - if app == "" || app == p.App { - res = append(res, p) - } - } - } - return res } // IsParent ensures that this is derived from the given secureClient @@ -84,19 +59,7 @@ func (c secureContext) IsParent(other basecoin.Context) bool { if !ok { return false } - return c.id == so.id -} - -// Reset should clear out all permissions, -// but carry on knowledge that this is a child -func (c secureContext) Reset() basecoin.Context { - return secureContext{ - id: c.id, - chain: c.chain, - app: c.app, - perms: nil, - Logger: c.Logger, - } + return c.naiveContext.IsParent(so.naiveContext) } // withApp is a private method that we can use to properly set the @@ -107,11 +70,8 @@ func withApp(ctx basecoin.Context, app string) basecoin.Context { return ctx } return secureContext{ - id: sc.id, - chain: sc.chain, - app: app, - perms: sc.perms, - Logger: sc.Logger, + app: app, + naiveContext: sc.naiveContext, } } diff --git a/stack/helpers_test.go b/stack/helpers_test.go index 1105988f47..58748e054d 100644 --- a/stack/helpers_test.go +++ b/stack/helpers_test.go @@ -15,7 +15,7 @@ import ( func TestOK(t *testing.T) { assert := assert.New(t) - ctx := NewContext("test-chain", log.NewNopLogger()) + ctx := NewContext("test-chain", 20, log.NewNopLogger()) store := state.NewMemKVStore() data := "this looks okay" tx := basecoin.Tx{} @@ -33,7 +33,7 @@ func TestOK(t *testing.T) { func TestFail(t *testing.T) { assert := assert.New(t) - ctx := NewContext("test-chain", log.NewNopLogger()) + ctx := NewContext("test-chain", 20, log.NewNopLogger()) store := state.NewMemKVStore() msg := "big problem" tx := basecoin.Tx{} @@ -53,7 +53,7 @@ func TestFail(t *testing.T) { func TestPanic(t *testing.T) { assert := assert.New(t) - ctx := NewContext("test-chain", log.NewNopLogger()) + ctx := NewContext("test-chain", 20, log.NewNopLogger()) store := state.NewMemKVStore() msg := "system crash!" tx := basecoin.Tx{} diff --git a/stack/middleware_test.go b/stack/middleware_test.go index bcd5f9b0d1..1b4ed8df3e 100644 --- a/stack/middleware_test.go +++ b/stack/middleware_test.go @@ -22,7 +22,7 @@ func TestPermissionSandbox(t *testing.T) { require := require.New(t) // generic args - ctx := NewContext("test-chain", log.NewNopLogger()) + ctx := NewContext("test-chain", 20, log.NewNopLogger()) store := state.NewMemKVStore() raw := NewRawTx([]byte{1, 2, 3, 4}) rawBytes, err := data.ToWire(raw) diff --git a/stack/mock.go b/stack/mock.go index 781f13d4fb..2f6ef9a5de 100644 --- a/stack/mock.go +++ b/stack/mock.go @@ -2,40 +2,55 @@ package stack import ( "bytes" + "math/rand" "github.com/tendermint/tmlibs/log" "github.com/tendermint/basecoin" ) -type mockContext struct { - perms []basecoin.Actor - chain string +type naiveContext struct { + id nonce + chain string + height uint64 + perms []basecoin.Actor log.Logger } -func MockContext(chain string) basecoin.Context { - return mockContext{ +// MockContext returns a simple, non-checking context for test cases. +// +// Always use NewContext() for production code to sandbox malicious code better +func MockContext(chain string, height uint64) basecoin.Context { + return naiveContext{ + id: nonce(rand.Int63()), chain: chain, + height: height, Logger: log.NewNopLogger(), } } -var _ basecoin.Context = mockContext{} +var _ basecoin.Context = naiveContext{} -func (c mockContext) ChainID() string { +func (c naiveContext) ChainID() string { return c.chain } +func (c naiveContext) BlockHeight() uint64 { + return c.height +} + // WithPermissions will panic if they try to set permission without the proper app -func (c mockContext) WithPermissions(perms ...basecoin.Actor) basecoin.Context { - return mockContext{ +func (c naiveContext) WithPermissions(perms ...basecoin.Actor) basecoin.Context { + return naiveContext{ + id: c.id, + chain: c.chain, + height: c.height, perms: append(c.perms, perms...), Logger: c.Logger, } } -func (c mockContext) HasPermission(perm basecoin.Actor) bool { +func (c naiveContext) HasPermission(perm basecoin.Actor) bool { for _, p := range c.perms { if perm.App == p.App && bytes.Equal(perm.Address, p.Address) { return true @@ -44,7 +59,7 @@ func (c mockContext) HasPermission(perm basecoin.Actor) bool { return false } -func (c mockContext) GetPermissions(chain, app string) (res []basecoin.Actor) { +func (c naiveContext) GetPermissions(chain, app string) (res []basecoin.Actor) { for _, p := range c.perms { if chain == p.ChainID { if app == "" || app == p.App { @@ -56,15 +71,21 @@ func (c mockContext) GetPermissions(chain, app string) (res []basecoin.Actor) { } // IsParent ensures that this is derived from the given secureClient -func (c mockContext) IsParent(other basecoin.Context) bool { - _, ok := other.(mockContext) - return ok +func (c naiveContext) IsParent(other basecoin.Context) bool { + nc, ok := other.(naiveContext) + if !ok { + return false + } + return c.id == nc.id } // Reset should clear out all permissions, // but carry on knowledge that this is a child -func (c mockContext) Reset() basecoin.Context { - return mockContext{ +func (c naiveContext) Reset() basecoin.Context { + return naiveContext{ + id: c.id, + chain: c.chain, + height: c.height, Logger: c.Logger, } } diff --git a/stack/recovery_test.go b/stack/recovery_test.go index 15eafba7b9..944d0db315 100644 --- a/stack/recovery_test.go +++ b/stack/recovery_test.go @@ -17,7 +17,7 @@ func TestRecovery(t *testing.T) { assert := assert.New(t) // generic args here... - ctx := NewContext("test-chain", log.NewNopLogger()) + ctx := NewContext("test-chain", 20, log.NewNopLogger()) store := state.NewMemKVStore() tx := basecoin.Tx{} From 765f52e402f5410db9d79914affdf8cda775154c Mon Sep 17 00:00:00 2001 From: Ethan Frey Date: Mon, 10 Jul 2017 12:36:30 +0200 Subject: [PATCH 4/5] Enforce the expiration height in Chain middleware --- errors/common.go | 8 ++++++++ modules/base/chain.go | 9 ++++++--- modules/base/chain_test.go | 14 ++++++++++++-- modules/base/tx.go | 7 ++++++- 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/errors/common.go b/errors/common.go index d38f6b07b6..c04af0984d 100644 --- a/errors/common.go +++ b/errors/common.go @@ -21,6 +21,7 @@ var ( errUnknownTxType = fmt.Errorf("Tx type unknown") errInvalidFormat = fmt.Errorf("Invalid format") errUnknownModule = fmt.Errorf("Unknown module") + errExpired = fmt.Errorf("Tx expired") ) // some crazy reflection to unwrap any generated struct. @@ -130,3 +131,10 @@ func ErrTooLarge() TMError { func IsTooLargeErr(err error) bool { return IsSameError(errTooLarge, err) } + +func ErrExpired() TMError { + return WithCode(errExpired, abci.CodeType_Unauthorized) +} +func IsExpiredErr(err error) bool { + return IsSameError(errExpired, err) +} diff --git a/modules/base/chain.go b/modules/base/chain.go index 4fc7071e63..8161f8802e 100644 --- a/modules/base/chain.go +++ b/modules/base/chain.go @@ -26,7 +26,7 @@ var _ stack.Middleware = Chain{} // CheckTx makes sure we are on the proper chain - fulfills Middlware interface func (c Chain) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Checker) (res basecoin.Result, err error) { - stx, err := c.checkChain(ctx.ChainID(), tx) + stx, err := c.checkChain(ctx.ChainID(), ctx.BlockHeight(), tx) if err != nil { return res, err } @@ -35,7 +35,7 @@ func (c Chain) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx // DeliverTx makes sure we are on the proper chain - fulfills Middlware interface func (c Chain) DeliverTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Deliver) (res basecoin.Result, err error) { - stx, err := c.checkChain(ctx.ChainID(), tx) + stx, err := c.checkChain(ctx.ChainID(), ctx.BlockHeight(), tx) if err != nil { return res, err } @@ -43,7 +43,7 @@ func (c Chain) DeliverTx(ctx basecoin.Context, store state.KVStore, tx basecoin. } // checkChain makes sure the tx is a Chain Tx and is on the proper chain -func (c Chain) checkChain(chainID string, tx basecoin.Tx) (basecoin.Tx, error) { +func (c Chain) checkChain(chainID string, height uint64, tx basecoin.Tx) (basecoin.Tx, error) { // make sure it is a chaintx ctx, ok := tx.Unwrap().(ChainTx) if !ok { @@ -60,5 +60,8 @@ func (c Chain) checkChain(chainID string, tx basecoin.Tx) (basecoin.Tx, error) { if ctx.ChainID != chainID { return tx, errors.ErrWrongChain(ctx.ChainID) } + if ctx.ExpiresAt != 0 && ctx.ExpiresAt <= height { + return tx, errors.ErrExpired() + } return ctx.Tx, nil } diff --git a/modules/base/chain_test.go b/modules/base/chain_test.go index 21f48dc6fb..8da281cddd 100644 --- a/modules/base/chain_test.go +++ b/modules/base/chain_test.go @@ -48,6 +48,7 @@ func TestChain(t *testing.T) { assert := assert.New(t) msg := "got it" chainID := "my-chain" + height := uint64(100) raw := stack.NewRawTx([]byte{1, 2, 3, 4}) cases := []struct { @@ -55,13 +56,22 @@ func TestChain(t *testing.T) { valid bool errorMsg string }{ + // check the chain ids are validated {NewChainTx(chainID, 0, raw), true, ""}, - {NewChainTx("someone-else", 0, raw), false, "someone-else"}, + // non-matching chainid, or impossible chain id + {NewChainTx("someone-else", 0, raw), false, "someone-else: Wrong chain"}, + {NewChainTx("Inval$$d:CH%%n", 0, raw), false, "Wrong chain"}, + // Wrong tx type {raw, false, "No chain id provided"}, + // Check different heights - must be 0 or higher than current height + {NewChainTx(chainID, height+1, raw), true, ""}, + {NewChainTx(chainID, height, raw), false, "Tx expired"}, + {NewChainTx(chainID, 1, raw), false, "expired"}, + {NewChainTx(chainID, 0, raw), true, ""}, } // generic args here... - ctx := stack.NewContext(chainID, 100, log.NewNopLogger()) + ctx := stack.NewContext(chainID, height, log.NewNopLogger()) store := state.NewMemKVStore() // build the stack diff --git a/modules/base/tx.go b/modules/base/tx.go index 30320bedcf..819811ef6c 100644 --- a/modules/base/tx.go +++ b/modules/base/tx.go @@ -69,7 +69,12 @@ var ( //nolint - TxInner Functions func NewChainTx(chainID string, expires uint64, tx basecoin.Tx) basecoin.Tx { - return (ChainTx{Tx: tx, ChainID: chainID}).Wrap() + c := ChainTx{ + ChainID: chainID, + ExpiresAt: expires, + Tx: tx, + } + return c.Wrap() } func (c ChainTx) Wrap() basecoin.Tx { return basecoin.Tx{c} From 64f2c63e21f48683ab03ee538924bd03e057c0c8 Mon Sep 17 00:00:00 2001 From: Ethan Frey Date: Tue, 11 Jul 2017 13:44:44 +0200 Subject: [PATCH 5/5] Fixes as per Rigels comments on PR --- cmd/basecli/commands/cmds.go | 9 +++++---- modules/base/chain.go | 9 +++++---- modules/base/tx.go | 11 ++++++++--- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/cmd/basecli/commands/cmds.go b/cmd/basecli/commands/cmds.go index e5fff36342..4c5a90640e 100644 --- a/cmd/basecli/commands/cmds.go +++ b/cmd/basecli/commands/cmds.go @@ -44,7 +44,7 @@ func init() { flags.String(FlagAmount, "", "Coins to send in the format ,...") flags.String(FlagFee, "0mycoin", "Coins for the transaction fee of the format ") flags.Uint64(FlagGas, 0, "Amount of gas for this transaction") - flags.Int64(FlagExpires, 0, "Block height at which this tx expires") + flags.Uint64(FlagExpires, 0, "Block height at which this tx expires") flags.Int(FlagSequence, -1, "Sequence number for this transaction") } @@ -86,10 +86,11 @@ func doSendTx(cmd *cobra.Command, args []string) error { // WrapChainTx will wrap the tx with a ChainTx from the standard flags func WrapChainTx(tx basecoin.Tx) (res basecoin.Tx, err error) { expires := viper.GetInt64(FlagExpires) - if expires < 0 { - return res, errors.New("expires must be >= 0") + chain := commands.GetChainID() + if chain == "" { + return res, errors.New("No chain-id provided") } - res = base.NewChainTx(commands.GetChainID(), uint64(expires), tx) + res = base.NewChainTx(chain, uint64(expires), tx) return res, nil } diff --git a/modules/base/chain.go b/modules/base/chain.go index 8161f8802e..65513679de 100644 --- a/modules/base/chain.go +++ b/modules/base/chain.go @@ -26,7 +26,7 @@ var _ stack.Middleware = Chain{} // CheckTx makes sure we are on the proper chain - fulfills Middlware interface func (c Chain) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Checker) (res basecoin.Result, err error) { - stx, err := c.checkChain(ctx.ChainID(), ctx.BlockHeight(), tx) + stx, err := c.checkChainTx(ctx.ChainID(), ctx.BlockHeight(), tx) if err != nil { return res, err } @@ -35,15 +35,16 @@ func (c Chain) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx // DeliverTx makes sure we are on the proper chain - fulfills Middlware interface func (c Chain) DeliverTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Deliver) (res basecoin.Result, err error) { - stx, err := c.checkChain(ctx.ChainID(), ctx.BlockHeight(), tx) + stx, err := c.checkChainTx(ctx.ChainID(), ctx.BlockHeight(), tx) if err != nil { return res, err } return next.DeliverTx(ctx, store, stx) } -// checkChain makes sure the tx is a Chain Tx and is on the proper chain -func (c Chain) checkChain(chainID string, height uint64, tx basecoin.Tx) (basecoin.Tx, error) { +// checkChainTx makes sure the tx is a Chain Tx, it is on the proper chain, +// and it has not expired. +func (c Chain) checkChainTx(chainID string, height uint64, tx basecoin.Tx) (basecoin.Tx, error) { // make sure it is a chaintx ctx, ok := tx.Unwrap().(ChainTx) if !ok { diff --git a/modules/base/tx.go b/modules/base/tx.go index 819811ef6c..fd66fad884 100644 --- a/modules/base/tx.go +++ b/modules/base/tx.go @@ -56,8 +56,10 @@ func (mt MultiTx) ValidateBasic() error { // ChainTx locks this tx to one chainTx, wrap with this before signing type ChainTx struct { - ChainID string `json:"chain_id"` // name of chain, must be [A-Za-z0-9_-]+ - ExpiresAt uint64 `json:"expires_at"` // block height at which it is no longer valid + // name of chain, must be [A-Za-z0-9_-]+ + ChainID string `json:"chain_id"` + // block height at which it is no longer valid, 0 means no expiration + ExpiresAt uint64 `json:"expires_at"` Tx basecoin.Tx `json:"tx"` } @@ -67,7 +69,8 @@ var ( chainPattern = regexp.MustCompile("^[A-Za-z0-9_-]+$") ) -//nolint - TxInner Functions +// NewChainTx wraps a particular tx with the ChainTx wrapper, +// to enforce chain and height func NewChainTx(chainID string, expires uint64, tx basecoin.Tx) basecoin.Tx { c := ChainTx{ ChainID: chainID, @@ -76,6 +79,8 @@ func NewChainTx(chainID string, expires uint64, tx basecoin.Tx) basecoin.Tx { } return c.Wrap() } + +//nolint - TxInner Functions func (c ChainTx) Wrap() basecoin.Tx { return basecoin.Tx{c} }