Protect mp.localAddrs and mp.pending behind helper functions

This commit is contained in:
Aayush Rajasekaran 2021-05-28 20:35:50 -04:00
parent 1f03a618f9
commit ed93d0725f
6 changed files with 192 additions and 81 deletions

View File

@ -126,12 +126,14 @@ type MessagePool struct {
republished map[cid.Cid]struct{} republished map[cid.Cid]struct{}
// only pubkey addresses // do NOT access this map directly, use isLocal, setLocal, and forEachLocal respectively
localAddrs map[address.Address]struct{} localAddrs map[address.Address]struct{}
// only pubkey addresses // do NOT access this map directly, use getPendingMset, setPendingMset, deletePendingMset, forEachPending, and clearPending respectively
pending map[address.Address]*msgSet pending map[address.Address]*msgSet
keyCache map[address.Address]address.Address
curTsLk sync.Mutex // DO NOT LOCK INSIDE lk curTsLk sync.Mutex // DO NOT LOCK INSIDE lk
curTs *types.TipSet curTs *types.TipSet
@ -331,6 +333,20 @@ func (ms *msgSet) getRequiredFunds(nonce uint64) types.BigInt {
return types.BigInt{Int: requiredFunds} return types.BigInt{Int: requiredFunds}
} }
func (ms *msgSet) toSlice() []*types.SignedMessage {
set := make([]*types.SignedMessage, 0, len(ms.msgs))
for _, m := range ms.msgs {
set = append(set, m)
}
sort.Slice(set, func(i, j int) bool {
return set[i].Message.Nonce < set[j].Message.Nonce
})
return set
}
func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes.NetworkName, j journal.Journal) (*MessagePool, error) { func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes.NetworkName, j journal.Journal) (*MessagePool, error) {
cache, _ := lru.New2Q(build.BlsSignatureCacheSize) cache, _ := lru.New2Q(build.BlsSignatureCacheSize)
verifcache, _ := lru.New2Q(build.VerifSigCacheSize) verifcache, _ := lru.New2Q(build.VerifSigCacheSize)
@ -352,6 +368,7 @@ func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes
repubTrigger: make(chan struct{}, 1), repubTrigger: make(chan struct{}, 1),
localAddrs: make(map[address.Address]struct{}), localAddrs: make(map[address.Address]struct{}),
pending: make(map[address.Address]*msgSet), pending: make(map[address.Address]*msgSet),
keyCache: make(map[address.Address]address.Address),
minGasPrice: types.NewInt(0), minGasPrice: types.NewInt(0),
pruneTrigger: make(chan struct{}, 1), pruneTrigger: make(chan struct{}, 1),
pruneCooldown: make(chan struct{}, 1), pruneCooldown: make(chan struct{}, 1),
@ -397,12 +414,106 @@ func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes
log.Info("mpool ready") log.Info("mpool ready")
mp.runLoop() mp.runLoop(context.Background())
}() }()
return mp, nil return mp, nil
} }
func (mp *MessagePool) resolveToKey(ctx context.Context, addr address.Address) (address.Address, error) {
// check the cache
a, f := mp.keyCache[addr]
if f {
return a, nil
}
// resolve the address
ka, err := mp.api.StateAccountKey(ctx, addr, mp.curTs)
if err != nil {
return address.Undef, err
}
// place both entries in the cache (may both be key addresses, which is fine)
mp.keyCache[addr] = ka
mp.keyCache[ka] = ka
return ka, nil
}
func (mp *MessagePool) getPendingMset(ctx context.Context, addr address.Address) (*msgSet, bool, error) {
ra, err := mp.resolveToKey(ctx, addr)
if err != nil {
return nil, false, err
}
ms, f := mp.pending[ra]
return ms, f, nil
}
func (mp *MessagePool) setPendingMset(ctx context.Context, addr address.Address, ms *msgSet) error {
ra, err := mp.resolveToKey(ctx, addr)
if err != nil {
return err
}
mp.pending[ra] = ms
return nil
}
// This method isn't strictly necessary, since it doesn't resolve any addresses, but it's safer to have
func (mp *MessagePool) forEachPending(f func(address.Address, *msgSet)) {
for la, ms := range mp.pending {
f(la, ms)
}
}
func (mp *MessagePool) deletePendingMset(ctx context.Context, addr address.Address) error {
ra, err := mp.resolveToKey(ctx, addr)
if err != nil {
return err
}
delete(mp.pending, ra)
return nil
}
// This method isn't strictly necessary, since it doesn't resolve any addresses, but it's safer to have
func (mp *MessagePool) clearPending() {
mp.pending = make(map[address.Address]*msgSet)
}
func (mp *MessagePool) isLocal(ctx context.Context, addr address.Address) (bool, error) {
ra, err := mp.resolveToKey(ctx, addr)
if err != nil {
return false, err
}
_, f := mp.localAddrs[ra]
return f, nil
}
func (mp *MessagePool) setLocal(ctx context.Context, addr address.Address) error {
ra, err := mp.resolveToKey(ctx, addr)
if err != nil {
return err
}
mp.localAddrs[ra] = struct{}{}
return nil
}
// This method isn't strictly necessary, since it doesn't resolve any addresses, but it's safer to have
func (mp *MessagePool) forEachLocal(ctx context.Context, f func(context.Context, address.Address)) {
for la := range mp.localAddrs {
f(ctx, la)
}
}
func (mp *MessagePool) Close() error { func (mp *MessagePool) Close() error {
close(mp.closer) close(mp.closer)
return nil return nil
@ -420,15 +531,15 @@ func (mp *MessagePool) Prune() {
mp.pruneTrigger <- struct{}{} mp.pruneTrigger <- struct{}{}
} }
func (mp *MessagePool) runLoop() { func (mp *MessagePool) runLoop(ctx context.Context) {
for { for {
select { select {
case <-mp.repubTk.C: case <-mp.repubTk.C:
if err := mp.republishPendingMessages(); err != nil { if err := mp.republishPendingMessages(ctx); err != nil {
log.Errorf("error while republishing messages: %s", err) log.Errorf("error while republishing messages: %s", err)
} }
case <-mp.repubTrigger: case <-mp.repubTrigger:
if err := mp.republishPendingMessages(); err != nil { if err := mp.republishPendingMessages(ctx); err != nil {
log.Errorf("error while republishing messages: %s", err) log.Errorf("error while republishing messages: %s", err)
} }
@ -445,14 +556,10 @@ func (mp *MessagePool) runLoop() {
} }
func (mp *MessagePool) addLocal(ctx context.Context, m *types.SignedMessage) error { func (mp *MessagePool) addLocal(ctx context.Context, m *types.SignedMessage) error {
sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs) if err := mp.setLocal(ctx, m.Message.From); err != nil {
if err != nil {
log.Debugf("mpooladdlocal failed to resolve sender: %s", err)
return err return err
} }
mp.localAddrs[sk] = struct{}{}
msgb, err := m.Serialize() msgb, err := m.Serialize()
if err != nil { if err != nil {
return xerrors.Errorf("error serializing message: %w", err) return xerrors.Errorf("error serializing message: %w", err)
@ -653,13 +760,12 @@ func (mp *MessagePool) checkBalance(ctx context.Context, m *types.SignedMessage,
// add Value for soft failure check // add Value for soft failure check
//requiredFunds = types.BigAdd(requiredFunds, m.Message.Value) //requiredFunds = types.BigAdd(requiredFunds, m.Message.Value)
sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs) mset, ok, err := mp.getPendingMset(ctx, m.Message.From)
if err != nil { if err != nil {
log.Debugf("mpoolcheckbalance failed to resolve sender: %s", err) log.Debugf("mpoolcheckbalance failed to get pending mset: %s", err)
return err return err
} }
mset, ok := mp.pending[sk]
if ok { if ok {
requiredFunds = types.BigAdd(requiredFunds, mset.getRequiredFunds(m.Message.Nonce)) requiredFunds = types.BigAdd(requiredFunds, mset.getRequiredFunds(m.Message.Nonce))
} }
@ -766,21 +872,22 @@ func (mp *MessagePool) addLocked(ctx context.Context, m *types.SignedMessage, st
return err return err
} }
sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs) mset, ok, err := mp.getPendingMset(ctx, m.Message.From)
if err != nil { if err != nil {
log.Debugf("mpooladd failed to resolve sender: %s", err) log.Debug(err)
return err return err
} }
mset, ok := mp.pending[sk]
if !ok { if !ok {
nonce, err := mp.getStateNonce(sk, mp.curTs) nonce, err := mp.getStateNonce(m.Message.From, mp.curTs)
if err != nil { if err != nil {
return xerrors.Errorf("failed to get initial actor nonce: %w", err) return xerrors.Errorf("failed to get initial actor nonce: %w", err)
} }
mset = newMsgSet(nonce) mset = newMsgSet(nonce)
mp.pending[sk] = mset if err = mp.setPendingMset(ctx, m.Message.From, mset); err != nil {
return xerrors.Errorf("failed to set pending mset: %w", err)
}
} }
incr, err := mset.add(m, mp, strict, untrusted) incr, err := mset.add(m, mp, strict, untrusted)
@ -831,13 +938,12 @@ func (mp *MessagePool) getNonceLocked(ctx context.Context, addr address.Address,
return 0, err return 0, err
} }
sk, err := mp.api.StateAccountKey(ctx, addr, mp.curTs) mset, ok, err := mp.getPendingMset(ctx, addr)
if err != nil { if err != nil {
log.Debugf("mpoolgetnonce failed to resolve sender: %s", err) log.Debugf("mpoolgetnonce failed to get mset: %s", err)
return 0, err return 0, err
} }
mset, ok := mp.pending[sk]
if ok { if ok {
if stateNonce > mset.nextNonce { if stateNonce > mset.nextNonce {
log.Errorf("state nonce was larger than mset.nextNonce (%d > %d)", stateNonce, mset.nextNonce) log.Errorf("state nonce was larger than mset.nextNonce (%d > %d)", stateNonce, mset.nextNonce)
@ -917,13 +1023,12 @@ func (mp *MessagePool) Remove(ctx context.Context, from address.Address, nonce u
} }
func (mp *MessagePool) remove(ctx context.Context, from address.Address, nonce uint64, applied bool) { func (mp *MessagePool) remove(ctx context.Context, from address.Address, nonce uint64, applied bool) {
sk, err := mp.api.StateAccountKey(ctx, from, mp.curTs) mset, ok, err := mp.getPendingMset(ctx, from)
if err != nil { if err != nil {
log.Debugf("mpoolremove failed to resolve sender: %s", err) log.Debugf("mpoolremove failed to get mset: %s", err)
return return
} }
mset, ok := mp.pending[sk]
if !ok { if !ok {
return return
} }
@ -948,7 +1053,10 @@ func (mp *MessagePool) remove(ctx context.Context, from address.Address, nonce u
mset.rm(nonce, applied) mset.rm(nonce, applied)
if len(mset.msgs) == 0 { if len(mset.msgs) == 0 {
delete(mp.pending, from) if err = mp.deletePendingMset(ctx, from); err != nil {
log.Debugf("mpoolremove failed to delete mset: %s", err)
return
}
} }
} }
@ -964,9 +1072,10 @@ func (mp *MessagePool) Pending(ctx context.Context) ([]*types.SignedMessage, *ty
func (mp *MessagePool) allPending(ctx context.Context) ([]*types.SignedMessage, *types.TipSet) { func (mp *MessagePool) allPending(ctx context.Context) ([]*types.SignedMessage, *types.TipSet) {
out := make([]*types.SignedMessage, 0) out := make([]*types.SignedMessage, 0)
for a := range mp.pending {
out = append(out, mp.pendingFor(ctx, a)...) mp.forEachPending(func(a address.Address, mset *msgSet) {
} out = append(out, mset.toSlice()...)
})
return out, mp.curTs return out, mp.curTs
} }
@ -981,28 +1090,17 @@ func (mp *MessagePool) PendingFor(ctx context.Context, a address.Address) ([]*ty
} }
func (mp *MessagePool) pendingFor(ctx context.Context, a address.Address) []*types.SignedMessage { func (mp *MessagePool) pendingFor(ctx context.Context, a address.Address) []*types.SignedMessage {
sk, err := mp.api.StateAccountKey(ctx, a, mp.curTs) mset, ok, err := mp.getPendingMset(ctx, a)
if err != nil { if err != nil {
log.Debugf("mpoolpendingfor failed to resolve sender: %s", err) log.Debugf("mpoolpendingfor failed to get mset: %s", err)
return nil return nil
} }
mset := mp.pending[sk] if mset == nil || !ok || len(mset.msgs) == 0 {
if mset == nil || len(mset.msgs) == 0 {
return nil return nil
} }
set := make([]*types.SignedMessage, 0, len(mset.msgs)) return mset.toSlice()
for _, m := range mset.msgs {
set = append(set, m)
}
sort.Slice(set, func(i, j int) bool {
return set[i].Message.Nonce < set[j].Message.Nonce
})
return set
} }
func (mp *MessagePool) HeadChange(ctx context.Context, revert []*types.TipSet, apply []*types.TipSet) error { func (mp *MessagePool) HeadChange(ctx context.Context, revert []*types.TipSet, apply []*types.TipSet) error {
@ -1341,31 +1439,30 @@ func (mp *MessagePool) loadLocal(ctx context.Context) error {
log.Errorf("adding local message: %+v", err) log.Errorf("adding local message: %+v", err)
} }
sk, err := mp.api.StateAccountKey(ctx, sm.Message.From, mp.curTs) if err = mp.setLocal(ctx, sm.Message.From); err != nil {
if err != nil { log.Debugf("mpoolloadLocal errored: %s", err)
log.Debugf("mpoolloadLocal failed to resolve sender: %s", err)
return err return err
} }
mp.localAddrs[sk] = struct{}{}
} }
return nil return nil
} }
func (mp *MessagePool) Clear(local bool) { func (mp *MessagePool) Clear(ctx context.Context, local bool) {
mp.lk.Lock() mp.lk.Lock()
defer mp.lk.Unlock() defer mp.lk.Unlock()
// remove everything if local is true, including removing local messages from // remove everything if local is true, including removing local messages from
// the datastore // the datastore
if local { if local {
for a := range mp.localAddrs { mp.forEachLocal(ctx, func(ctx context.Context, la address.Address) {
mset, ok := mp.pending[a] mset, ok, err := mp.getPendingMset(ctx, la)
if !ok { if err != nil {
continue log.Warnf("errored while getting pending mset: %w", err)
return
} }
if ok {
for _, m := range mset.msgs { for _, m := range mset.msgs {
err := mp.localMsgs.Delete(datastore.NewKey(string(m.Cid().Bytes()))) err := mp.localMsgs.Delete(datastore.NewKey(string(m.Cid().Bytes())))
if err != nil { if err != nil {
@ -1373,21 +1470,30 @@ func (mp *MessagePool) Clear(local bool) {
} }
} }
} }
})
mp.pending = make(map[address.Address]*msgSet) mp.clearPending()
mp.republished = nil mp.republished = nil
return return
} }
// remove everything except the local messages mp.forEachPending(func(a address.Address, ms *msgSet) {
for a := range mp.pending { isLocal, err := mp.isLocal(ctx, a)
_, isLocal := mp.localAddrs[a] if err != nil {
log.Warnf("errored while determining isLocal: %w", err)
return
}
if isLocal { if isLocal {
continue return
} }
delete(mp.pending, a)
if err = mp.deletePendingMset(ctx, a); err != nil {
log.Warnf("errored while deleting mset: %w", err)
return
} }
})
} }
func getBaseFeeLowerBound(baseFee, factor types.BigInt) types.BigInt { func getBaseFeeLowerBound(baseFee, factor types.BigInt) types.BigInt {

View File

@ -537,7 +537,7 @@ func TestClearAll(t *testing.T) {
mustAdd(t, mp, m) mustAdd(t, mp, m)
} }
mp.Clear(true) mp.Clear(context.Background(), true)
pending, _ := mp.Pending(context.TODO()) pending, _ := mp.Pending(context.TODO())
if len(pending) > 0 { if len(pending) > 0 {
@ -592,7 +592,7 @@ func TestClearNonLocal(t *testing.T) {
mustAdd(t, mp, m) mustAdd(t, mp, m)
} }
mp.Clear(false) mp.Clear(context.Background(), false)
pending, _ := mp.Pending(context.TODO()) pending, _ := mp.Pending(context.TODO())
if len(pending) != 10 { if len(pending) != 10 {

View File

@ -61,9 +61,9 @@ func (mp *MessagePool) pruneMessages(ctx context.Context, ts *types.TipSet) erro
} }
// we also never prune locally published messages // we also never prune locally published messages
for actor := range mp.localAddrs { mp.forEachLocal(ctx, func(ctx context.Context, actor address.Address) {
protected[actor] = struct{}{} protected[actor] = struct{}{}
} })
// Collect all messages to track which ones to remove and create chains for block inclusion // Collect all messages to track which ones to remove and create chains for block inclusion
pruneMsgs := make(map[cid.Cid]*types.SignedMessage, mp.currentSize) pruneMsgs := make(map[cid.Cid]*types.SignedMessage, mp.currentSize)

View File

@ -18,7 +18,7 @@ const repubMsgLimit = 30
var RepublishBatchDelay = 100 * time.Millisecond var RepublishBatchDelay = 100 * time.Millisecond
func (mp *MessagePool) republishPendingMessages() error { func (mp *MessagePool) republishPendingMessages(ctx context.Context) error {
mp.curTsLk.Lock() mp.curTsLk.Lock()
ts := mp.curTs ts := mp.curTs
@ -32,13 +32,18 @@ func (mp *MessagePool) republishPendingMessages() error {
pending := make(map[address.Address]map[uint64]*types.SignedMessage) pending := make(map[address.Address]map[uint64]*types.SignedMessage)
mp.lk.Lock() mp.lk.Lock()
mp.republished = nil // clear this to avoid races triggering an early republish mp.republished = nil // clear this to avoid races triggering an early republish
for actor := range mp.localAddrs { mp.forEachLocal(ctx, func(ctx context.Context, actor address.Address) {
mset, ok := mp.pending[actor] mset, ok, err := mp.getPendingMset(ctx, actor)
if err != nil {
log.Debugf("failed to get mset: %w", err)
return
}
if !ok { if !ok {
continue return
} }
if len(mset.msgs) == 0 { if len(mset.msgs) == 0 {
continue return
} }
// we need to copy this while holding the lock to avoid races with concurrent modification // we need to copy this while holding the lock to avoid races with concurrent modification
pend := make(map[uint64]*types.SignedMessage, len(mset.msgs)) pend := make(map[uint64]*types.SignedMessage, len(mset.msgs))
@ -46,7 +51,8 @@ func (mp *MessagePool) republishPendingMessages() error {
pend[nonce] = m pend[nonce] = m
} }
pending[actor] = pend pending[actor] = pend
} })
mp.lk.Unlock() mp.lk.Unlock()
mp.curTsLk.Unlock() mp.curTsLk.Unlock()

View File

@ -654,8 +654,7 @@ func (mp *MessagePool) getPendingMessages(curTs, ts *types.TipSet) (map[address.
inSync = true inSync = true
} }
// first add our current pending messages mp.forEachPending(func(a address.Address, mset *msgSet) {
for a, mset := range mp.pending {
if inSync { if inSync {
// no need to copy the map // no need to copy the map
result[a] = mset.msgs result[a] = mset.msgs
@ -668,7 +667,7 @@ func (mp *MessagePool) getPendingMessages(curTs, ts *types.TipSet) (map[address.
result[a] = msetCopy result[a] = msetCopy
} }
} })
// we are in sync, that's the happy path // we are in sync, that's the happy path
if inSync { if inSync {

View File

@ -120,7 +120,7 @@ func (a *MpoolAPI) MpoolPending(ctx context.Context, tsk types.TipSetKey) ([]*ty
} }
func (a *MpoolAPI) MpoolClear(ctx context.Context, local bool) error { func (a *MpoolAPI) MpoolClear(ctx context.Context, local bool) error {
a.Mpool.Clear(local) a.Mpool.Clear(ctx, local)
return nil return nil
} }