Plumb contexts through

This commit is contained in:
Aayush Rajasekaran 2021-05-18 14:56:42 -04:00
parent 8d991283f4
commit 1f03a618f9
12 changed files with 94 additions and 98 deletions

View File

@ -331,7 +331,7 @@ func (ms *msgSet) getRequiredFunds(nonce uint64) types.BigInt {
return types.BigInt{Int: requiredFunds} return types.BigInt{Int: requiredFunds}
} }
func New(api Provider, ds dtypes.MetadataDS, netName dtypes.NetworkName, j journal.Journal) (*MessagePool, error) { func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes.NetworkName, j journal.Journal) (*MessagePool, error) {
cache, _ := lru.New2Q(build.BlsSignatureCacheSize) cache, _ := lru.New2Q(build.BlsSignatureCacheSize)
verifcache, _ := lru.New2Q(build.VerifSigCacheSize) verifcache, _ := lru.New2Q(build.VerifSigCacheSize)
@ -375,7 +375,7 @@ func New(api Provider, ds dtypes.MetadataDS, netName dtypes.NetworkName, j journ
// load the current tipset and subscribe to head changes _before_ loading local messages // load the current tipset and subscribe to head changes _before_ loading local messages
mp.curTs = api.SubscribeHeadChanges(func(rev, app []*types.TipSet) error { mp.curTs = api.SubscribeHeadChanges(func(rev, app []*types.TipSet) error {
err := mp.HeadChange(rev, app) err := mp.HeadChange(ctx, rev, app)
if err != nil { if err != nil {
log.Errorf("mpool head notif handler error: %+v", err) log.Errorf("mpool head notif handler error: %+v", err)
} }
@ -386,7 +386,7 @@ func New(api Provider, ds dtypes.MetadataDS, netName dtypes.NetworkName, j journ
mp.lk.Lock() mp.lk.Lock()
go func() { go func() {
err := mp.loadLocal() err := mp.loadLocal(ctx)
mp.lk.Unlock() mp.lk.Unlock()
mp.curTsLk.Unlock() mp.curTsLk.Unlock()
@ -444,9 +444,8 @@ func (mp *MessagePool) runLoop() {
} }
} }
func (mp *MessagePool) addLocal(m *types.SignedMessage) error { func (mp *MessagePool) addLocal(ctx context.Context, m *types.SignedMessage) error {
// TODO: Is context.TODO() safe here? Idk how Go works. sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs)
sk, err := mp.api.StateAccountKey(context.TODO(), m.Message.From, mp.curTs)
if err != nil { if err != nil {
log.Debugf("mpooladdlocal failed to resolve sender: %s", err) log.Debugf("mpooladdlocal failed to resolve sender: %s", err)
return err return err
@ -519,7 +518,7 @@ func (mp *MessagePool) verifyMsgBeforeAdd(m *types.SignedMessage, curTs *types.T
return publish, nil return publish, nil
} }
func (mp *MessagePool) Push(m *types.SignedMessage) (cid.Cid, error) { func (mp *MessagePool) Push(ctx context.Context, m *types.SignedMessage) (cid.Cid, error) {
err := mp.checkMessage(m) err := mp.checkMessage(m)
if err != nil { if err != nil {
return cid.Undef, err return cid.Undef, err
@ -532,7 +531,7 @@ func (mp *MessagePool) Push(m *types.SignedMessage) (cid.Cid, error) {
}() }()
mp.curTsLk.Lock() mp.curTsLk.Lock()
publish, err := mp.addTs(m, mp.curTs, true, false) publish, err := mp.addTs(ctx, m, mp.curTs, true, false)
if err != nil { if err != nil {
mp.curTsLk.Unlock() mp.curTsLk.Unlock()
return cid.Undef, err return cid.Undef, err
@ -585,7 +584,7 @@ func (mp *MessagePool) checkMessage(m *types.SignedMessage) error {
return nil return nil
} }
func (mp *MessagePool) Add(m *types.SignedMessage) error { func (mp *MessagePool) Add(ctx context.Context, m *types.SignedMessage) error {
err := mp.checkMessage(m) err := mp.checkMessage(m)
if err != nil { if err != nil {
return err return err
@ -600,7 +599,7 @@ func (mp *MessagePool) Add(m *types.SignedMessage) error {
mp.curTsLk.Lock() mp.curTsLk.Lock()
defer mp.curTsLk.Unlock() defer mp.curTsLk.Unlock()
_, err = mp.addTs(m, mp.curTs, false, false) _, err = mp.addTs(ctx, m, mp.curTs, false, false)
return err return err
} }
@ -640,7 +639,7 @@ func (mp *MessagePool) VerifyMsgSig(m *types.SignedMessage) error {
return nil return nil
} }
func (mp *MessagePool) checkBalance(m *types.SignedMessage, curTs *types.TipSet) error { func (mp *MessagePool) checkBalance(ctx context.Context, m *types.SignedMessage, curTs *types.TipSet) error {
balance, err := mp.getStateBalance(m.Message.From, curTs) balance, err := mp.getStateBalance(m.Message.From, curTs)
if err != nil { if err != nil {
return xerrors.Errorf("failed to check sender balance: %s: %w", err, ErrSoftValidationFailure) return xerrors.Errorf("failed to check sender balance: %s: %w", err, ErrSoftValidationFailure)
@ -654,8 +653,7 @@ func (mp *MessagePool) checkBalance(m *types.SignedMessage, curTs *types.TipSet)
// add Value for soft failure check // add Value for soft failure check
//requiredFunds = types.BigAdd(requiredFunds, m.Message.Value) //requiredFunds = types.BigAdd(requiredFunds, m.Message.Value)
// TODO: Is context.TODO() safe here? Idk how Go works. sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs)
sk, err := mp.api.StateAccountKey(context.TODO(), m.Message.From, mp.curTs)
if err != nil { if err != nil {
log.Debugf("mpoolcheckbalance failed to resolve sender: %s", err) log.Debugf("mpoolcheckbalance failed to resolve sender: %s", err)
return err return err
@ -675,7 +673,7 @@ func (mp *MessagePool) checkBalance(m *types.SignedMessage, curTs *types.TipSet)
return nil return nil
} }
func (mp *MessagePool) addTs(m *types.SignedMessage, curTs *types.TipSet, local, untrusted bool) (bool, error) { func (mp *MessagePool) addTs(ctx context.Context, m *types.SignedMessage, curTs *types.TipSet, local, untrusted bool) (bool, error) {
snonce, err := mp.getStateNonce(m.Message.From, curTs) snonce, err := mp.getStateNonce(m.Message.From, curTs)
if err != nil { if err != nil {
return false, xerrors.Errorf("failed to look up actor state nonce: %s: %w", err, ErrSoftValidationFailure) return false, xerrors.Errorf("failed to look up actor state nonce: %s: %w", err, ErrSoftValidationFailure)
@ -693,17 +691,17 @@ func (mp *MessagePool) addTs(m *types.SignedMessage, curTs *types.TipSet, local,
return false, err return false, err
} }
if err := mp.checkBalance(m, curTs); err != nil { if err := mp.checkBalance(ctx, m, curTs); err != nil {
return false, err return false, err
} }
err = mp.addLocked(m, !local, untrusted) err = mp.addLocked(ctx, m, !local, untrusted)
if err != nil { if err != nil {
return false, err return false, err
} }
if local { if local {
err = mp.addLocal(m) err = mp.addLocal(ctx, m)
if err != nil { if err != nil {
return false, xerrors.Errorf("error persisting local message: %w", err) return false, xerrors.Errorf("error persisting local message: %w", err)
} }
@ -712,7 +710,7 @@ func (mp *MessagePool) addTs(m *types.SignedMessage, curTs *types.TipSet, local,
return publish, nil return publish, nil
} }
func (mp *MessagePool) addLoaded(m *types.SignedMessage) error { func (mp *MessagePool) addLoaded(ctx context.Context, m *types.SignedMessage) error {
err := mp.checkMessage(m) err := mp.checkMessage(m)
if err != nil { if err != nil {
return err return err
@ -738,21 +736,21 @@ func (mp *MessagePool) addLoaded(m *types.SignedMessage) error {
return err return err
} }
if err := mp.checkBalance(m, curTs); err != nil { if err := mp.checkBalance(ctx, m, curTs); err != nil {
return err return err
} }
return mp.addLocked(m, false, false) return mp.addLocked(ctx, m, false, false)
} }
func (mp *MessagePool) addSkipChecks(m *types.SignedMessage) error { func (mp *MessagePool) addSkipChecks(ctx context.Context, m *types.SignedMessage) error {
mp.lk.Lock() mp.lk.Lock()
defer mp.lk.Unlock() defer mp.lk.Unlock()
return mp.addLocked(m, false, false) return mp.addLocked(ctx, m, false, false)
} }
func (mp *MessagePool) addLocked(m *types.SignedMessage, strict, untrusted bool) error { func (mp *MessagePool) addLocked(ctx context.Context, m *types.SignedMessage, strict, untrusted bool) error {
log.Debugf("mpooladd: %s %d", m.Message.From, m.Message.Nonce) log.Debugf("mpooladd: %s %d", m.Message.From, m.Message.Nonce)
if m.Signature.Type == crypto.SigTypeBLS { if m.Signature.Type == crypto.SigTypeBLS {
mp.blsSigCache.Add(m.Cid(), m.Signature) mp.blsSigCache.Add(m.Cid(), m.Signature)
@ -768,8 +766,7 @@ func (mp *MessagePool) addLocked(m *types.SignedMessage, strict, untrusted bool)
return err return err
} }
// TODO: Is context.TODO() safe here? Idk how Go works. sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs)
sk, err := mp.api.StateAccountKey(context.TODO(), m.Message.From, mp.curTs)
if err != nil { if err != nil {
log.Debugf("mpooladd failed to resolve sender: %s", err) log.Debugf("mpooladd failed to resolve sender: %s", err)
return err return err
@ -818,24 +815,23 @@ func (mp *MessagePool) addLocked(m *types.SignedMessage, strict, untrusted bool)
return nil return nil
} }
func (mp *MessagePool) GetNonce(addr address.Address) (uint64, error) { func (mp *MessagePool) GetNonce(ctx context.Context, addr address.Address) (uint64, error) {
mp.curTsLk.Lock() mp.curTsLk.Lock()
defer mp.curTsLk.Unlock() defer mp.curTsLk.Unlock()
mp.lk.Lock() mp.lk.Lock()
defer mp.lk.Unlock() defer mp.lk.Unlock()
return mp.getNonceLocked(addr, mp.curTs) return mp.getNonceLocked(ctx, addr, mp.curTs)
} }
func (mp *MessagePool) getNonceLocked(addr address.Address, curTs *types.TipSet) (uint64, error) { func (mp *MessagePool) getNonceLocked(ctx context.Context, addr address.Address, curTs *types.TipSet) (uint64, error) {
stateNonce, err := mp.getStateNonce(addr, curTs) // sanity check stateNonce, err := mp.getStateNonce(addr, curTs) // sanity check
if err != nil { if err != nil {
return 0, err return 0, err
} }
// TODO: Is context.TODO() safe here? Idk how Go works. sk, err := mp.api.StateAccountKey(ctx, addr, mp.curTs)
sk, err := mp.api.StateAccountKey(context.TODO(), addr, mp.curTs)
if err != nil { if err != nil {
log.Debugf("mpoolgetnonce failed to resolve sender: %s", err) log.Debugf("mpoolgetnonce failed to resolve sender: %s", err)
return 0, err return 0, err
@ -878,7 +874,7 @@ func (mp *MessagePool) getStateBalance(addr address.Address, ts *types.TipSet) (
// - strict checks are enabled // - strict checks are enabled
// - extra strict add checks are used when adding the messages to the msgSet // - extra strict add checks are used when adding the messages to the msgSet
// that means: no nonce gaps, at most 10 pending messages for the actor // that means: no nonce gaps, at most 10 pending messages for the actor
func (mp *MessagePool) PushUntrusted(m *types.SignedMessage) (cid.Cid, error) { func (mp *MessagePool) PushUntrusted(ctx context.Context, m *types.SignedMessage) (cid.Cid, error) {
err := mp.checkMessage(m) err := mp.checkMessage(m)
if err != nil { if err != nil {
return cid.Undef, err return cid.Undef, err
@ -891,7 +887,7 @@ func (mp *MessagePool) PushUntrusted(m *types.SignedMessage) (cid.Cid, error) {
}() }()
mp.curTsLk.Lock() mp.curTsLk.Lock()
publish, err := mp.addTs(m, mp.curTs, true, true) publish, err := mp.addTs(ctx, m, mp.curTs, true, true)
if err != nil { if err != nil {
mp.curTsLk.Unlock() mp.curTsLk.Unlock()
return cid.Undef, err return cid.Undef, err
@ -913,16 +909,15 @@ func (mp *MessagePool) PushUntrusted(m *types.SignedMessage) (cid.Cid, error) {
return m.Cid(), nil return m.Cid(), nil
} }
func (mp *MessagePool) Remove(from address.Address, nonce uint64, applied bool) { func (mp *MessagePool) Remove(ctx context.Context, from address.Address, nonce uint64, applied bool) {
mp.lk.Lock() mp.lk.Lock()
defer mp.lk.Unlock() defer mp.lk.Unlock()
mp.remove(from, nonce, applied) mp.remove(ctx, from, nonce, applied)
} }
func (mp *MessagePool) remove(from address.Address, nonce uint64, applied bool) { func (mp *MessagePool) remove(ctx context.Context, from address.Address, nonce uint64, applied bool) {
// TODO: Is context.TODO() safe here? Idk how Go works. sk, err := mp.api.StateAccountKey(ctx, from, mp.curTs)
sk, err := mp.api.StateAccountKey(context.TODO(), from, mp.curTs)
if err != nil { if err != nil {
log.Debugf("mpoolremove failed to resolve sender: %s", err) log.Debugf("mpoolremove failed to resolve sender: %s", err)
return return
@ -957,37 +952,36 @@ func (mp *MessagePool) remove(from address.Address, nonce uint64, applied bool)
} }
} }
func (mp *MessagePool) Pending() ([]*types.SignedMessage, *types.TipSet) { func (mp *MessagePool) Pending(ctx context.Context) ([]*types.SignedMessage, *types.TipSet) {
mp.curTsLk.Lock() mp.curTsLk.Lock()
defer mp.curTsLk.Unlock() defer mp.curTsLk.Unlock()
mp.lk.Lock() mp.lk.Lock()
defer mp.lk.Unlock() defer mp.lk.Unlock()
return mp.allPending() return mp.allPending(ctx)
} }
func (mp *MessagePool) allPending() ([]*types.SignedMessage, *types.TipSet) { func (mp *MessagePool) allPending(ctx context.Context) ([]*types.SignedMessage, *types.TipSet) {
out := make([]*types.SignedMessage, 0) out := make([]*types.SignedMessage, 0)
for a := range mp.pending { for a := range mp.pending {
out = append(out, mp.pendingFor(a)...) out = append(out, mp.pendingFor(ctx, a)...)
} }
return out, mp.curTs return out, mp.curTs
} }
func (mp *MessagePool) PendingFor(a address.Address) ([]*types.SignedMessage, *types.TipSet) { func (mp *MessagePool) PendingFor(ctx context.Context, a address.Address) ([]*types.SignedMessage, *types.TipSet) {
mp.curTsLk.Lock() mp.curTsLk.Lock()
defer mp.curTsLk.Unlock() defer mp.curTsLk.Unlock()
mp.lk.Lock() mp.lk.Lock()
defer mp.lk.Unlock() defer mp.lk.Unlock()
return mp.pendingFor(a), mp.curTs return mp.pendingFor(ctx, a), mp.curTs
} }
func (mp *MessagePool) pendingFor(a address.Address) []*types.SignedMessage { func (mp *MessagePool) pendingFor(ctx context.Context, a address.Address) []*types.SignedMessage {
// TODO: Is context.TODO() safe here? Idk how Go works. sk, err := mp.api.StateAccountKey(ctx, a, mp.curTs)
sk, err := mp.api.StateAccountKey(context.TODO(), a, mp.curTs)
if err != nil { if err != nil {
log.Debugf("mpoolpendingfor failed to resolve sender: %s", err) log.Debugf("mpoolpendingfor failed to resolve sender: %s", err)
return nil return nil
@ -1011,7 +1005,7 @@ func (mp *MessagePool) pendingFor(a address.Address) []*types.SignedMessage {
return set return set
} }
func (mp *MessagePool) HeadChange(revert []*types.TipSet, apply []*types.TipSet) error { func (mp *MessagePool) HeadChange(ctx context.Context, revert []*types.TipSet, apply []*types.TipSet) error {
mp.curTsLk.Lock() mp.curTsLk.Lock()
defer mp.curTsLk.Unlock() defer mp.curTsLk.Unlock()
@ -1028,7 +1022,7 @@ func (mp *MessagePool) HeadChange(revert []*types.TipSet, apply []*types.TipSet)
rm := func(from address.Address, nonce uint64) { rm := func(from address.Address, nonce uint64) {
s, ok := rmsgs[from] s, ok := rmsgs[from]
if !ok { if !ok {
mp.Remove(from, nonce, true) mp.Remove(ctx, from, nonce, true)
return return
} }
@ -1037,7 +1031,7 @@ func (mp *MessagePool) HeadChange(revert []*types.TipSet, apply []*types.TipSet)
return return
} }
mp.Remove(from, nonce, true) mp.Remove(ctx, from, nonce, true)
} }
maybeRepub := func(cid cid.Cid) { maybeRepub := func(cid cid.Cid) {
@ -1108,7 +1102,7 @@ func (mp *MessagePool) HeadChange(revert []*types.TipSet, apply []*types.TipSet)
for _, s := range rmsgs { for _, s := range rmsgs {
for _, msg := range s { for _, msg := range s {
if err := mp.addSkipChecks(msg); err != nil { if err := mp.addSkipChecks(ctx, msg); err != nil {
log.Errorf("Failed to readd message from reorg to mpool: %s", err) log.Errorf("Failed to readd message from reorg to mpool: %s", err)
} }
} }
@ -1116,7 +1110,7 @@ func (mp *MessagePool) HeadChange(revert []*types.TipSet, apply []*types.TipSet)
if len(revert) > 0 && futureDebug { if len(revert) > 0 && futureDebug {
mp.lk.Lock() mp.lk.Lock()
msgs, ts := mp.allPending() msgs, ts := mp.allPending(ctx)
mp.lk.Unlock() mp.lk.Unlock()
buckets := map[address.Address]*statBucket{} buckets := map[address.Address]*statBucket{}
@ -1323,7 +1317,7 @@ func (mp *MessagePool) Updates(ctx context.Context) (<-chan api.MpoolUpdate, err
return out, nil return out, nil
} }
func (mp *MessagePool) loadLocal() error { func (mp *MessagePool) loadLocal(ctx context.Context) error {
res, err := mp.localMsgs.Query(query.Query{}) res, err := mp.localMsgs.Query(query.Query{})
if err != nil { if err != nil {
return xerrors.Errorf("query local messages: %w", err) return xerrors.Errorf("query local messages: %w", err)
@ -1339,7 +1333,7 @@ func (mp *MessagePool) loadLocal() error {
return xerrors.Errorf("unmarshaling local message: %w", err) return xerrors.Errorf("unmarshaling local message: %w", err)
} }
if err := mp.addLoaded(&sm); err != nil { if err := mp.addLoaded(ctx, &sm); err != nil {
if xerrors.Is(err, ErrNonceTooLow) { if xerrors.Is(err, ErrNonceTooLow) {
continue // todo: drop the message from local cache (if above certain confidence threshold) continue // todo: drop the message from local cache (if above certain confidence threshold)
} }
@ -1347,8 +1341,7 @@ func (mp *MessagePool) loadLocal() error {
log.Errorf("adding local message: %+v", err) log.Errorf("adding local message: %+v", err)
} }
// TODO: Is context.TODO() safe here? Idk how Go works. sk, err := mp.api.StateAccountKey(ctx, sm.Message.From, mp.curTs)
sk, err := mp.api.StateAccountKey(context.TODO(), sm.Message.From, mp.curTs)
if err != nil { if err != nil {
log.Debugf("mpoolloadLocal failed to resolve sender: %s", err) log.Debugf("mpoolloadLocal failed to resolve sender: %s", err)
return err return err

View File

@ -199,7 +199,7 @@ func (tma *testMpoolAPI) ChainComputeBaseFee(ctx context.Context, ts *types.TipS
func assertNonce(t *testing.T, mp *MessagePool, addr address.Address, val uint64) { func assertNonce(t *testing.T, mp *MessagePool, addr address.Address, val uint64) {
t.Helper() t.Helper()
n, err := mp.GetNonce(addr) n, err := mp.GetNonce(context.TODO(), addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -211,7 +211,7 @@ func assertNonce(t *testing.T, mp *MessagePool, addr address.Address, val uint64
func mustAdd(t *testing.T, mp *MessagePool, msg *types.SignedMessage) { func mustAdd(t *testing.T, mp *MessagePool, msg *types.SignedMessage) {
t.Helper() t.Helper()
if err := mp.Add(msg); err != nil { if err := mp.Add(context.TODO(), msg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
@ -226,7 +226,7 @@ func TestMessagePool(t *testing.T) {
ds := datastore.NewMapDatastore() ds := datastore.NewMapDatastore()
mp, err := New(tma, ds, "mptest", nil) mp, err := New(context.TODO(), tma, ds, "mptest", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -267,7 +267,7 @@ func TestMessagePoolMessagesInEachBlock(t *testing.T) {
ds := datastore.NewMapDatastore() ds := datastore.NewMapDatastore()
mp, err := New(tma, ds, "mptest", nil) mp, err := New(context.TODO(), tma, ds, "mptest", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -293,7 +293,7 @@ func TestMessagePoolMessagesInEachBlock(t *testing.T) {
tma.applyBlock(t, a) tma.applyBlock(t, a)
tsa := mock.TipSet(a) tsa := mock.TipSet(a)
_, _ = mp.Pending() _, _ = mp.Pending(context.TODO())
selm, _ := mp.SelectMessages(tsa, 1) selm, _ := mp.SelectMessages(tsa, 1)
if len(selm) == 0 { if len(selm) == 0 {
@ -316,7 +316,7 @@ func TestRevertMessages(t *testing.T) {
ds := datastore.NewMapDatastore() ds := datastore.NewMapDatastore()
mp, err := New(tma, ds, "mptest", nil) mp, err := New(context.TODO(), tma, ds, "mptest", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -355,7 +355,7 @@ func TestRevertMessages(t *testing.T) {
assertNonce(t, mp, sender, 4) assertNonce(t, mp, sender, 4)
p, _ := mp.Pending() p, _ := mp.Pending(context.TODO())
fmt.Printf("%+v\n", p) fmt.Printf("%+v\n", p)
if len(p) != 3 { if len(p) != 3 {
t.Fatal("expected three messages in mempool") t.Fatal("expected three messages in mempool")
@ -379,7 +379,7 @@ func TestPruningSimple(t *testing.T) {
ds := datastore.NewMapDatastore() ds := datastore.NewMapDatastore()
mp, err := New(tma, ds, "mptest", nil) mp, err := New(context.TODO(), tma, ds, "mptest", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -396,14 +396,14 @@ func TestPruningSimple(t *testing.T) {
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
smsg := mock.MkMessage(sender, target, uint64(i), w) smsg := mock.MkMessage(sender, target, uint64(i), w)
if err := mp.Add(smsg); err != nil { if err := mp.Add(context.TODO(), smsg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
for i := 10; i < 50; i++ { for i := 10; i < 50; i++ {
smsg := mock.MkMessage(sender, target, uint64(i), w) smsg := mock.MkMessage(sender, target, uint64(i), w)
if err := mp.Add(smsg); err != nil { if err := mp.Add(context.TODO(), smsg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
@ -413,7 +413,7 @@ func TestPruningSimple(t *testing.T) {
mp.Prune() mp.Prune()
msgs, _ := mp.Pending() msgs, _ := mp.Pending(context.TODO())
if len(msgs) != 5 { if len(msgs) != 5 {
t.Fatal("expected only 5 messages in pool, got: ", len(msgs)) t.Fatal("expected only 5 messages in pool, got: ", len(msgs))
} }
@ -423,7 +423,7 @@ func TestLoadLocal(t *testing.T) {
tma := newTestMpoolAPI() tma := newTestMpoolAPI()
ds := datastore.NewMapDatastore() ds := datastore.NewMapDatastore()
mp, err := New(tma, ds, "mptest", nil) mp, err := New(context.TODO(), tma, ds, "mptest", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -455,7 +455,7 @@ func TestLoadLocal(t *testing.T) {
msgs := make(map[cid.Cid]struct{}) msgs := make(map[cid.Cid]struct{})
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1)) m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1))
cid, err := mp.Push(m) cid, err := mp.Push(context.TODO(), m)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -466,12 +466,12 @@ func TestLoadLocal(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
mp, err = New(tma, ds, "mptest", nil) mp, err = New(context.TODO(), tma, ds, "mptest", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
pmsgs, _ := mp.Pending() pmsgs, _ := mp.Pending(context.TODO())
if len(msgs) != len(pmsgs) { if len(msgs) != len(pmsgs) {
t.Fatalf("expected %d messages, but got %d", len(msgs), len(pmsgs)) t.Fatalf("expected %d messages, but got %d", len(msgs), len(pmsgs))
} }
@ -495,7 +495,7 @@ func TestClearAll(t *testing.T) {
tma := newTestMpoolAPI() tma := newTestMpoolAPI()
ds := datastore.NewMapDatastore() ds := datastore.NewMapDatastore()
mp, err := New(tma, ds, "mptest", nil) mp, err := New(context.TODO(), tma, ds, "mptest", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -526,7 +526,7 @@ func TestClearAll(t *testing.T) {
gasLimit := gasguess.Costs[gasguess.CostKey{Code: builtin2.StorageMarketActorCodeID, M: 2}] gasLimit := gasguess.Costs[gasguess.CostKey{Code: builtin2.StorageMarketActorCodeID, M: 2}]
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1)) m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1))
_, err := mp.Push(m) _, err := mp.Push(context.TODO(), m)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -539,7 +539,7 @@ func TestClearAll(t *testing.T) {
mp.Clear(true) mp.Clear(true)
pending, _ := mp.Pending() pending, _ := mp.Pending(context.TODO())
if len(pending) > 0 { if len(pending) > 0 {
t.Fatalf("cleared the mpool, but got %d pending messages", len(pending)) t.Fatalf("cleared the mpool, but got %d pending messages", len(pending))
} }
@ -549,7 +549,7 @@ func TestClearNonLocal(t *testing.T) {
tma := newTestMpoolAPI() tma := newTestMpoolAPI()
ds := datastore.NewMapDatastore() ds := datastore.NewMapDatastore()
mp, err := New(tma, ds, "mptest", nil) mp, err := New(context.TODO(), tma, ds, "mptest", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -581,7 +581,7 @@ func TestClearNonLocal(t *testing.T) {
gasLimit := gasguess.Costs[gasguess.CostKey{Code: builtin2.StorageMarketActorCodeID, M: 2}] gasLimit := gasguess.Costs[gasguess.CostKey{Code: builtin2.StorageMarketActorCodeID, M: 2}]
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1)) m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1))
_, err := mp.Push(m) _, err := mp.Push(context.TODO(), m)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -594,7 +594,7 @@ func TestClearNonLocal(t *testing.T) {
mp.Clear(false) mp.Clear(false)
pending, _ := mp.Pending() pending, _ := mp.Pending(context.TODO())
if len(pending) != 10 { if len(pending) != 10 {
t.Fatalf("expected 10 pending messages, but got %d instead", len(pending)) t.Fatalf("expected 10 pending messages, but got %d instead", len(pending))
} }
@ -610,7 +610,7 @@ func TestUpdates(t *testing.T) {
tma := newTestMpoolAPI() tma := newTestMpoolAPI()
ds := datastore.NewMapDatastore() ds := datastore.NewMapDatastore()
mp, err := New(tma, ds, "mptest", nil) mp, err := New(context.TODO(), tma, ds, "mptest", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -651,7 +651,7 @@ func TestUpdates(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1)) m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1))
_, err := mp.Push(m) _, err := mp.Push(context.TODO(), m)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -108,7 +108,7 @@ keepLoop:
// and remove all messages that are still in pruneMsgs after processing the chains // and remove all messages that are still in pruneMsgs after processing the chains
log.Infof("Pruning %d messages", len(pruneMsgs)) log.Infof("Pruning %d messages", len(pruneMsgs))
for _, m := range pruneMsgs { for _, m := range pruneMsgs {
mp.remove(m.Message.From, m.Message.Nonce, false) mp.remove(ctx, m.Message.From, m.Message.Nonce, false)
} }
return nil return nil

View File

@ -24,7 +24,7 @@ func TestRepubMessages(t *testing.T) {
tma := newTestMpoolAPI() tma := newTestMpoolAPI()
ds := datastore.NewMapDatastore() ds := datastore.NewMapDatastore()
mp, err := New(tma, ds, "mptest", nil) mp, err := New(context.TODO(), tma, ds, "mptest", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -56,7 +56,7 @@ func TestRepubMessages(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1)) m := makeTestMessage(w1, a1, a2, uint64(i), gasLimit, uint64(i+1))
_, err := mp.Push(m) _, err := mp.Push(context.TODO(), m)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -60,7 +60,7 @@ func makeTestMessage(w *wallet.LocalWallet, from, to address.Address, nonce uint
func makeTestMpool() (*MessagePool, *testMpoolAPI) { func makeTestMpool() (*MessagePool, *testMpoolAPI) {
tma := newTestMpoolAPI() tma := newTestMpoolAPI()
ds := datastore.NewMapDatastore() ds := datastore.NewMapDatastore()
mp, err := New(tma, ds, "test", nil) mp, err := New(context.TODO(), tma, ds, "test", nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -464,7 +464,7 @@ func TestBasicMessageSelection(t *testing.T) {
tma.applyBlock(t, block2) tma.applyBlock(t, block2)
// we should have no pending messages in the mpool // we should have no pending messages in the mpool
pend, _ := mp.Pending() pend, _ := mp.Pending(context.TODO())
if len(pend) != 0 { if len(pend) != 0 {
t.Fatalf("expected no pending messages, but got %d", len(pend)) t.Fatalf("expected no pending messages, but got %d", len(pend))
} }

View File

@ -23,7 +23,7 @@ const dsKeyActorNonce = "ActorNextNonce"
var log = logging.Logger("messagesigner") var log = logging.Logger("messagesigner")
type MpoolNonceAPI interface { type MpoolNonceAPI interface {
GetNonce(address.Address) (uint64, error) GetNonce(context.Context, address.Address) (uint64, error)
} }
// MessageSigner keeps track of nonces per address, and increments the nonce // MessageSigner keeps track of nonces per address, and increments the nonce
@ -51,7 +51,7 @@ func (ms *MessageSigner) SignMessage(ctx context.Context, msg *types.Message, cb
defer ms.lk.Unlock() defer ms.lk.Unlock()
// Get the next message nonce // Get the next message nonce
nonce, err := ms.nextNonce(msg.From) nonce, err := ms.nextNonce(ctx, msg.From)
if err != nil { if err != nil {
return nil, xerrors.Errorf("failed to create nonce: %w", err) return nil, xerrors.Errorf("failed to create nonce: %w", err)
} }
@ -92,12 +92,12 @@ func (ms *MessageSigner) SignMessage(ctx context.Context, msg *types.Message, cb
// nextNonce gets the next nonce for the given address. // nextNonce gets the next nonce for the given address.
// If there is no nonce in the datastore, gets the nonce from the message pool. // If there is no nonce in the datastore, gets the nonce from the message pool.
func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) { func (ms *MessageSigner) nextNonce(ctx context.Context, addr address.Address) (uint64, error) {
// Nonces used to be created by the mempool and we need to support nodes // Nonces used to be created by the mempool and we need to support nodes
// that have mempool nonces, so first check the mempool for a nonce for // that have mempool nonces, so first check the mempool for a nonce for
// this address. Note that the mempool returns the actor state's nonce // this address. Note that the mempool returns the actor state's nonce
// by default. // by default.
nonce, err := ms.mpool.GetNonce(addr) nonce, err := ms.mpool.GetNonce(ctx, addr)
if err != nil { if err != nil {
return 0, xerrors.Errorf("failed to get nonce from mempool: %w", err) return 0, xerrors.Errorf("failed to get nonce from mempool: %w", err)
} }

View File

@ -24,6 +24,8 @@ type mockMpool struct {
nonces map[address.Address]uint64 nonces map[address.Address]uint64
} }
var _ MpoolNonceAPI = (*mockMpool)(nil)
func newMockMpool() *mockMpool { func newMockMpool() *mockMpool {
return &mockMpool{nonces: make(map[address.Address]uint64)} return &mockMpool{nonces: make(map[address.Address]uint64)}
} }
@ -35,7 +37,7 @@ func (mp *mockMpool) setNonce(addr address.Address, nonce uint64) {
mp.nonces[addr] = nonce mp.nonces[addr] = nonce
} }
func (mp *mockMpool) GetNonce(addr address.Address) (uint64, error) { func (mp *mockMpool) GetNonce(ctx context.Context, addr address.Address) (uint64, error) {
mp.lk.RLock() mp.lk.RLock()
defer mp.lk.RUnlock() defer mp.lk.RUnlock()

View File

@ -516,7 +516,7 @@ func (mv *MessageValidator) Validate(ctx context.Context, pid peer.ID, msg *pubs
return pubsub.ValidationReject return pubsub.ValidationReject
} }
if err := mv.mpool.Add(m); err != nil { if err := mv.mpool.Add(ctx, m); err != nil {
log.Debugf("failed to add message from network to message pool (From: %s, To: %s, Nonce: %d, Value: %s): %s", m.Message.From, m.Message.To, m.Message.Nonce, types.FIL(m.Message.Value), err) log.Debugf("failed to add message from network to message pool (From: %s, To: %s, Nonce: %d, Value: %s): %s", m.Message.From, m.Message.To, m.Message.Nonce, types.FIL(m.Message.Value), err)
ctx, _ = tag.New( ctx, _ = tag.New(
ctx, ctx,

View File

@ -265,7 +265,7 @@ func gasEstimateGasLimit(
return -1, xerrors.Errorf("getting key address: %w", err) return -1, xerrors.Errorf("getting key address: %w", err)
} }
pending, ts := mpool.PendingFor(fromA) pending, ts := mpool.PendingFor(ctx, fromA)
priorMsgs := make([]types.ChainMsg, 0, len(pending)) priorMsgs := make([]types.ChainMsg, 0, len(pending))
for _, m := range pending { for _, m := range pending {
if m.Message.Nonce == msg.Nonce { if m.Message.Nonce == msg.Nonce {

View File

@ -66,7 +66,7 @@ func (a *MpoolAPI) MpoolPending(ctx context.Context, tsk types.TipSetKey) ([]*ty
if err != nil { if err != nil {
return nil, xerrors.Errorf("loading tipset %s: %w", tsk, err) return nil, xerrors.Errorf("loading tipset %s: %w", tsk, err)
} }
pending, mpts := a.Mpool.Pending() pending, mpts := a.Mpool.Pending(ctx)
haveCids := map[cid.Cid]struct{}{} haveCids := map[cid.Cid]struct{}{}
for _, m := range pending { for _, m := range pending {
@ -125,11 +125,11 @@ func (a *MpoolAPI) MpoolClear(ctx context.Context, local bool) error {
} }
func (m *MpoolModule) MpoolPush(ctx context.Context, smsg *types.SignedMessage) (cid.Cid, error) { func (m *MpoolModule) MpoolPush(ctx context.Context, smsg *types.SignedMessage) (cid.Cid, error) {
return m.Mpool.Push(smsg) return m.Mpool.Push(ctx, smsg)
} }
func (a *MpoolAPI) MpoolPushUntrusted(ctx context.Context, smsg *types.SignedMessage) (cid.Cid, error) { func (a *MpoolAPI) MpoolPushUntrusted(ctx context.Context, smsg *types.SignedMessage) (cid.Cid, error) {
return a.Mpool.PushUntrusted(smsg) return a.Mpool.PushUntrusted(ctx, smsg)
} }
func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spec *api.MessageSendSpec) (*types.SignedMessage, error) { func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spec *api.MessageSendSpec) (*types.SignedMessage, error) {
@ -190,7 +190,7 @@ func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spe
func (a *MpoolAPI) MpoolBatchPush(ctx context.Context, smsgs []*types.SignedMessage) ([]cid.Cid, error) { func (a *MpoolAPI) MpoolBatchPush(ctx context.Context, smsgs []*types.SignedMessage) ([]cid.Cid, error) {
var messageCids []cid.Cid var messageCids []cid.Cid
for _, smsg := range smsgs { for _, smsg := range smsgs {
smsgCid, err := a.Mpool.Push(smsg) smsgCid, err := a.Mpool.Push(ctx, smsg)
if err != nil { if err != nil {
return messageCids, err return messageCids, err
} }
@ -202,7 +202,7 @@ func (a *MpoolAPI) MpoolBatchPush(ctx context.Context, smsgs []*types.SignedMess
func (a *MpoolAPI) MpoolBatchPushUntrusted(ctx context.Context, smsgs []*types.SignedMessage) ([]cid.Cid, error) { func (a *MpoolAPI) MpoolBatchPushUntrusted(ctx context.Context, smsgs []*types.SignedMessage) ([]cid.Cid, error) {
var messageCids []cid.Cid var messageCids []cid.Cid
for _, smsg := range smsgs { for _, smsg := range smsgs {
smsgCid, err := a.Mpool.PushUntrusted(smsg) smsgCid, err := a.Mpool.PushUntrusted(ctx, smsg)
if err != nil { if err != nil {
return messageCids, err return messageCids, err
} }
@ -224,7 +224,7 @@ func (a *MpoolAPI) MpoolBatchPushMessage(ctx context.Context, msgs []*types.Mess
} }
func (a *MpoolAPI) MpoolGetNonce(ctx context.Context, addr address.Address) (uint64, error) { func (a *MpoolAPI) MpoolGetNonce(ctx context.Context, addr address.Address) (uint64, error) {
return a.Mpool.GetNonce(addr) return a.Mpool.GetNonce(ctx, addr)
} }
func (a *MpoolAPI) MpoolSub(ctx context.Context) (<-chan api.MpoolUpdate, error) { func (a *MpoolAPI) MpoolSub(ctx context.Context) (<-chan api.MpoolUpdate, error) {

View File

@ -61,7 +61,8 @@ func ChainBlockService(bs dtypes.ExposedBlockstore, rem dtypes.ChainBitswap) dty
func MessagePool(lc fx.Lifecycle, sm *stmgr.StateManager, ps *pubsub.PubSub, ds dtypes.MetadataDS, nn dtypes.NetworkName, j journal.Journal) (*messagepool.MessagePool, error) { func MessagePool(lc fx.Lifecycle, sm *stmgr.StateManager, ps *pubsub.PubSub, ds dtypes.MetadataDS, nn dtypes.NetworkName, j journal.Journal) (*messagepool.MessagePool, error) {
mpp := messagepool.NewProvider(sm, ps) mpp := messagepool.NewProvider(sm, ps)
mp, err := messagepool.New(mpp, ds, nn, j) // TODO: I still don't know how go works -- should this context be part of the builder?
mp, err := messagepool.New(context.TODO(), mpp, ds, nn, j)
if err != nil { if err != nil {
return nil, xerrors.Errorf("constructing mpool: %w", err) return nil, xerrors.Errorf("constructing mpool: %w", err)
} }

View File

@ -23,7 +23,7 @@ type MpoolNonceAPI struct {
} }
// GetNonce gets the nonce from current chain head. // GetNonce gets the nonce from current chain head.
func (a *MpoolNonceAPI) GetNonce(addr address.Address) (uint64, error) { func (a *MpoolNonceAPI) GetNonce(ctx context.Context, addr address.Address) (uint64, error) {
ts := a.StateAPI.Chain.GetHeaviestTipSet() ts := a.StateAPI.Chain.GetHeaviestTipSet()
// make sure we have a key address so we can compare with messages // make sure we have a key address so we can compare with messages