mpool: Make tests pass

This commit is contained in:
Łukasz Magiera 2019-12-03 22:09:39 +01:00
parent 569bcce878
commit d79f1c180d
2 changed files with 8 additions and 4 deletions

View File

@ -341,7 +341,7 @@ func (mp *MessagePool) addLocked(m *types.SignedMessage) error {
func (mp *MessagePool) GetNonce(addr address.Address) (uint64, error) { func (mp *MessagePool) GetNonce(addr address.Address) (uint64, error) {
mp.curTsLk.Lock() mp.curTsLk.Lock()
defer mp.curTsLk.Lock() defer mp.curTsLk.Unlock()
mp.lk.Lock() mp.lk.Lock()
defer mp.lk.Unlock() defer mp.lk.Unlock()
@ -415,7 +415,7 @@ func (mp *MessagePool) getStateBalance(addr address.Address) (types.BigInt, erro
func (mp *MessagePool) PushWithNonce(addr address.Address, cb func(uint64) (*types.SignedMessage, error)) (*types.SignedMessage, error) { func (mp *MessagePool) PushWithNonce(addr address.Address, cb func(uint64) (*types.SignedMessage, error)) (*types.SignedMessage, error) {
mp.curTsLk.Lock() mp.curTsLk.Lock()
defer mp.curTsLk.Lock() defer mp.curTsLk.Unlock()
mp.lk.Lock() mp.lk.Lock()
defer mp.lk.Unlock() defer mp.lk.Unlock()
@ -534,6 +534,8 @@ func (mp *MessagePool) HeadChange(revert []*types.TipSet, apply []*types.TipSet)
return err return err
} }
mp.curTs = pts
for _, msg := range msgs { for _, msg := range msgs {
if err := mp.addTs(msg, pts); err != nil { if err := mp.addTs(msg, pts); err != nil {
log.Error(err) // TODO: probably lots of spam in multi-block tsets log.Error(err) // TODO: probably lots of spam in multi-block tsets

View File

@ -52,8 +52,9 @@ func (tma *testMpoolApi) setBlockMessages(h *types.BlockHeader, msgs ...*types.S
tma.tipsets = append(tma.tipsets, mock.TipSet(h)) tma.tipsets = append(tma.tipsets, mock.TipSet(h))
} }
func (tma *testMpoolApi) SubscribeHeadChanges(cb func(rev, app []*types.TipSet) error) { func (tma *testMpoolApi) SubscribeHeadChanges(cb func(rev, app []*types.TipSet) error) *types.TipSet {
tma.cb = cb tma.cb = cb
return nil
} }
func (tma *testMpoolApi) PutMessage(m store.ChainMsg) (cid.Cid, error) { func (tma *testMpoolApi) PutMessage(m store.ChainMsg) (cid.Cid, error) {
@ -216,7 +217,8 @@ func TestRevertMessages(t *testing.T) {
assertNonce(t, mp, sender, 4) assertNonce(t, mp, sender, 4)
if len(mp.Pending()) != 3 { p, _ := mp.Pending()
if len(p) != 3 {
t.Fatal("expected three messages in mempool") t.Fatal("expected three messages in mempool")
} }