diff --git a/.pending/improvements/sdk/4706-Simplify-context b/.pending/improvements/sdk/4706-Simplify-context new file mode 100644 index 0000000000..d75da4affe --- /dev/null +++ b/.pending/improvements/sdk/4706-Simplify-context @@ -0,0 +1,11 @@ +4706 - Simplify context + +Replace complex Context construct with a simpler immutible struct. +Only breaking change is not to support `Value` and `GetValue` as first class calls. +We do embed ctx.Context() as a raw context.Context instead to be used as you see fit. + +Migration guide: +`ctx = ctx.WithValue(contextKeyBadProposal, false)` -> +`ctx = ctx.WithContext(context.WithValue(ctx.Context(), contextKeyBadProposal, false))` +A bit more verbose, but also allows `context.WithTimeout()`, etc and only used +in one function in this repo, in test code. diff --git a/types/context.go b/types/context.go index 364d56d3dd..ef2a2e7530 100644 --- a/types/context.go +++ b/types/context.go @@ -1,13 +1,10 @@ -// nolint package types import ( "context" - "sync" "time" - "github.com/golang/protobuf/proto" - + "github.com/gogo/protobuf/proto" abci "github.com/tendermint/tendermint/abci/types" "github.com/tendermint/tendermint/libs/log" @@ -16,172 +13,95 @@ import ( ) /* -The intent of Context is for it to be an immutable object that can be -cloned and updated cheaply with WithValue() and passed forward to the -next decorator or handler. For example, +Context is an immutable object contains all information needed to +process a request. - func MsgHandler(ctx Context, tx Tx) Result { - ... - ctx = ctx.WithValue(key, value) - ... - } +It contains a context.Context object inside if you want to use that, +but please do not over-use it. We try to keep all data structured +and standard additions here would be better just to add to the Context struct */ type Context struct { - context.Context - pst *thePast - gen int - // Don't add any other fields here, - // it's probably not what you want to do. + ctx context.Context + ms MultiStore + header abci.Header + chainID string + txBytes []byte + logger log.Logger + voteInfo []abci.VoteInfo + gasMeter GasMeter + blockGasMeter GasMeter + checkTx bool + minGasPrice DecCoins + consParams *abci.ConsensusParams + eventManager *EventManager +} + +// Proposed rename, not done to avoid API breakage +type Request = Context + +// Read-only accessors +func (c Context) Context() context.Context { return c.ctx } +func (c Context) MultiStore() MultiStore { return c.ms } +func (c Context) BlockHeight() int64 { return c.header.Height } +func (c Context) BlockTime() time.Time { return c.header.Time } +func (c Context) ChainID() string { return c.chainID } +func (c Context) TxBytes() []byte { return c.txBytes } +func (c Context) Logger() log.Logger { return c.logger } +func (c Context) VoteInfos() []abci.VoteInfo { return c.voteInfo } +func (c Context) GasMeter() GasMeter { return c.gasMeter } +func (c Context) BlockGasMeter() GasMeter { return c.blockGasMeter } +func (c Context) IsCheckTx() bool { return c.checkTx } +func (c Context) MinGasPrices() DecCoins { return c.minGasPrice } +func (c Context) EventManager() *EventManager { return c.eventManager } + +// clone the header before returning +func (c Context) BlockHeader() abci.Header { + var msg = proto.Clone(&c.header).(*abci.Header) + return *msg +} + +func (c Context) ConsensusParams() *abci.ConsensusParams { + return proto.Clone(c.consParams).(*abci.ConsensusParams) } // create a new context func NewContext(ms MultiStore, header abci.Header, isCheckTx bool, logger log.Logger) Context { - c := Context{ - Context: context.Background(), - pst: newThePast(), - gen: 0, + // https://github.com/gogo/protobuf/issues/519 + header.Time = header.Time.UTC() + return Context{ + ctx: context.Background(), + ms: ms, + header: header, + chainID: header.ChainID, + checkTx: isCheckTx, + logger: logger, + gasMeter: stypes.NewInfiniteGasMeter(), + minGasPrice: DecCoins{}, + eventManager: NewEventManager(), } +} - c = c.WithMultiStore(ms) - c = c.WithBlockHeader(header) - c = c.WithChainID(header.ChainID) - c = c.WithIsCheckTx(isCheckTx) - c = c.WithTxBytes(nil) - c = c.WithLogger(logger) - c = c.WithVoteInfos(nil) - c = c.WithGasMeter(stypes.NewInfiniteGasMeter()) - c = c.WithMinGasPrices(DecCoins{}) - c = c.WithConsensusParams(nil) - c = c.WithEventManager(NewEventManager()) - +func (c Context) WithContext(ctx context.Context) Context { + c.ctx = ctx return c } -// is context nil -func (c Context) IsZero() bool { - return c.Context == nil -} - -// ---------------------------------------------------------------------------- -// Getters -// ---------------------------------------------------------------------------- - -type contextKey int // local to the context module - -const ( - contextKeyMultiStore contextKey = iota - contextKeyBlockHeader - contextKeyChainID - contextKeyIsCheckTx - contextKeyTxBytes - contextKeyLogger - contextKeyVoteInfos - contextKeyGasMeter - contextKeyBlockGasMeter - contextKeyMinGasPrices - contextKeyConsensusParams - contextKeyEventManager -) - -// context value for the provided key -func (c Context) Value(key interface{}) interface{} { - value := c.Context.Value(key) - if cloner, ok := value.(cloner); ok { - return cloner.Clone() - } - if message, ok := value.(proto.Message); ok { - return proto.Clone(message) - } - return value -} - -func (c Context) MultiStore() MultiStore { - return c.Value(contextKeyMultiStore).(MultiStore) -} - -func (c Context) BlockHeader() abci.Header { return c.Value(contextKeyBlockHeader).(abci.Header) } - -func (c Context) BlockHeight() int64 { return c.BlockHeader().Height } - -func (c Context) ChainID() string { return c.Value(contextKeyChainID).(string) } - -func (c Context) TxBytes() []byte { return c.Value(contextKeyTxBytes).([]byte) } - -func (c Context) Logger() log.Logger { return c.Value(contextKeyLogger).(log.Logger) } - -func (c Context) VoteInfos() []abci.VoteInfo { - return c.Value(contextKeyVoteInfos).([]abci.VoteInfo) -} - -func (c Context) GasMeter() GasMeter { return c.Value(contextKeyGasMeter).(GasMeter) } - -func (c Context) BlockGasMeter() GasMeter { return c.Value(contextKeyBlockGasMeter).(GasMeter) } - -func (c Context) IsCheckTx() bool { return c.Value(contextKeyIsCheckTx).(bool) } - -func (c Context) MinGasPrices() DecCoins { return c.Value(contextKeyMinGasPrices).(DecCoins) } - -func (c Context) ConsensusParams() *abci.ConsensusParams { - return c.Value(contextKeyConsensusParams).(*abci.ConsensusParams) -} - -func (c Context) EventManager() *EventManager { return c.Value(contextKeyEventManager).(*EventManager) } - -// ---------------------------------------------------------------------------- -// Setters -// ---------------------------------------------------------------------------- - -func (c Context) WithValue(key interface{}, value interface{}) Context { - return c.withValue(key, value) -} -func (c Context) WithCloner(key interface{}, value cloner) Context { - return c.withValue(key, value) -} -func (c Context) WithCacheWrapper(key interface{}, value CacheWrapper) Context { - return c.withValue(key, value) -} -func (c Context) WithProtoMsg(key interface{}, value proto.Message) Context { - return c.withValue(key, value) -} -func (c Context) WithString(key interface{}, value string) Context { - return c.withValue(key, value) -} -func (c Context) WithInt32(key interface{}, value int32) Context { - return c.withValue(key, value) -} -func (c Context) WithUint32(key interface{}, value uint32) Context { - return c.withValue(key, value) -} -func (c Context) WithUint64(key interface{}, value uint64) Context { - return c.withValue(key, value) -} - -func (c Context) withValue(key interface{}, value interface{}) Context { - c.pst.bump(Op{ - gen: c.gen + 1, - key: key, - value: value, - }) // increment version for all relatives. - - return Context{ - Context: context.WithValue(c.Context, key, value), - pst: c.pst, - gen: c.gen + 1, - } -} - func (c Context) WithMultiStore(ms MultiStore) Context { - return c.withValue(contextKeyMultiStore, ms) + c.ms = ms + return c } func (c Context) WithBlockHeader(header abci.Header) Context { - var _ proto.Message = &header // for cloning. - return c.withValue(contextKeyBlockHeader, header) + // https://github.com/gogo/protobuf/issues/519 + header.Time = header.Time.UTC() + c.header = header + return c } func (c Context) WithBlockTime(newTime time.Time) Context { newHeader := c.BlockHeader() - newHeader.Time = newTime + // https://github.com/gogo/protobuf/issues/519 + newHeader.Time = newTime.UTC() return c.WithBlockHeader(newHeader) } @@ -197,36 +117,78 @@ func (c Context) WithBlockHeight(height int64) Context { return c.WithBlockHeader(newHeader) } -func (c Context) WithChainID(chainID string) Context { return c.withValue(contextKeyChainID, chainID) } - -func (c Context) WithTxBytes(txBytes []byte) Context { return c.withValue(contextKeyTxBytes, txBytes) } - -func (c Context) WithLogger(logger log.Logger) Context { return c.withValue(contextKeyLogger, logger) } - -func (c Context) WithVoteInfos(VoteInfos []abci.VoteInfo) Context { - return c.withValue(contextKeyVoteInfos, VoteInfos) +func (c Context) WithChainID(chainID string) Context { + c.chainID = chainID + return c } -func (c Context) WithGasMeter(meter GasMeter) Context { return c.withValue(contextKeyGasMeter, meter) } +func (c Context) WithTxBytes(txBytes []byte) Context { + c.txBytes = txBytes + return c +} + +func (c Context) WithLogger(logger log.Logger) Context { + c.logger = logger + return c +} + +func (c Context) WithVoteInfos(voteInfo []abci.VoteInfo) Context { + c.voteInfo = voteInfo + return c +} + +func (c Context) WithGasMeter(meter GasMeter) Context { + c.gasMeter = meter + return c +} func (c Context) WithBlockGasMeter(meter GasMeter) Context { - return c.withValue(contextKeyBlockGasMeter, meter) + c.blockGasMeter = meter + return c } func (c Context) WithIsCheckTx(isCheckTx bool) Context { - return c.withValue(contextKeyIsCheckTx, isCheckTx) + c.checkTx = isCheckTx + return c } func (c Context) WithMinGasPrices(gasPrices DecCoins) Context { - return c.withValue(contextKeyMinGasPrices, gasPrices) + c.minGasPrice = gasPrices + return c } func (c Context) WithConsensusParams(params *abci.ConsensusParams) Context { - return c.withValue(contextKeyConsensusParams, params) + c.consParams = params + return c } func (c Context) WithEventManager(em *EventManager) Context { - return c.WithValue(contextKeyEventManager, em) + c.eventManager = em + return c +} + +// TODO: remove??? +func (c Context) IsZero() bool { + return c.ms == nil +} + +// WithValue is deprecated, provided for backwards compatibility +// Please use +// ctx = ctx.WithContext(context.WithValue(ctx.Context(), key, false)) +// instead of +// ctx = ctx.WithValue(key, false) +func (c Context) WithValue(key, value interface{}) Context { + c.ctx = context.WithValue(c.ctx, key, value) + return c +} + +// Value is deprecated, provided for backwards compatibility +// Please use +// ctx.Context().Value(key) +// instead of +// ctx.Value(key) +func (c Context) Value(key interface{}) interface{} { + return c.ctx.Value(key) } // ---------------------------------------------------------------------------- @@ -250,65 +212,3 @@ func (c Context) CacheContext() (cc Context, writeCache func()) { cc = c.WithMultiStore(cms) return cc, cms.Write } - -//---------------------------------------- -// thePast - -// Returns false if ver <= 0 || ver > len(c.pst.ops). -// The first operation is version 1. -func (c Context) GetOp(ver int64) (Op, bool) { - return c.pst.getOp(ver) -} - -//---------------------------------------- -// Misc. - -type cloner interface { - Clone() interface{} // deep copy -} - -// TODO add description -type Op struct { - // type is always 'with' - gen int - key interface{} - value interface{} -} - -type thePast struct { - mtx sync.RWMutex - ver int - ops []Op -} - -func newThePast() *thePast { - return &thePast{ - ver: 0, - ops: nil, - } -} - -func (pst *thePast) bump(op Op) { - pst.mtx.Lock() - pst.ver++ - pst.ops = append(pst.ops, op) - pst.mtx.Unlock() -} - -func (pst *thePast) version() int { - pst.mtx.RLock() - defer pst.mtx.RUnlock() - return pst.ver -} - -// Returns false if ver <= 0 || ver > len(pst.ops). -// The first operation is version 1. -func (pst *thePast) getOp(ver int64) (Op, bool) { - pst.mtx.RLock() - defer pst.mtx.RUnlock() - l := int64(len(pst.ops)) - if l < ver || ver <= 0 { - return Op{}, false - } - return pst.ops[ver-1], true -} diff --git a/types/context_test.go b/types/context_test.go index bfa66e0745..0c34f061f8 100644 --- a/types/context_test.go +++ b/types/context_test.go @@ -44,18 +44,6 @@ func (l MockLogger) With(kvs ...interface{}) log.Logger { panic("not implemented") } -func TestContextGetOpShouldNeverPanic(t *testing.T) { - var ms types.MultiStore - ctx := types.NewContext(ms, abci.Header{}, false, log.NewNopLogger()) - indices := []int64{ - -10, 1, 0, 10, 20, - } - - for _, index := range indices { - _, _ = ctx.GetOp(index) - } -} - func defaultContext(key types.StoreKey) types.Context { db := dbm.NewMemDB() cms := store.NewCommitMultiStore(db) @@ -109,55 +97,11 @@ func (d dummy) Clone() interface{} { return d } -// Testing saving/loading primitive values to/from the context -func TestContextWithPrimitive(t *testing.T) { - ctx := types.NewContext(nil, abci.Header{}, false, log.NewNopLogger()) - - clonerkey := "cloner" - stringkey := "string" - int32key := "int32" - uint32key := "uint32" - uint64key := "uint64" - - keys := []string{clonerkey, stringkey, int32key, uint32key, uint64key} - - for _, key := range keys { - require.Nil(t, ctx.Value(key)) - } - - clonerval := dummy(1) - stringval := "string" - int32val := int32(1) - uint32val := uint32(2) - uint64val := uint64(3) - - ctx = ctx. - WithCloner(clonerkey, clonerval). - WithString(stringkey, stringval). - WithInt32(int32key, int32val). - WithUint32(uint32key, uint32val). - WithUint64(uint64key, uint64val) - - require.Equal(t, clonerval, ctx.Value(clonerkey)) - require.Equal(t, stringval, ctx.Value(stringkey)) - require.Equal(t, int32val, ctx.Value(int32key)) - require.Equal(t, uint32val, ctx.Value(uint32key)) - require.Equal(t, uint64val, ctx.Value(uint64key)) -} - // Testing saving/loading sdk type values to/from the context func TestContextWithCustom(t *testing.T) { var ctx types.Context require.True(t, ctx.IsZero()) - require.Panics(t, func() { ctx.BlockHeader() }) - require.Panics(t, func() { ctx.BlockHeight() }) - require.Panics(t, func() { ctx.ChainID() }) - require.Panics(t, func() { ctx.TxBytes() }) - require.Panics(t, func() { ctx.Logger() }) - require.Panics(t, func() { ctx.VoteInfos() }) - require.Panics(t, func() { ctx.GasMeter() }) - header := abci.Header{} height := int64(1) chainid := "chainid" @@ -192,9 +136,6 @@ func TestContextWithCustom(t *testing.T) { func TestContextHeader(t *testing.T) { var ctx types.Context - require.Panics(t, func() { ctx.BlockHeader() }) - require.Panics(t, func() { ctx.BlockHeight() }) - height := int64(5) time := time.Now() addr := secp256k1.GenPrivKey().PubKey().Address() @@ -208,6 +149,61 @@ func TestContextHeader(t *testing.T) { WithProposer(proposer) require.Equal(t, height, ctx.BlockHeight()) require.Equal(t, height, ctx.BlockHeader().Height) - require.Equal(t, time, ctx.BlockHeader().Time) + require.Equal(t, time.UTC(), ctx.BlockHeader().Time) require.Equal(t, proposer.Bytes(), ctx.BlockHeader().ProposerAddress) } + +func TestContextHeaderClone(t *testing.T) { + cases := map[string]struct { + h abci.Header + }{ + "empty": { + h: abci.Header{}, + }, + "height": { + h: abci.Header{ + Height: 77, + }, + }, + "time": { + h: abci.Header{ + Time: time.Unix(12345677, 12345), + }, + }, + "zero time": { + h: abci.Header{ + Time: time.Unix(0, 0), + }, + }, + "many items": { + h: abci.Header{ + Height: 823, + Time: time.Unix(9999999999, 0), + ChainID: "silly-demo", + }, + }, + "many items with hash": { + h: abci.Header{ + Height: 823, + Time: time.Unix(9999999999, 0), + ChainID: "silly-demo", + AppHash: []byte{5, 34, 11, 3, 23}, + ConsensusHash: []byte{11, 3, 23, 87, 3, 1}, + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + ctx := types.NewContext(nil, tc.h, false, nil) + require.Equal(t, tc.h.Height, ctx.BlockHeight()) + require.Equal(t, tc.h.Time.UTC(), ctx.BlockTime()) + + // update only changes one field + var newHeight int64 = 17 + ctx = ctx.WithBlockHeight(newHeight) + require.Equal(t, newHeight, ctx.BlockHeight()) + require.Equal(t, tc.h.Time.UTC(), ctx.BlockTime()) + }) + } +}