From 1f03a618f9abfcb107575e929d4a467f40391d5f Mon Sep 17 00:00:00 2001 From: Aayush Rajasekaran Date: Tue, 18 May 2021 14:56:42 -0400 Subject: [PATCH] Plumb contexts through --- chain/messagepool/messagepool.go | 103 ++++++++++------------ chain/messagepool/messagepool_test.go | 46 +++++----- chain/messagepool/pruning.go | 2 +- chain/messagepool/repub_test.go | 4 +- chain/messagepool/selection_test.go | 4 +- chain/messagesigner/messagesigner.go | 8 +- chain/messagesigner/messagesigner_test.go | 4 +- chain/sub/incoming.go | 2 +- node/impl/full/gas.go | 2 +- node/impl/full/mpool.go | 12 +-- node/modules/chain.go | 3 +- node/modules/mpoolnonceapi.go | 2 +- 12 files changed, 94 insertions(+), 98 deletions(-) diff --git a/chain/messagepool/messagepool.go b/chain/messagepool/messagepool.go index 299634c6f..0c8569a1f 100644 --- a/chain/messagepool/messagepool.go +++ b/chain/messagepool/messagepool.go @@ -331,7 +331,7 @@ func (ms *msgSet) getRequiredFunds(nonce uint64) types.BigInt { return types.BigInt{Int: requiredFunds} } -func New(api Provider, ds dtypes.MetadataDS, netName dtypes.NetworkName, j journal.Journal) (*MessagePool, error) { +func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes.NetworkName, j journal.Journal) (*MessagePool, error) { cache, _ := lru.New2Q(build.BlsSignatureCacheSize) verifcache, _ := lru.New2Q(build.VerifSigCacheSize) @@ -375,7 +375,7 @@ func New(api Provider, ds dtypes.MetadataDS, netName dtypes.NetworkName, j journ // load the current tipset and subscribe to head changes _before_ loading local messages mp.curTs = api.SubscribeHeadChanges(func(rev, app []*types.TipSet) error { - err := mp.HeadChange(rev, app) + err := mp.HeadChange(ctx, rev, app) if err != nil { log.Errorf("mpool head notif handler error: %+v", err) } @@ -386,7 +386,7 @@ func New(api Provider, ds dtypes.MetadataDS, netName dtypes.NetworkName, j journ mp.lk.Lock() go func() { - err := mp.loadLocal() + err := mp.loadLocal(ctx) mp.lk.Unlock() mp.curTsLk.Unlock() @@ -444,9 +444,8 @@ func (mp *MessagePool) runLoop() { } } -func (mp *MessagePool) addLocal(m *types.SignedMessage) error { - // TODO: Is context.TODO() safe here? Idk how Go works. - sk, err := mp.api.StateAccountKey(context.TODO(), m.Message.From, mp.curTs) +func (mp *MessagePool) addLocal(ctx context.Context, m *types.SignedMessage) error { + sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs) if err != nil { log.Debugf("mpooladdlocal failed to resolve sender: %s", err) return err @@ -519,7 +518,7 @@ func (mp *MessagePool) verifyMsgBeforeAdd(m *types.SignedMessage, curTs *types.T return publish, nil } -func (mp *MessagePool) Push(m *types.SignedMessage) (cid.Cid, error) { +func (mp *MessagePool) Push(ctx context.Context, m *types.SignedMessage) (cid.Cid, error) { err := mp.checkMessage(m) if err != nil { return cid.Undef, err @@ -532,7 +531,7 @@ func (mp *MessagePool) Push(m *types.SignedMessage) (cid.Cid, error) { }() mp.curTsLk.Lock() - publish, err := mp.addTs(m, mp.curTs, true, false) + publish, err := mp.addTs(ctx, m, mp.curTs, true, false) if err != nil { mp.curTsLk.Unlock() return cid.Undef, err @@ -585,7 +584,7 @@ func (mp *MessagePool) checkMessage(m *types.SignedMessage) error { return nil } -func (mp *MessagePool) Add(m *types.SignedMessage) error { +func (mp *MessagePool) Add(ctx context.Context, m *types.SignedMessage) error { err := mp.checkMessage(m) if err != nil { return err @@ -600,7 +599,7 @@ func (mp *MessagePool) Add(m *types.SignedMessage) error { mp.curTsLk.Lock() defer mp.curTsLk.Unlock() - _, err = mp.addTs(m, mp.curTs, false, false) + _, err = mp.addTs(ctx, m, mp.curTs, false, false) return err } @@ -640,7 +639,7 @@ func (mp *MessagePool) VerifyMsgSig(m *types.SignedMessage) error { return nil } -func (mp *MessagePool) checkBalance(m *types.SignedMessage, curTs *types.TipSet) error { +func (mp *MessagePool) checkBalance(ctx context.Context, m *types.SignedMessage, curTs *types.TipSet) error { balance, err := mp.getStateBalance(m.Message.From, curTs) if err != nil { return xerrors.Errorf("failed to check sender balance: %s: %w", err, ErrSoftValidationFailure) @@ -654,8 +653,7 @@ func (mp *MessagePool) checkBalance(m *types.SignedMessage, curTs *types.TipSet) // add Value for soft failure check //requiredFunds = types.BigAdd(requiredFunds, m.Message.Value) - // TODO: Is context.TODO() safe here? Idk how Go works. - sk, err := mp.api.StateAccountKey(context.TODO(), m.Message.From, mp.curTs) + sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs) if err != nil { log.Debugf("mpoolcheckbalance failed to resolve sender: %s", err) return err @@ -675,7 +673,7 @@ func (mp *MessagePool) checkBalance(m *types.SignedMessage, curTs *types.TipSet) return nil } -func (mp *MessagePool) addTs(m *types.SignedMessage, curTs *types.TipSet, local, untrusted bool) (bool, error) { +func (mp *MessagePool) addTs(ctx context.Context, m *types.SignedMessage, curTs *types.TipSet, local, untrusted bool) (bool, error) { snonce, err := mp.getStateNonce(m.Message.From, curTs) if err != nil { return false, xerrors.Errorf("failed to look up actor state nonce: %s: %w", err, ErrSoftValidationFailure) @@ -693,17 +691,17 @@ func (mp *MessagePool) addTs(m *types.SignedMessage, curTs *types.TipSet, local, return false, err } - if err := mp.checkBalance(m, curTs); err != nil { + if err := mp.checkBalance(ctx, m, curTs); err != nil { return false, err } - err = mp.addLocked(m, !local, untrusted) + err = mp.addLocked(ctx, m, !local, untrusted) if err != nil { return false, err } if local { - err = mp.addLocal(m) + err = mp.addLocal(ctx, m) if err != nil { return false, xerrors.Errorf("error persisting local message: %w", err) } @@ -712,7 +710,7 @@ func (mp *MessagePool) addTs(m *types.SignedMessage, curTs *types.TipSet, local, return publish, nil } -func (mp *MessagePool) addLoaded(m *types.SignedMessage) error { +func (mp *MessagePool) addLoaded(ctx context.Context, m *types.SignedMessage) error { err := mp.checkMessage(m) if err != nil { return err @@ -738,21 +736,21 @@ func (mp *MessagePool) addLoaded(m *types.SignedMessage) error { return err } - if err := mp.checkBalance(m, curTs); err != nil { + if err := mp.checkBalance(ctx, m, curTs); err != nil { return err } - return mp.addLocked(m, false, false) + return mp.addLocked(ctx, m, false, false) } -func (mp *MessagePool) addSkipChecks(m *types.SignedMessage) error { +func (mp *MessagePool) addSkipChecks(ctx context.Context, m *types.SignedMessage) error { mp.lk.Lock() defer mp.lk.Unlock() - return mp.addLocked(m, false, false) + return mp.addLocked(ctx, m, false, false) } -func (mp *MessagePool) addLocked(m *types.SignedMessage, strict, untrusted bool) error { +func (mp *MessagePool) addLocked(ctx context.Context, m *types.SignedMessage, strict, untrusted bool) error { log.Debugf("mpooladd: %s %d", m.Message.From, m.Message.Nonce) if m.Signature.Type == crypto.SigTypeBLS { mp.blsSigCache.Add(m.Cid(), m.Signature) @@ -768,8 +766,7 @@ func (mp *MessagePool) addLocked(m *types.SignedMessage, strict, untrusted bool) return err } - // TODO: Is context.TODO() safe here? Idk how Go works. - sk, err := mp.api.StateAccountKey(context.TODO(), m.Message.From, mp.curTs) + sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs) if err != nil { log.Debugf("mpooladd failed to resolve sender: %s", err) return err @@ -818,24 +815,23 @@ func (mp *MessagePool) addLocked(m *types.SignedMessage, strict, untrusted bool) return nil } -func (mp *MessagePool) GetNonce(addr address.Address) (uint64, error) { +func (mp *MessagePool) GetNonce(ctx context.Context, addr address.Address) (uint64, error) { mp.curTsLk.Lock() defer mp.curTsLk.Unlock() mp.lk.Lock() defer mp.lk.Unlock() - return mp.getNonceLocked(addr, mp.curTs) + return mp.getNonceLocked(ctx, addr, mp.curTs) } -func (mp *MessagePool) getNonceLocked(addr address.Address, curTs *types.TipSet) (uint64, error) { +func (mp *MessagePool) getNonceLocked(ctx context.Context, addr address.Address, curTs *types.TipSet) (uint64, error) { stateNonce, err := mp.getStateNonce(addr, curTs) // sanity check if err != nil { return 0, err } - // TODO: Is context.TODO() safe here? Idk how Go works. - sk, err := mp.api.StateAccountKey(context.TODO(), addr, mp.curTs) + sk, err := mp.api.StateAccountKey(ctx, addr, mp.curTs) if err != nil { log.Debugf("mpoolgetnonce failed to resolve sender: %s", err) return 0, err @@ -878,7 +874,7 @@ func (mp *MessagePool) getStateBalance(addr address.Address, ts *types.TipSet) ( // - strict checks are enabled // - extra strict add checks are used when adding the messages to the msgSet // that means: no nonce gaps, at most 10 pending messages for the actor -func (mp *MessagePool) PushUntrusted(m *types.SignedMessage) (cid.Cid, error) { +func (mp *MessagePool) PushUntrusted(ctx context.Context, m *types.SignedMessage) (cid.Cid, error) { err := mp.checkMessage(m) if err != nil { return cid.Undef, err @@ -891,7 +887,7 @@ func (mp *MessagePool) PushUntrusted(m *types.SignedMessage) (cid.Cid, error) { }() mp.curTsLk.Lock() - publish, err := mp.addTs(m, mp.curTs, true, true) + publish, err := mp.addTs(ctx, m, mp.curTs, true, true) if err != nil { mp.curTsLk.Unlock() return cid.Undef, err @@ -913,16 +909,15 @@ func (mp *MessagePool) PushUntrusted(m *types.SignedMessage) (cid.Cid, error) { return m.Cid(), nil } -func (mp *MessagePool) Remove(from address.Address, nonce uint64, applied bool) { +func (mp *MessagePool) Remove(ctx context.Context, from address.Address, nonce uint64, applied bool) { mp.lk.Lock() defer mp.lk.Unlock() - mp.remove(from, nonce, applied) + mp.remove(ctx, from, nonce, applied) } -func (mp *MessagePool) remove(from address.Address, nonce uint64, applied bool) { - // TODO: Is context.TODO() safe here? Idk how Go works. - sk, err := mp.api.StateAccountKey(context.TODO(), from, mp.curTs) +func (mp *MessagePool) remove(ctx context.Context, from address.Address, nonce uint64, applied bool) { + sk, err := mp.api.StateAccountKey(ctx, from, mp.curTs) if err != nil { log.Debugf("mpoolremove failed to resolve sender: %s", err) return @@ -957,37 +952,36 @@ func (mp *MessagePool) remove(from address.Address, nonce uint64, applied bool) } } -func (mp *MessagePool) Pending() ([]*types.SignedMessage, *types.TipSet) { +func (mp *MessagePool) Pending(ctx context.Context) ([]*types.SignedMessage, *types.TipSet) { mp.curTsLk.Lock() defer mp.curTsLk.Unlock() mp.lk.Lock() defer mp.lk.Unlock() - return mp.allPending() + return mp.allPending(ctx) } -func (mp *MessagePool) allPending() ([]*types.SignedMessage, *types.TipSet) { +func (mp *MessagePool) allPending(ctx context.Context) ([]*types.SignedMessage, *types.TipSet) { out := make([]*types.SignedMessage, 0) for a := range mp.pending { - out = append(out, mp.pendingFor(a)...) + out = append(out, mp.pendingFor(ctx, a)...) } return out, mp.curTs } -func (mp *MessagePool) PendingFor(a address.Address) ([]*types.SignedMessage, *types.TipSet) { +func (mp *MessagePool) PendingFor(ctx context.Context, a address.Address) ([]*types.SignedMessage, *types.TipSet) { mp.curTsLk.Lock() defer mp.curTsLk.Unlock() mp.lk.Lock() defer mp.lk.Unlock() - return mp.pendingFor(a), mp.curTs + return mp.pendingFor(ctx, a), mp.curTs } -func (mp *MessagePool) pendingFor(a address.Address) []*types.SignedMessage { - // TODO: Is context.TODO() safe here? Idk how Go works. - sk, err := mp.api.StateAccountKey(context.TODO(), a, mp.curTs) +func (mp *MessagePool) pendingFor(ctx context.Context, a address.Address) []*types.SignedMessage { + sk, err := mp.api.StateAccountKey(ctx, a, mp.curTs) if err != nil { log.Debugf("mpoolpendingfor failed to resolve sender: %s", err) return nil @@ -1011,7 +1005,7 @@ func (mp *MessagePool) pendingFor(a address.Address) []*types.SignedMessage { return set } -func (mp *MessagePool) HeadChange(revert []*types.TipSet, apply []*types.TipSet) error { +func (mp *MessagePool) HeadChange(ctx context.Context, revert []*types.TipSet, apply []*types.TipSet) error { mp.curTsLk.Lock() defer mp.curTsLk.Unlock() @@ -1028,7 +1022,7 @@ func (mp *MessagePool) HeadChange(revert []*types.TipSet, apply []*types.TipSet) rm := func(from address.Address, nonce uint64) { s, ok := rmsgs[from] if !ok { - mp.Remove(from, nonce, true) + mp.Remove(ctx, from, nonce, true) return } @@ -1037,7 +1031,7 @@ func (mp *MessagePool) HeadChange(revert []*types.TipSet, apply []*types.TipSet) return } - mp.Remove(from, nonce, true) + mp.Remove(ctx, from, nonce, true) } maybeRepub := func(cid cid.Cid) { @@ -1108,7 +1102,7 @@ func (mp *MessagePool) HeadChange(revert []*types.TipSet, apply []*types.TipSet) for _, s := range rmsgs { for _, msg := range s { - if err := mp.addSkipChecks(msg); err != nil { + if err := mp.addSkipChecks(ctx, msg); err != nil { log.Errorf("Failed to readd message from reorg to mpool: %s", err) } } @@ -1116,7 +1110,7 @@ func (mp *MessagePool) HeadChange(revert []*types.TipSet, apply []*types.TipSet) if len(revert) > 0 && futureDebug { mp.lk.Lock() - msgs, ts := mp.allPending() + msgs, ts := mp.allPending(ctx) mp.lk.Unlock() buckets := map[address.Address]*statBucket{} @@ -1323,7 +1317,7 @@ func (mp *MessagePool) Updates(ctx context.Context) (<-chan api.MpoolUpdate, err return out, nil } -func (mp *MessagePool) loadLocal() error { +func (mp *MessagePool) loadLocal(ctx context.Context) error { res, err := mp.localMsgs.Query(query.Query{}) if err != nil { return xerrors.Errorf("query local messages: %w", err) @@ -1339,7 +1333,7 @@ func (mp *MessagePool) loadLocal() error { return xerrors.Errorf("unmarshaling local message: %w", err) } - if err := mp.addLoaded(&sm); err != nil { + if err := mp.addLoaded(ctx, &sm); err != nil { if xerrors.Is(err, ErrNonceTooLow) { continue // todo: drop the message from local cache (if above certain confidence threshold) } @@ -1347,8 +1341,7 @@ func (mp *MessagePool) loadLocal() error { log.Errorf("adding local message: %+v", err) } - // TODO: Is context.TODO() safe here? Idk how Go works. - sk, err := mp.api.StateAccountKey(context.TODO(), sm.Message.From, mp.curTs) + sk, err := mp.api.StateAccountKey(ctx, sm.Message.From, mp.curTs) if err != nil { log.Debugf("mpoolloadLocal failed to resolve sender: %s", err) return err diff --git a/chain/messagepool/messagepool_test.go b/chain/messagepool/messagepool_test.go index e31df936c..3e5bad81f 100644 --- a/chain/messagepool/messagepool_test.go +++ b/chain/messagepool/messagepool_test.go @@ -199,7 +199,7 @@ func (tma *testMpoolAPI) ChainComputeBaseFee(ctx context.Context, ts *types.TipS func assertNonce(t *testing.T, mp *MessagePool, addr address.Address, val uint64) { t.Helper() - n, err := mp.GetNonce(addr) + n, err := mp.GetNonce(context.TODO(), addr) if err != nil { t.Fatal(err) } @@ -211,7 +211,7 @@ func assertNonce(t *testing.T, mp *MessagePool, addr address.Address, val uint64 func mustAdd(t *testing.T, mp *MessagePool, msg *types.SignedMessage) { t.Helper() - if err := mp.Add(msg); err != nil { + if err := mp.Add(context.TODO(), msg); err != nil { t.Fatal(err) } } @@ -226,7 +226,7 @@ func TestMessagePool(t *testing.T) { ds := datastore.NewMapDatastore() - mp, err := New(tma, ds, "mptest", nil) + mp, err := New(context.TODO(), tma, ds, "mptest", nil) if err != nil { t.Fatal(err) } @@ -267,7 +267,7 @@ func TestMessagePoolMessagesInEachBlock(t *testing.T) { ds := datastore.NewMapDatastore() - mp, err := New(tma, ds, "mptest", nil) + mp, err := New(context.TODO(), tma, ds, "mptest", nil) if err != nil { t.Fatal(err) } @@ -293,7 +293,7 @@ func TestMessagePoolMessagesInEachBlock(t *testing.T) { tma.applyBlock(t, a) tsa := mock.TipSet(a) - _, _ = mp.Pending() + _, _ = mp.Pending(context.TODO()) selm, _ := mp.SelectMessages(tsa, 1) if len(selm) == 0 { @@ -316,7 +316,7 @@ func TestRevertMessages(t *testing.T) { ds := datastore.NewMapDatastore() - mp, err := New(tma, ds, "mptest", nil) + mp, err := New(context.TODO(), tma, ds, "mptest", nil) if err != nil { t.Fatal(err) } @@ -355,7 +355,7 @@ func TestRevertMessages(t *testing.T) { assertNonce(t, mp, sender, 4) - p, _ := mp.Pending() + p, _ := mp.Pending(context.TODO()) fmt.Printf("%+v\n", p) if len(p) != 3 { t.Fatal("expected three messages in mempool") @@ -379,7 +379,7 @@ func TestPruningSimple(t *testing.T) { ds := datastore.NewMapDatastore() - mp, err := New(tma, ds, "mptest", nil) + mp, err := New(context.TODO(), tma, ds, "mptest", nil) if err != nil { t.Fatal(err) } @@ -396,14 +396,14 @@ func TestPruningSimple(t *testing.T) { for i := 0; i < 5; i++ { smsg := mock.MkMessage(sender, target, uint64(i), w) - if err := mp.Add(smsg); err != nil { + if err := mp.Add(context.TODO(), smsg); err != nil { t.Fatal(err) } } for i := 10; i < 50; i++ { smsg := mock.MkMessage(sender, target, uint64(i), w) - if err := mp.Add(smsg); err != nil { + if err := mp.Add(context.TODO(), smsg); err != nil { t.Fatal(err) } } @@ -413,7 +413,7 @@ func TestPruningSimple(t *testing.T) { mp.Prune() - msgs, _ := mp.Pending() + msgs, _ := mp.Pending(context.TODO()) if len(msgs) != 5 { t.Fatal("expected only 5 messages in pool, got: ", len(msgs)) } @@ -423,7 +423,7 @@ func TestLoadLocal(t *testing.T) { tma := newTestMpoolAPI() ds := datastore.NewMapDatastore() - mp, err := New(tma, ds, "mptest", nil) + mp, err := New(context.TODO(), tma, ds, "mptest", nil) if err != nil { t.Fatal(err) } @@ -455,7 +455,7 @@ func TestLoadLocal(t *testing.T) { msgs := make(map[cid.Cid]struct{}) for i := 0; i < 10; i++ { m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1)) - cid, err := mp.Push(m) + cid, err := mp.Push(context.TODO(), m) if err != nil { t.Fatal(err) } @@ -466,12 +466,12 @@ func TestLoadLocal(t *testing.T) { t.Fatal(err) } - mp, err = New(tma, ds, "mptest", nil) + mp, err = New(context.TODO(), tma, ds, "mptest", nil) if err != nil { t.Fatal(err) } - pmsgs, _ := mp.Pending() + pmsgs, _ := mp.Pending(context.TODO()) if len(msgs) != len(pmsgs) { t.Fatalf("expected %d messages, but got %d", len(msgs), len(pmsgs)) } @@ -495,7 +495,7 @@ func TestClearAll(t *testing.T) { tma := newTestMpoolAPI() ds := datastore.NewMapDatastore() - mp, err := New(tma, ds, "mptest", nil) + mp, err := New(context.TODO(), tma, ds, "mptest", nil) if err != nil { t.Fatal(err) } @@ -526,7 +526,7 @@ func TestClearAll(t *testing.T) { gasLimit := gasguess.Costs[gasguess.CostKey{Code: builtin2.StorageMarketActorCodeID, M: 2}] for i := 0; i < 10; i++ { m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1)) - _, err := mp.Push(m) + _, err := mp.Push(context.TODO(), m) if err != nil { t.Fatal(err) } @@ -539,7 +539,7 @@ func TestClearAll(t *testing.T) { mp.Clear(true) - pending, _ := mp.Pending() + pending, _ := mp.Pending(context.TODO()) if len(pending) > 0 { t.Fatalf("cleared the mpool, but got %d pending messages", len(pending)) } @@ -549,7 +549,7 @@ func TestClearNonLocal(t *testing.T) { tma := newTestMpoolAPI() ds := datastore.NewMapDatastore() - mp, err := New(tma, ds, "mptest", nil) + mp, err := New(context.TODO(), tma, ds, "mptest", nil) if err != nil { t.Fatal(err) } @@ -581,7 +581,7 @@ func TestClearNonLocal(t *testing.T) { gasLimit := gasguess.Costs[gasguess.CostKey{Code: builtin2.StorageMarketActorCodeID, M: 2}] for i := 0; i < 10; i++ { m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1)) - _, err := mp.Push(m) + _, err := mp.Push(context.TODO(), m) if err != nil { t.Fatal(err) } @@ -594,7 +594,7 @@ func TestClearNonLocal(t *testing.T) { mp.Clear(false) - pending, _ := mp.Pending() + pending, _ := mp.Pending(context.TODO()) if len(pending) != 10 { t.Fatalf("expected 10 pending messages, but got %d instead", len(pending)) } @@ -610,7 +610,7 @@ func TestUpdates(t *testing.T) { tma := newTestMpoolAPI() ds := datastore.NewMapDatastore() - mp, err := New(tma, ds, "mptest", nil) + mp, err := New(context.TODO(), tma, ds, "mptest", nil) if err != nil { t.Fatal(err) } @@ -651,7 +651,7 @@ func TestUpdates(t *testing.T) { for i := 0; i < 10; i++ { m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1)) - _, err := mp.Push(m) + _, err := mp.Push(context.TODO(), m) if err != nil { t.Fatal(err) } diff --git a/chain/messagepool/pruning.go b/chain/messagepool/pruning.go index dc1c69417..ad8f38c50 100644 --- a/chain/messagepool/pruning.go +++ b/chain/messagepool/pruning.go @@ -108,7 +108,7 @@ keepLoop: // and remove all messages that are still in pruneMsgs after processing the chains log.Infof("Pruning %d messages", len(pruneMsgs)) for _, m := range pruneMsgs { - mp.remove(m.Message.From, m.Message.Nonce, false) + mp.remove(ctx, m.Message.From, m.Message.Nonce, false) } return nil diff --git a/chain/messagepool/repub_test.go b/chain/messagepool/repub_test.go index 8da64f974..70e457aaa 100644 --- a/chain/messagepool/repub_test.go +++ b/chain/messagepool/repub_test.go @@ -24,7 +24,7 @@ func TestRepubMessages(t *testing.T) { tma := newTestMpoolAPI() ds := datastore.NewMapDatastore() - mp, err := New(tma, ds, "mptest", nil) + mp, err := New(context.TODO(), tma, ds, "mptest", nil) if err != nil { t.Fatal(err) } @@ -56,7 +56,7 @@ func TestRepubMessages(t *testing.T) { for i := 0; i < 10; i++ { m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1)) - _, err := mp.Push(m) + _, err := mp.Push(context.TODO(), m) if err != nil { t.Fatal(err) } diff --git a/chain/messagepool/selection_test.go b/chain/messagepool/selection_test.go index e32d897c4..f254c6706 100644 --- a/chain/messagepool/selection_test.go +++ b/chain/messagepool/selection_test.go @@ -60,7 +60,7 @@ func makeTestMessage(w *wallet.LocalWallet, from, to address.Address, nonce uint func makeTestMpool() (*MessagePool, *testMpoolAPI) { tma := newTestMpoolAPI() ds := datastore.NewMapDatastore() - mp, err := New(tma, ds, "test", nil) + mp, err := New(context.TODO(), tma, ds, "test", nil) if err != nil { panic(err) } @@ -464,7 +464,7 @@ func TestBasicMessageSelection(t *testing.T) { tma.applyBlock(t, block2) // we should have no pending messages in the mpool - pend, _ := mp.Pending() + pend, _ := mp.Pending(context.TODO()) if len(pend) != 0 { t.Fatalf("expected no pending messages, but got %d", len(pend)) } diff --git a/chain/messagesigner/messagesigner.go b/chain/messagesigner/messagesigner.go index ce9d01b3a..c64d00003 100644 --- a/chain/messagesigner/messagesigner.go +++ b/chain/messagesigner/messagesigner.go @@ -23,7 +23,7 @@ const dsKeyActorNonce = "ActorNextNonce" var log = logging.Logger("messagesigner") type MpoolNonceAPI interface { - GetNonce(address.Address) (uint64, error) + GetNonce(context.Context, address.Address) (uint64, error) } // MessageSigner keeps track of nonces per address, and increments the nonce @@ -51,7 +51,7 @@ func (ms *MessageSigner) SignMessage(ctx context.Context, msg *types.Message, cb defer ms.lk.Unlock() // Get the next message nonce - nonce, err := ms.nextNonce(msg.From) + nonce, err := ms.nextNonce(ctx, msg.From) if err != nil { return nil, xerrors.Errorf("failed to create nonce: %w", err) } @@ -92,12 +92,12 @@ func (ms *MessageSigner) SignMessage(ctx context.Context, msg *types.Message, cb // nextNonce gets the next nonce for the given address. // If there is no nonce in the datastore, gets the nonce from the message pool. -func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) { +func (ms *MessageSigner) nextNonce(ctx context.Context, addr address.Address) (uint64, error) { // Nonces used to be created by the mempool and we need to support nodes // that have mempool nonces, so first check the mempool for a nonce for // this address. Note that the mempool returns the actor state's nonce // by default. - nonce, err := ms.mpool.GetNonce(addr) + nonce, err := ms.mpool.GetNonce(ctx, addr) if err != nil { return 0, xerrors.Errorf("failed to get nonce from mempool: %w", err) } diff --git a/chain/messagesigner/messagesigner_test.go b/chain/messagesigner/messagesigner_test.go index 5eebd36da..8206b11c0 100644 --- a/chain/messagesigner/messagesigner_test.go +++ b/chain/messagesigner/messagesigner_test.go @@ -24,6 +24,8 @@ type mockMpool struct { nonces map[address.Address]uint64 } +var _ MpoolNonceAPI = (*mockMpool)(nil) + func newMockMpool() *mockMpool { return &mockMpool{nonces: make(map[address.Address]uint64)} } @@ -35,7 +37,7 @@ func (mp *mockMpool) setNonce(addr address.Address, nonce uint64) { mp.nonces[addr] = nonce } -func (mp *mockMpool) GetNonce(addr address.Address) (uint64, error) { +func (mp *mockMpool) GetNonce(ctx context.Context, addr address.Address) (uint64, error) { mp.lk.RLock() defer mp.lk.RUnlock() diff --git a/chain/sub/incoming.go b/chain/sub/incoming.go index d1c6414a1..e262fe271 100644 --- a/chain/sub/incoming.go +++ b/chain/sub/incoming.go @@ -516,7 +516,7 @@ func (mv *MessageValidator) Validate(ctx context.Context, pid peer.ID, msg *pubs return pubsub.ValidationReject } - if err := mv.mpool.Add(m); err != nil { + if err := mv.mpool.Add(ctx, m); err != nil { log.Debugf("failed to add message from network to message pool (From: %s, To: %s, Nonce: %d, Value: %s): %s", m.Message.From, m.Message.To, m.Message.Nonce, types.FIL(m.Message.Value), err) ctx, _ = tag.New( ctx, diff --git a/node/impl/full/gas.go b/node/impl/full/gas.go index 3d9889c10..acd2eccfe 100644 --- a/node/impl/full/gas.go +++ b/node/impl/full/gas.go @@ -265,7 +265,7 @@ func gasEstimateGasLimit( return -1, xerrors.Errorf("getting key address: %w", err) } - pending, ts := mpool.PendingFor(fromA) + pending, ts := mpool.PendingFor(ctx, fromA) priorMsgs := make([]types.ChainMsg, 0, len(pending)) for _, m := range pending { if m.Message.Nonce == msg.Nonce { diff --git a/node/impl/full/mpool.go b/node/impl/full/mpool.go index b1e9f94f9..9aa5371e9 100644 --- a/node/impl/full/mpool.go +++ b/node/impl/full/mpool.go @@ -66,7 +66,7 @@ func (a *MpoolAPI) MpoolPending(ctx context.Context, tsk types.TipSetKey) ([]*ty if err != nil { return nil, xerrors.Errorf("loading tipset %s: %w", tsk, err) } - pending, mpts := a.Mpool.Pending() + pending, mpts := a.Mpool.Pending(ctx) haveCids := map[cid.Cid]struct{}{} for _, m := range pending { @@ -125,11 +125,11 @@ func (a *MpoolAPI) MpoolClear(ctx context.Context, local bool) error { } func (m *MpoolModule) MpoolPush(ctx context.Context, smsg *types.SignedMessage) (cid.Cid, error) { - return m.Mpool.Push(smsg) + return m.Mpool.Push(ctx, smsg) } func (a *MpoolAPI) MpoolPushUntrusted(ctx context.Context, smsg *types.SignedMessage) (cid.Cid, error) { - return a.Mpool.PushUntrusted(smsg) + return a.Mpool.PushUntrusted(ctx, smsg) } func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spec *api.MessageSendSpec) (*types.SignedMessage, error) { @@ -190,7 +190,7 @@ func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spe func (a *MpoolAPI) MpoolBatchPush(ctx context.Context, smsgs []*types.SignedMessage) ([]cid.Cid, error) { var messageCids []cid.Cid for _, smsg := range smsgs { - smsgCid, err := a.Mpool.Push(smsg) + smsgCid, err := a.Mpool.Push(ctx, smsg) if err != nil { return messageCids, err } @@ -202,7 +202,7 @@ func (a *MpoolAPI) MpoolBatchPush(ctx context.Context, smsgs []*types.SignedMess func (a *MpoolAPI) MpoolBatchPushUntrusted(ctx context.Context, smsgs []*types.SignedMessage) ([]cid.Cid, error) { var messageCids []cid.Cid for _, smsg := range smsgs { - smsgCid, err := a.Mpool.PushUntrusted(smsg) + smsgCid, err := a.Mpool.PushUntrusted(ctx, smsg) if err != nil { return messageCids, err } @@ -224,7 +224,7 @@ func (a *MpoolAPI) MpoolBatchPushMessage(ctx context.Context, msgs []*types.Mess } func (a *MpoolAPI) MpoolGetNonce(ctx context.Context, addr address.Address) (uint64, error) { - return a.Mpool.GetNonce(addr) + return a.Mpool.GetNonce(ctx, addr) } func (a *MpoolAPI) MpoolSub(ctx context.Context) (<-chan api.MpoolUpdate, error) { diff --git a/node/modules/chain.go b/node/modules/chain.go index ffdf3aa3a..b0f0543c6 100644 --- a/node/modules/chain.go +++ b/node/modules/chain.go @@ -61,7 +61,8 @@ func ChainBlockService(bs dtypes.ExposedBlockstore, rem dtypes.ChainBitswap) dty func MessagePool(lc fx.Lifecycle, sm *stmgr.StateManager, ps *pubsub.PubSub, ds dtypes.MetadataDS, nn dtypes.NetworkName, j journal.Journal) (*messagepool.MessagePool, error) { mpp := messagepool.NewProvider(sm, ps) - mp, err := messagepool.New(mpp, ds, nn, j) + // TODO: I still don't know how go works -- should this context be part of the builder? + mp, err := messagepool.New(context.TODO(), mpp, ds, nn, j) if err != nil { return nil, xerrors.Errorf("constructing mpool: %w", err) } diff --git a/node/modules/mpoolnonceapi.go b/node/modules/mpoolnonceapi.go index efcb14037..3d670611b 100644 --- a/node/modules/mpoolnonceapi.go +++ b/node/modules/mpoolnonceapi.go @@ -23,7 +23,7 @@ type MpoolNonceAPI struct { } // GetNonce gets the nonce from current chain head. -func (a *MpoolNonceAPI) GetNonce(addr address.Address) (uint64, error) { +func (a *MpoolNonceAPI) GetNonce(ctx context.Context, addr address.Address) (uint64, error) { ts := a.StateAPI.Chain.GetHeaviestTipSet() // make sure we have a key address so we can compare with messages