Protect mp.localAddrs and mp.pending behind helper functions

This commit is contained in:
Aayush Rajasekaran 2021-05-28 20:35:50 -04:00
parent 1f03a618f9
commit ed93d0725f
6 changed files with 192 additions and 81 deletions

View File

@ -126,12 +126,14 @@ type MessagePool struct {
republished map[cid.Cid]struct{}
// only pubkey addresses
// do NOT access this map directly, use isLocal, setLocal, and forEachLocal respectively
localAddrs map[address.Address]struct{}
// only pubkey addresses
// do NOT access this map directly, use getPendingMset, setPendingMset, deletePendingMset, forEachPending, and clearPending respectively
pending map[address.Address]*msgSet
keyCache map[address.Address]address.Address
curTsLk sync.Mutex // DO NOT LOCK INSIDE lk
curTs *types.TipSet
@ -331,6 +333,20 @@ func (ms *msgSet) getRequiredFunds(nonce uint64) types.BigInt {
return types.BigInt{Int: requiredFunds}
}
func (ms *msgSet) toSlice() []*types.SignedMessage {
set := make([]*types.SignedMessage, 0, len(ms.msgs))
for _, m := range ms.msgs {
set = append(set, m)
}
sort.Slice(set, func(i, j int) bool {
return set[i].Message.Nonce < set[j].Message.Nonce
})
return set
}
func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes.NetworkName, j journal.Journal) (*MessagePool, error) {
cache, _ := lru.New2Q(build.BlsSignatureCacheSize)
verifcache, _ := lru.New2Q(build.VerifSigCacheSize)
@ -352,6 +368,7 @@ func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes
repubTrigger: make(chan struct{}, 1),
localAddrs: make(map[address.Address]struct{}),
pending: make(map[address.Address]*msgSet),
keyCache: make(map[address.Address]address.Address),
minGasPrice: types.NewInt(0),
pruneTrigger: make(chan struct{}, 1),
pruneCooldown: make(chan struct{}, 1),
@ -397,12 +414,106 @@ func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes
log.Info("mpool ready")
mp.runLoop()
mp.runLoop(context.Background())
}()
return mp, nil
}
func (mp *MessagePool) resolveToKey(ctx context.Context, addr address.Address) (address.Address, error) {
// check the cache
a, f := mp.keyCache[addr]
if f {
return a, nil
}
// resolve the address
ka, err := mp.api.StateAccountKey(ctx, addr, mp.curTs)
if err != nil {
return address.Undef, err
}
// place both entries in the cache (may both be key addresses, which is fine)
mp.keyCache[addr] = ka
mp.keyCache[ka] = ka
return ka, nil
}
func (mp *MessagePool) getPendingMset(ctx context.Context, addr address.Address) (*msgSet, bool, error) {
ra, err := mp.resolveToKey(ctx, addr)
if err != nil {
return nil, false, err
}
ms, f := mp.pending[ra]
return ms, f, nil
}
func (mp *MessagePool) setPendingMset(ctx context.Context, addr address.Address, ms *msgSet) error {
ra, err := mp.resolveToKey(ctx, addr)
if err != nil {
return err
}
mp.pending[ra] = ms
return nil
}
// This method isn't strictly necessary, since it doesn't resolve any addresses, but it's safer to have
func (mp *MessagePool) forEachPending(f func(address.Address, *msgSet)) {
for la, ms := range mp.pending {
f(la, ms)
}
}
func (mp *MessagePool) deletePendingMset(ctx context.Context, addr address.Address) error {
ra, err := mp.resolveToKey(ctx, addr)
if err != nil {
return err
}
delete(mp.pending, ra)
return nil
}
// This method isn't strictly necessary, since it doesn't resolve any addresses, but it's safer to have
func (mp *MessagePool) clearPending() {
mp.pending = make(map[address.Address]*msgSet)
}
func (mp *MessagePool) isLocal(ctx context.Context, addr address.Address) (bool, error) {
ra, err := mp.resolveToKey(ctx, addr)
if err != nil {
return false, err
}
_, f := mp.localAddrs[ra]
return f, nil
}
func (mp *MessagePool) setLocal(ctx context.Context, addr address.Address) error {
ra, err := mp.resolveToKey(ctx, addr)
if err != nil {
return err
}
mp.localAddrs[ra] = struct{}{}
return nil
}
// This method isn't strictly necessary, since it doesn't resolve any addresses, but it's safer to have
func (mp *MessagePool) forEachLocal(ctx context.Context, f func(context.Context, address.Address)) {
for la := range mp.localAddrs {
f(ctx, la)
}
}
func (mp *MessagePool) Close() error {
close(mp.closer)
return nil
@ -420,15 +531,15 @@ func (mp *MessagePool) Prune() {
mp.pruneTrigger <- struct{}{}
}
func (mp *MessagePool) runLoop() {
func (mp *MessagePool) runLoop(ctx context.Context) {
for {
select {
case <-mp.repubTk.C:
if err := mp.republishPendingMessages(); err != nil {
if err := mp.republishPendingMessages(ctx); err != nil {
log.Errorf("error while republishing messages: %s", err)
}
case <-mp.repubTrigger:
if err := mp.republishPendingMessages(); err != nil {
if err := mp.republishPendingMessages(ctx); err != nil {
log.Errorf("error while republishing messages: %s", err)
}
@ -445,14 +556,10 @@ func (mp *MessagePool) runLoop() {
}
func (mp *MessagePool) addLocal(ctx context.Context, m *types.SignedMessage) error {
sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs)
if err != nil {
log.Debugf("mpooladdlocal failed to resolve sender: %s", err)
if err := mp.setLocal(ctx, m.Message.From); err != nil {
return err
}
mp.localAddrs[sk] = struct{}{}
msgb, err := m.Serialize()
if err != nil {
return xerrors.Errorf("error serializing message: %w", err)
@ -653,13 +760,12 @@ func (mp *MessagePool) checkBalance(ctx context.Context, m *types.SignedMessage,
// add Value for soft failure check
//requiredFunds = types.BigAdd(requiredFunds, m.Message.Value)
sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs)
mset, ok, err := mp.getPendingMset(ctx, m.Message.From)
if err != nil {
log.Debugf("mpoolcheckbalance failed to resolve sender: %s", err)
log.Debugf("mpoolcheckbalance failed to get pending mset: %s", err)
return err
}
mset, ok := mp.pending[sk]
if ok {
requiredFunds = types.BigAdd(requiredFunds, mset.getRequiredFunds(m.Message.Nonce))
}
@ -766,21 +872,22 @@ func (mp *MessagePool) addLocked(ctx context.Context, m *types.SignedMessage, st
return err
}
sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs)
mset, ok, err := mp.getPendingMset(ctx, m.Message.From)
if err != nil {
log.Debugf("mpooladd failed to resolve sender: %s", err)
log.Debug(err)
return err
}
mset, ok := mp.pending[sk]
if !ok {
nonce, err := mp.getStateNonce(sk, mp.curTs)
nonce, err := mp.getStateNonce(m.Message.From, mp.curTs)
if err != nil {
return xerrors.Errorf("failed to get initial actor nonce: %w", err)
}
mset = newMsgSet(nonce)
mp.pending[sk] = mset
if err = mp.setPendingMset(ctx, m.Message.From, mset); err != nil {
return xerrors.Errorf("failed to set pending mset: %w", err)
}
}
incr, err := mset.add(m, mp, strict, untrusted)
@ -831,13 +938,12 @@ func (mp *MessagePool) getNonceLocked(ctx context.Context, addr address.Address,
return 0, err
}
sk, err := mp.api.StateAccountKey(ctx, addr, mp.curTs)
mset, ok, err := mp.getPendingMset(ctx, addr)
if err != nil {
log.Debugf("mpoolgetnonce failed to resolve sender: %s", err)
log.Debugf("mpoolgetnonce failed to get mset: %s", err)
return 0, err
}
mset, ok := mp.pending[sk]
if ok {
if stateNonce > mset.nextNonce {
log.Errorf("state nonce was larger than mset.nextNonce (%d > %d)", stateNonce, mset.nextNonce)
@ -917,13 +1023,12 @@ func (mp *MessagePool) Remove(ctx context.Context, from address.Address, nonce u
}
func (mp *MessagePool) remove(ctx context.Context, from address.Address, nonce uint64, applied bool) {
sk, err := mp.api.StateAccountKey(ctx, from, mp.curTs)
mset, ok, err := mp.getPendingMset(ctx, from)
if err != nil {
log.Debugf("mpoolremove failed to resolve sender: %s", err)
log.Debugf("mpoolremove failed to get mset: %s", err)
return
}
mset, ok := mp.pending[sk]
if !ok {
return
}
@ -948,7 +1053,10 @@ func (mp *MessagePool) remove(ctx context.Context, from address.Address, nonce u
mset.rm(nonce, applied)
if len(mset.msgs) == 0 {
delete(mp.pending, from)
if err = mp.deletePendingMset(ctx, from); err != nil {
log.Debugf("mpoolremove failed to delete mset: %s", err)
return
}
}
}
@ -964,9 +1072,10 @@ func (mp *MessagePool) Pending(ctx context.Context) ([]*types.SignedMessage, *ty
func (mp *MessagePool) allPending(ctx context.Context) ([]*types.SignedMessage, *types.TipSet) {
out := make([]*types.SignedMessage, 0)
for a := range mp.pending {
out = append(out, mp.pendingFor(ctx, a)...)
}
mp.forEachPending(func(a address.Address, mset *msgSet) {
out = append(out, mset.toSlice()...)
})
return out, mp.curTs
}
@ -981,28 +1090,17 @@ func (mp *MessagePool) PendingFor(ctx context.Context, a address.Address) ([]*ty
}
func (mp *MessagePool) pendingFor(ctx context.Context, a address.Address) []*types.SignedMessage {
sk, err := mp.api.StateAccountKey(ctx, a, mp.curTs)
mset, ok, err := mp.getPendingMset(ctx, a)
if err != nil {
log.Debugf("mpoolpendingfor failed to resolve sender: %s", err)
log.Debugf("mpoolpendingfor failed to get mset: %s", err)
return nil
}
mset := mp.pending[sk]
if mset == nil || len(mset.msgs) == 0 {
if mset == nil || !ok || len(mset.msgs) == 0 {
return nil
}
set := make([]*types.SignedMessage, 0, len(mset.msgs))
for _, m := range mset.msgs {
set = append(set, m)
}
sort.Slice(set, func(i, j int) bool {
return set[i].Message.Nonce < set[j].Message.Nonce
})
return set
return mset.toSlice()
}
func (mp *MessagePool) HeadChange(ctx context.Context, revert []*types.TipSet, apply []*types.TipSet) error {
@ -1341,53 +1439,61 @@ func (mp *MessagePool) loadLocal(ctx context.Context) error {
log.Errorf("adding local message: %+v", err)
}
sk, err := mp.api.StateAccountKey(ctx, sm.Message.From, mp.curTs)
if err != nil {
log.Debugf("mpoolloadLocal failed to resolve sender: %s", err)
if err = mp.setLocal(ctx, sm.Message.From); err != nil {
log.Debugf("mpoolloadLocal errored: %s", err)
return err
}
mp.localAddrs[sk] = struct{}{}
}
return nil
}
func (mp *MessagePool) Clear(local bool) {
func (mp *MessagePool) Clear(ctx context.Context, local bool) {
mp.lk.Lock()
defer mp.lk.Unlock()
// remove everything if local is true, including removing local messages from
// the datastore
if local {
for a := range mp.localAddrs {
mset, ok := mp.pending[a]
if !ok {
continue
mp.forEachLocal(ctx, func(ctx context.Context, la address.Address) {
mset, ok, err := mp.getPendingMset(ctx, la)
if err != nil {
log.Warnf("errored while getting pending mset: %w", err)
return
}
for _, m := range mset.msgs {
err := mp.localMsgs.Delete(datastore.NewKey(string(m.Cid().Bytes())))
if err != nil {
log.Warnf("error deleting local message: %s", err)
if ok {
for _, m := range mset.msgs {
err := mp.localMsgs.Delete(datastore.NewKey(string(m.Cid().Bytes())))
if err != nil {
log.Warnf("error deleting local message: %s", err)
}
}
}
}
})
mp.pending = make(map[address.Address]*msgSet)
mp.clearPending()
mp.republished = nil
return
}
// remove everything except the local messages
for a := range mp.pending {
_, isLocal := mp.localAddrs[a]
if isLocal {
continue
mp.forEachPending(func(a address.Address, ms *msgSet) {
isLocal, err := mp.isLocal(ctx, a)
if err != nil {
log.Warnf("errored while determining isLocal: %w", err)
return
}
delete(mp.pending, a)
}
if isLocal {
return
}
if err = mp.deletePendingMset(ctx, a); err != nil {
log.Warnf("errored while deleting mset: %w", err)
return
}
})
}
func getBaseFeeLowerBound(baseFee, factor types.BigInt) types.BigInt {

View File

@ -537,7 +537,7 @@ func TestClearAll(t *testing.T) {
mustAdd(t, mp, m)
}
mp.Clear(true)
mp.Clear(context.Background(), true)
pending, _ := mp.Pending(context.TODO())
if len(pending) > 0 {
@ -592,7 +592,7 @@ func TestClearNonLocal(t *testing.T) {
mustAdd(t, mp, m)
}
mp.Clear(false)
mp.Clear(context.Background(), false)
pending, _ := mp.Pending(context.TODO())
if len(pending) != 10 {

View File

@ -61,9 +61,9 @@ func (mp *MessagePool) pruneMessages(ctx context.Context, ts *types.TipSet) erro
}
// we also never prune locally published messages
for actor := range mp.localAddrs {
mp.forEachLocal(ctx, func(ctx context.Context, actor address.Address) {
protected[actor] = struct{}{}
}
})
// Collect all messages to track which ones to remove and create chains for block inclusion
pruneMsgs := make(map[cid.Cid]*types.SignedMessage, mp.currentSize)

View File

@ -18,7 +18,7 @@ const repubMsgLimit = 30
var RepublishBatchDelay = 100 * time.Millisecond
func (mp *MessagePool) republishPendingMessages() error {
func (mp *MessagePool) republishPendingMessages(ctx context.Context) error {
mp.curTsLk.Lock()
ts := mp.curTs
@ -32,13 +32,18 @@ func (mp *MessagePool) republishPendingMessages() error {
pending := make(map[address.Address]map[uint64]*types.SignedMessage)
mp.lk.Lock()
mp.republished = nil // clear this to avoid races triggering an early republish
for actor := range mp.localAddrs {
mset, ok := mp.pending[actor]
mp.forEachLocal(ctx, func(ctx context.Context, actor address.Address) {
mset, ok, err := mp.getPendingMset(ctx, actor)
if err != nil {
log.Debugf("failed to get mset: %w", err)
return
}
if !ok {
continue
return
}
if len(mset.msgs) == 0 {
continue
return
}
// we need to copy this while holding the lock to avoid races with concurrent modification
pend := make(map[uint64]*types.SignedMessage, len(mset.msgs))
@ -46,7 +51,8 @@ func (mp *MessagePool) republishPendingMessages() error {
pend[nonce] = m
}
pending[actor] = pend
}
})
mp.lk.Unlock()
mp.curTsLk.Unlock()

View File

@ -654,8 +654,7 @@ func (mp *MessagePool) getPendingMessages(curTs, ts *types.TipSet) (map[address.
inSync = true
}
// first add our current pending messages
for a, mset := range mp.pending {
mp.forEachPending(func(a address.Address, mset *msgSet) {
if inSync {
// no need to copy the map
result[a] = mset.msgs
@ -668,7 +667,7 @@ func (mp *MessagePool) getPendingMessages(curTs, ts *types.TipSet) (map[address.
result[a] = msetCopy
}
}
})
// we are in sync, that's the happy path
if inSync {

View File

@ -120,7 +120,7 @@ func (a *MpoolAPI) MpoolPending(ctx context.Context, tsk types.TipSetKey) ([]*ty
}
func (a *MpoolAPI) MpoolClear(ctx context.Context, local bool) error {
a.Mpool.Clear(local)
a.Mpool.Clear(ctx, local)
return nil
}