From ed93d0725ffc17ab06cf690d32683b8f8753df7a Mon Sep 17 00:00:00 2001 From: Aayush Rajasekaran Date: Fri, 28 May 2021 20:35:50 -0400 Subject: [PATCH] Protect mp.localAddrs and mp.pending behind helper functions --- chain/messagepool/messagepool.go | 240 +++++++++++++++++++------- chain/messagepool/messagepool_test.go | 4 +- chain/messagepool/pruning.go | 4 +- chain/messagepool/repub.go | 18 +- chain/messagepool/selection.go | 5 +- node/impl/full/mpool.go | 2 +- 6 files changed, 192 insertions(+), 81 deletions(-) diff --git a/chain/messagepool/messagepool.go b/chain/messagepool/messagepool.go index 0c8569a1f..12e8f24c2 100644 --- a/chain/messagepool/messagepool.go +++ b/chain/messagepool/messagepool.go @@ -126,12 +126,14 @@ type MessagePool struct { republished map[cid.Cid]struct{} - // only pubkey addresses + // do NOT access this map directly, use isLocal, setLocal, and forEachLocal respectively localAddrs map[address.Address]struct{} - // only pubkey addresses + // do NOT access this map directly, use getPendingMset, setPendingMset, deletePendingMset, forEachPending, and clearPending respectively pending map[address.Address]*msgSet + keyCache map[address.Address]address.Address + curTsLk sync.Mutex // DO NOT LOCK INSIDE lk curTs *types.TipSet @@ -331,6 +333,20 @@ func (ms *msgSet) getRequiredFunds(nonce uint64) types.BigInt { return types.BigInt{Int: requiredFunds} } +func (ms *msgSet) toSlice() []*types.SignedMessage { + set := make([]*types.SignedMessage, 0, len(ms.msgs)) + + for _, m := range ms.msgs { + set = append(set, m) + } + + sort.Slice(set, func(i, j int) bool { + return set[i].Message.Nonce < set[j].Message.Nonce + }) + + return set +} + func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes.NetworkName, j journal.Journal) (*MessagePool, error) { cache, _ := lru.New2Q(build.BlsSignatureCacheSize) verifcache, _ := lru.New2Q(build.VerifSigCacheSize) @@ -352,6 +368,7 @@ func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes repubTrigger: make(chan struct{}, 1), localAddrs: make(map[address.Address]struct{}), pending: make(map[address.Address]*msgSet), + keyCache: make(map[address.Address]address.Address), minGasPrice: types.NewInt(0), pruneTrigger: make(chan struct{}, 1), pruneCooldown: make(chan struct{}, 1), @@ -397,12 +414,106 @@ func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes log.Info("mpool ready") - mp.runLoop() + mp.runLoop(context.Background()) }() return mp, nil } +func (mp *MessagePool) resolveToKey(ctx context.Context, addr address.Address) (address.Address, error) { + // check the cache + a, f := mp.keyCache[addr] + if f { + return a, nil + } + + // resolve the address + ka, err := mp.api.StateAccountKey(ctx, addr, mp.curTs) + if err != nil { + return address.Undef, err + } + + // place both entries in the cache (may both be key addresses, which is fine) + mp.keyCache[addr] = ka + mp.keyCache[ka] = ka + + return ka, nil +} + +func (mp *MessagePool) getPendingMset(ctx context.Context, addr address.Address) (*msgSet, bool, error) { + ra, err := mp.resolveToKey(ctx, addr) + if err != nil { + return nil, false, err + } + + ms, f := mp.pending[ra] + + return ms, f, nil +} + +func (mp *MessagePool) setPendingMset(ctx context.Context, addr address.Address, ms *msgSet) error { + ra, err := mp.resolveToKey(ctx, addr) + if err != nil { + return err + } + + mp.pending[ra] = ms + + return nil +} + +// This method isn't strictly necessary, since it doesn't resolve any addresses, but it's safer to have +func (mp *MessagePool) forEachPending(f func(address.Address, *msgSet)) { + for la, ms := range mp.pending { + f(la, ms) + } +} + +func (mp *MessagePool) deletePendingMset(ctx context.Context, addr address.Address) error { + ra, err := mp.resolveToKey(ctx, addr) + if err != nil { + return err + } + + delete(mp.pending, ra) + + return nil +} + +// This method isn't strictly necessary, since it doesn't resolve any addresses, but it's safer to have +func (mp *MessagePool) clearPending() { + mp.pending = make(map[address.Address]*msgSet) +} + +func (mp *MessagePool) isLocal(ctx context.Context, addr address.Address) (bool, error) { + ra, err := mp.resolveToKey(ctx, addr) + if err != nil { + return false, err + } + + _, f := mp.localAddrs[ra] + + return f, nil +} + +func (mp *MessagePool) setLocal(ctx context.Context, addr address.Address) error { + ra, err := mp.resolveToKey(ctx, addr) + if err != nil { + return err + } + + mp.localAddrs[ra] = struct{}{} + + return nil +} + +// This method isn't strictly necessary, since it doesn't resolve any addresses, but it's safer to have +func (mp *MessagePool) forEachLocal(ctx context.Context, f func(context.Context, address.Address)) { + for la := range mp.localAddrs { + f(ctx, la) + } +} + func (mp *MessagePool) Close() error { close(mp.closer) return nil @@ -420,15 +531,15 @@ func (mp *MessagePool) Prune() { mp.pruneTrigger <- struct{}{} } -func (mp *MessagePool) runLoop() { +func (mp *MessagePool) runLoop(ctx context.Context) { for { select { case <-mp.repubTk.C: - if err := mp.republishPendingMessages(); err != nil { + if err := mp.republishPendingMessages(ctx); err != nil { log.Errorf("error while republishing messages: %s", err) } case <-mp.repubTrigger: - if err := mp.republishPendingMessages(); err != nil { + if err := mp.republishPendingMessages(ctx); err != nil { log.Errorf("error while republishing messages: %s", err) } @@ -445,14 +556,10 @@ func (mp *MessagePool) runLoop() { } func (mp *MessagePool) addLocal(ctx context.Context, m *types.SignedMessage) error { - sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs) - if err != nil { - log.Debugf("mpooladdlocal failed to resolve sender: %s", err) + if err := mp.setLocal(ctx, m.Message.From); err != nil { return err } - mp.localAddrs[sk] = struct{}{} - msgb, err := m.Serialize() if err != nil { return xerrors.Errorf("error serializing message: %w", err) @@ -653,13 +760,12 @@ func (mp *MessagePool) checkBalance(ctx context.Context, m *types.SignedMessage, // add Value for soft failure check //requiredFunds = types.BigAdd(requiredFunds, m.Message.Value) - sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs) + mset, ok, err := mp.getPendingMset(ctx, m.Message.From) if err != nil { - log.Debugf("mpoolcheckbalance failed to resolve sender: %s", err) + log.Debugf("mpoolcheckbalance failed to get pending mset: %s", err) return err } - mset, ok := mp.pending[sk] if ok { requiredFunds = types.BigAdd(requiredFunds, mset.getRequiredFunds(m.Message.Nonce)) } @@ -766,21 +872,22 @@ func (mp *MessagePool) addLocked(ctx context.Context, m *types.SignedMessage, st return err } - sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs) + mset, ok, err := mp.getPendingMset(ctx, m.Message.From) if err != nil { - log.Debugf("mpooladd failed to resolve sender: %s", err) + log.Debug(err) return err } - mset, ok := mp.pending[sk] if !ok { - nonce, err := mp.getStateNonce(sk, mp.curTs) + nonce, err := mp.getStateNonce(m.Message.From, mp.curTs) if err != nil { return xerrors.Errorf("failed to get initial actor nonce: %w", err) } mset = newMsgSet(nonce) - mp.pending[sk] = mset + if err = mp.setPendingMset(ctx, m.Message.From, mset); err != nil { + return xerrors.Errorf("failed to set pending mset: %w", err) + } } incr, err := mset.add(m, mp, strict, untrusted) @@ -831,13 +938,12 @@ func (mp *MessagePool) getNonceLocked(ctx context.Context, addr address.Address, return 0, err } - sk, err := mp.api.StateAccountKey(ctx, addr, mp.curTs) + mset, ok, err := mp.getPendingMset(ctx, addr) if err != nil { - log.Debugf("mpoolgetnonce failed to resolve sender: %s", err) + log.Debugf("mpoolgetnonce failed to get mset: %s", err) return 0, err } - mset, ok := mp.pending[sk] if ok { if stateNonce > mset.nextNonce { log.Errorf("state nonce was larger than mset.nextNonce (%d > %d)", stateNonce, mset.nextNonce) @@ -917,13 +1023,12 @@ func (mp *MessagePool) Remove(ctx context.Context, from address.Address, nonce u } func (mp *MessagePool) remove(ctx context.Context, from address.Address, nonce uint64, applied bool) { - sk, err := mp.api.StateAccountKey(ctx, from, mp.curTs) + mset, ok, err := mp.getPendingMset(ctx, from) if err != nil { - log.Debugf("mpoolremove failed to resolve sender: %s", err) + log.Debugf("mpoolremove failed to get mset: %s", err) return } - mset, ok := mp.pending[sk] if !ok { return } @@ -948,7 +1053,10 @@ func (mp *MessagePool) remove(ctx context.Context, from address.Address, nonce u mset.rm(nonce, applied) if len(mset.msgs) == 0 { - delete(mp.pending, from) + if err = mp.deletePendingMset(ctx, from); err != nil { + log.Debugf("mpoolremove failed to delete mset: %s", err) + return + } } } @@ -964,9 +1072,10 @@ func (mp *MessagePool) Pending(ctx context.Context) ([]*types.SignedMessage, *ty func (mp *MessagePool) allPending(ctx context.Context) ([]*types.SignedMessage, *types.TipSet) { out := make([]*types.SignedMessage, 0) - for a := range mp.pending { - out = append(out, mp.pendingFor(ctx, a)...) - } + + mp.forEachPending(func(a address.Address, mset *msgSet) { + out = append(out, mset.toSlice()...) + }) return out, mp.curTs } @@ -981,28 +1090,17 @@ func (mp *MessagePool) PendingFor(ctx context.Context, a address.Address) ([]*ty } func (mp *MessagePool) pendingFor(ctx context.Context, a address.Address) []*types.SignedMessage { - sk, err := mp.api.StateAccountKey(ctx, a, mp.curTs) + mset, ok, err := mp.getPendingMset(ctx, a) if err != nil { - log.Debugf("mpoolpendingfor failed to resolve sender: %s", err) + log.Debugf("mpoolpendingfor failed to get mset: %s", err) return nil } - mset := mp.pending[sk] - if mset == nil || len(mset.msgs) == 0 { + if mset == nil || !ok || len(mset.msgs) == 0 { return nil } - set := make([]*types.SignedMessage, 0, len(mset.msgs)) - - for _, m := range mset.msgs { - set = append(set, m) - } - - sort.Slice(set, func(i, j int) bool { - return set[i].Message.Nonce < set[j].Message.Nonce - }) - - return set + return mset.toSlice() } func (mp *MessagePool) HeadChange(ctx context.Context, revert []*types.TipSet, apply []*types.TipSet) error { @@ -1341,53 +1439,61 @@ func (mp *MessagePool) loadLocal(ctx context.Context) error { log.Errorf("adding local message: %+v", err) } - sk, err := mp.api.StateAccountKey(ctx, sm.Message.From, mp.curTs) - if err != nil { - log.Debugf("mpoolloadLocal failed to resolve sender: %s", err) + if err = mp.setLocal(ctx, sm.Message.From); err != nil { + log.Debugf("mpoolloadLocal errored: %s", err) return err } - - mp.localAddrs[sk] = struct{}{} } return nil } -func (mp *MessagePool) Clear(local bool) { +func (mp *MessagePool) Clear(ctx context.Context, local bool) { mp.lk.Lock() defer mp.lk.Unlock() // remove everything if local is true, including removing local messages from // the datastore if local { - for a := range mp.localAddrs { - mset, ok := mp.pending[a] - if !ok { - continue + mp.forEachLocal(ctx, func(ctx context.Context, la address.Address) { + mset, ok, err := mp.getPendingMset(ctx, la) + if err != nil { + log.Warnf("errored while getting pending mset: %w", err) + return } - for _, m := range mset.msgs { - err := mp.localMsgs.Delete(datastore.NewKey(string(m.Cid().Bytes()))) - if err != nil { - log.Warnf("error deleting local message: %s", err) + if ok { + for _, m := range mset.msgs { + err := mp.localMsgs.Delete(datastore.NewKey(string(m.Cid().Bytes()))) + if err != nil { + log.Warnf("error deleting local message: %s", err) + } } } - } + }) - mp.pending = make(map[address.Address]*msgSet) + mp.clearPending() mp.republished = nil return } - // remove everything except the local messages - for a := range mp.pending { - _, isLocal := mp.localAddrs[a] - if isLocal { - continue + mp.forEachPending(func(a address.Address, ms *msgSet) { + isLocal, err := mp.isLocal(ctx, a) + if err != nil { + log.Warnf("errored while determining isLocal: %w", err) + return } - delete(mp.pending, a) - } + + if isLocal { + return + } + + if err = mp.deletePendingMset(ctx, a); err != nil { + log.Warnf("errored while deleting mset: %w", err) + return + } + }) } func getBaseFeeLowerBound(baseFee, factor types.BigInt) types.BigInt { diff --git a/chain/messagepool/messagepool_test.go b/chain/messagepool/messagepool_test.go index 3e5bad81f..aa3331c11 100644 --- a/chain/messagepool/messagepool_test.go +++ b/chain/messagepool/messagepool_test.go @@ -537,7 +537,7 @@ func TestClearAll(t *testing.T) { mustAdd(t, mp, m) } - mp.Clear(true) + mp.Clear(context.Background(), true) pending, _ := mp.Pending(context.TODO()) if len(pending) > 0 { @@ -592,7 +592,7 @@ func TestClearNonLocal(t *testing.T) { mustAdd(t, mp, m) } - mp.Clear(false) + mp.Clear(context.Background(), false) pending, _ := mp.Pending(context.TODO()) if len(pending) != 10 { diff --git a/chain/messagepool/pruning.go b/chain/messagepool/pruning.go index ad8f38c50..6802e23f3 100644 --- a/chain/messagepool/pruning.go +++ b/chain/messagepool/pruning.go @@ -61,9 +61,9 @@ func (mp *MessagePool) pruneMessages(ctx context.Context, ts *types.TipSet) erro } // we also never prune locally published messages - for actor := range mp.localAddrs { + mp.forEachLocal(ctx, func(ctx context.Context, actor address.Address) { protected[actor] = struct{}{} - } + }) // Collect all messages to track which ones to remove and create chains for block inclusion pruneMsgs := make(map[cid.Cid]*types.SignedMessage, mp.currentSize) diff --git a/chain/messagepool/repub.go b/chain/messagepool/repub.go index 5fa68aa53..4323bdee1 100644 --- a/chain/messagepool/repub.go +++ b/chain/messagepool/repub.go @@ -18,7 +18,7 @@ const repubMsgLimit = 30 var RepublishBatchDelay = 100 * time.Millisecond -func (mp *MessagePool) republishPendingMessages() error { +func (mp *MessagePool) republishPendingMessages(ctx context.Context) error { mp.curTsLk.Lock() ts := mp.curTs @@ -32,13 +32,18 @@ func (mp *MessagePool) republishPendingMessages() error { pending := make(map[address.Address]map[uint64]*types.SignedMessage) mp.lk.Lock() mp.republished = nil // clear this to avoid races triggering an early republish - for actor := range mp.localAddrs { - mset, ok := mp.pending[actor] + mp.forEachLocal(ctx, func(ctx context.Context, actor address.Address) { + mset, ok, err := mp.getPendingMset(ctx, actor) + if err != nil { + log.Debugf("failed to get mset: %w", err) + return + } + if !ok { - continue + return } if len(mset.msgs) == 0 { - continue + return } // we need to copy this while holding the lock to avoid races with concurrent modification pend := make(map[uint64]*types.SignedMessage, len(mset.msgs)) @@ -46,7 +51,8 @@ func (mp *MessagePool) republishPendingMessages() error { pend[nonce] = m } pending[actor] = pend - } + }) + mp.lk.Unlock() mp.curTsLk.Unlock() diff --git a/chain/messagepool/selection.go b/chain/messagepool/selection.go index af450645f..dfed2b6b5 100644 --- a/chain/messagepool/selection.go +++ b/chain/messagepool/selection.go @@ -654,8 +654,7 @@ func (mp *MessagePool) getPendingMessages(curTs, ts *types.TipSet) (map[address. inSync = true } - // first add our current pending messages - for a, mset := range mp.pending { + mp.forEachPending(func(a address.Address, mset *msgSet) { if inSync { // no need to copy the map result[a] = mset.msgs @@ -668,7 +667,7 @@ func (mp *MessagePool) getPendingMessages(curTs, ts *types.TipSet) (map[address. result[a] = msetCopy } - } + }) // we are in sync, that's the happy path if inSync { diff --git a/node/impl/full/mpool.go b/node/impl/full/mpool.go index 9aa5371e9..e91fc8b9e 100644 --- a/node/impl/full/mpool.go +++ b/node/impl/full/mpool.go @@ -120,7 +120,7 @@ func (a *MpoolAPI) MpoolPending(ctx context.Context, tsk types.TipSetKey) ([]*ty } func (a *MpoolAPI) MpoolClear(ctx context.Context, local bool) error { - a.Mpool.Clear(local) + a.Mpool.Clear(ctx, local) return nil }