Protect mp.localAddrs and mp.pending behind helper functions
This commit is contained in:
parent
1f03a618f9
commit
ed93d0725f
@ -126,12 +126,14 @@ type MessagePool struct {
|
||||
|
||||
republished map[cid.Cid]struct{}
|
||||
|
||||
// only pubkey addresses
|
||||
// do NOT access this map directly, use isLocal, setLocal, and forEachLocal respectively
|
||||
localAddrs map[address.Address]struct{}
|
||||
|
||||
// only pubkey addresses
|
||||
// do NOT access this map directly, use getPendingMset, setPendingMset, deletePendingMset, forEachPending, and clearPending respectively
|
||||
pending map[address.Address]*msgSet
|
||||
|
||||
keyCache map[address.Address]address.Address
|
||||
|
||||
curTsLk sync.Mutex // DO NOT LOCK INSIDE lk
|
||||
curTs *types.TipSet
|
||||
|
||||
@ -331,6 +333,20 @@ func (ms *msgSet) getRequiredFunds(nonce uint64) types.BigInt {
|
||||
return types.BigInt{Int: requiredFunds}
|
||||
}
|
||||
|
||||
func (ms *msgSet) toSlice() []*types.SignedMessage {
|
||||
set := make([]*types.SignedMessage, 0, len(ms.msgs))
|
||||
|
||||
for _, m := range ms.msgs {
|
||||
set = append(set, m)
|
||||
}
|
||||
|
||||
sort.Slice(set, func(i, j int) bool {
|
||||
return set[i].Message.Nonce < set[j].Message.Nonce
|
||||
})
|
||||
|
||||
return set
|
||||
}
|
||||
|
||||
func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes.NetworkName, j journal.Journal) (*MessagePool, error) {
|
||||
cache, _ := lru.New2Q(build.BlsSignatureCacheSize)
|
||||
verifcache, _ := lru.New2Q(build.VerifSigCacheSize)
|
||||
@ -352,6 +368,7 @@ func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes
|
||||
repubTrigger: make(chan struct{}, 1),
|
||||
localAddrs: make(map[address.Address]struct{}),
|
||||
pending: make(map[address.Address]*msgSet),
|
||||
keyCache: make(map[address.Address]address.Address),
|
||||
minGasPrice: types.NewInt(0),
|
||||
pruneTrigger: make(chan struct{}, 1),
|
||||
pruneCooldown: make(chan struct{}, 1),
|
||||
@ -397,12 +414,106 @@ func New(ctx context.Context, api Provider, ds dtypes.MetadataDS, netName dtypes
|
||||
|
||||
log.Info("mpool ready")
|
||||
|
||||
mp.runLoop()
|
||||
mp.runLoop(context.Background())
|
||||
}()
|
||||
|
||||
return mp, nil
|
||||
}
|
||||
|
||||
func (mp *MessagePool) resolveToKey(ctx context.Context, addr address.Address) (address.Address, error) {
|
||||
// check the cache
|
||||
a, f := mp.keyCache[addr]
|
||||
if f {
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// resolve the address
|
||||
ka, err := mp.api.StateAccountKey(ctx, addr, mp.curTs)
|
||||
if err != nil {
|
||||
return address.Undef, err
|
||||
}
|
||||
|
||||
// place both entries in the cache (may both be key addresses, which is fine)
|
||||
mp.keyCache[addr] = ka
|
||||
mp.keyCache[ka] = ka
|
||||
|
||||
return ka, nil
|
||||
}
|
||||
|
||||
func (mp *MessagePool) getPendingMset(ctx context.Context, addr address.Address) (*msgSet, bool, error) {
|
||||
ra, err := mp.resolveToKey(ctx, addr)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
ms, f := mp.pending[ra]
|
||||
|
||||
return ms, f, nil
|
||||
}
|
||||
|
||||
func (mp *MessagePool) setPendingMset(ctx context.Context, addr address.Address, ms *msgSet) error {
|
||||
ra, err := mp.resolveToKey(ctx, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mp.pending[ra] = ms
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// This method isn't strictly necessary, since it doesn't resolve any addresses, but it's safer to have
|
||||
func (mp *MessagePool) forEachPending(f func(address.Address, *msgSet)) {
|
||||
for la, ms := range mp.pending {
|
||||
f(la, ms)
|
||||
}
|
||||
}
|
||||
|
||||
func (mp *MessagePool) deletePendingMset(ctx context.Context, addr address.Address) error {
|
||||
ra, err := mp.resolveToKey(ctx, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(mp.pending, ra)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// This method isn't strictly necessary, since it doesn't resolve any addresses, but it's safer to have
|
||||
func (mp *MessagePool) clearPending() {
|
||||
mp.pending = make(map[address.Address]*msgSet)
|
||||
}
|
||||
|
||||
func (mp *MessagePool) isLocal(ctx context.Context, addr address.Address) (bool, error) {
|
||||
ra, err := mp.resolveToKey(ctx, addr)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, f := mp.localAddrs[ra]
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (mp *MessagePool) setLocal(ctx context.Context, addr address.Address) error {
|
||||
ra, err := mp.resolveToKey(ctx, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mp.localAddrs[ra] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// This method isn't strictly necessary, since it doesn't resolve any addresses, but it's safer to have
|
||||
func (mp *MessagePool) forEachLocal(ctx context.Context, f func(context.Context, address.Address)) {
|
||||
for la := range mp.localAddrs {
|
||||
f(ctx, la)
|
||||
}
|
||||
}
|
||||
|
||||
func (mp *MessagePool) Close() error {
|
||||
close(mp.closer)
|
||||
return nil
|
||||
@ -420,15 +531,15 @@ func (mp *MessagePool) Prune() {
|
||||
mp.pruneTrigger <- struct{}{}
|
||||
}
|
||||
|
||||
func (mp *MessagePool) runLoop() {
|
||||
func (mp *MessagePool) runLoop(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-mp.repubTk.C:
|
||||
if err := mp.republishPendingMessages(); err != nil {
|
||||
if err := mp.republishPendingMessages(ctx); err != nil {
|
||||
log.Errorf("error while republishing messages: %s", err)
|
||||
}
|
||||
case <-mp.repubTrigger:
|
||||
if err := mp.republishPendingMessages(); err != nil {
|
||||
if err := mp.republishPendingMessages(ctx); err != nil {
|
||||
log.Errorf("error while republishing messages: %s", err)
|
||||
}
|
||||
|
||||
@ -445,14 +556,10 @@ func (mp *MessagePool) runLoop() {
|
||||
}
|
||||
|
||||
func (mp *MessagePool) addLocal(ctx context.Context, m *types.SignedMessage) error {
|
||||
sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs)
|
||||
if err != nil {
|
||||
log.Debugf("mpooladdlocal failed to resolve sender: %s", err)
|
||||
if err := mp.setLocal(ctx, m.Message.From); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mp.localAddrs[sk] = struct{}{}
|
||||
|
||||
msgb, err := m.Serialize()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("error serializing message: %w", err)
|
||||
@ -653,13 +760,12 @@ func (mp *MessagePool) checkBalance(ctx context.Context, m *types.SignedMessage,
|
||||
// add Value for soft failure check
|
||||
//requiredFunds = types.BigAdd(requiredFunds, m.Message.Value)
|
||||
|
||||
sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs)
|
||||
mset, ok, err := mp.getPendingMset(ctx, m.Message.From)
|
||||
if err != nil {
|
||||
log.Debugf("mpoolcheckbalance failed to resolve sender: %s", err)
|
||||
log.Debugf("mpoolcheckbalance failed to get pending mset: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
mset, ok := mp.pending[sk]
|
||||
if ok {
|
||||
requiredFunds = types.BigAdd(requiredFunds, mset.getRequiredFunds(m.Message.Nonce))
|
||||
}
|
||||
@ -766,21 +872,22 @@ func (mp *MessagePool) addLocked(ctx context.Context, m *types.SignedMessage, st
|
||||
return err
|
||||
}
|
||||
|
||||
sk, err := mp.api.StateAccountKey(ctx, m.Message.From, mp.curTs)
|
||||
mset, ok, err := mp.getPendingMset(ctx, m.Message.From)
|
||||
if err != nil {
|
||||
log.Debugf("mpooladd failed to resolve sender: %s", err)
|
||||
log.Debug(err)
|
||||
return err
|
||||
}
|
||||
|
||||
mset, ok := mp.pending[sk]
|
||||
if !ok {
|
||||
nonce, err := mp.getStateNonce(sk, mp.curTs)
|
||||
nonce, err := mp.getStateNonce(m.Message.From, mp.curTs)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to get initial actor nonce: %w", err)
|
||||
}
|
||||
|
||||
mset = newMsgSet(nonce)
|
||||
mp.pending[sk] = mset
|
||||
if err = mp.setPendingMset(ctx, m.Message.From, mset); err != nil {
|
||||
return xerrors.Errorf("failed to set pending mset: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
incr, err := mset.add(m, mp, strict, untrusted)
|
||||
@ -831,13 +938,12 @@ func (mp *MessagePool) getNonceLocked(ctx context.Context, addr address.Address,
|
||||
return 0, err
|
||||
}
|
||||
|
||||
sk, err := mp.api.StateAccountKey(ctx, addr, mp.curTs)
|
||||
mset, ok, err := mp.getPendingMset(ctx, addr)
|
||||
if err != nil {
|
||||
log.Debugf("mpoolgetnonce failed to resolve sender: %s", err)
|
||||
log.Debugf("mpoolgetnonce failed to get mset: %s", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
mset, ok := mp.pending[sk]
|
||||
if ok {
|
||||
if stateNonce > mset.nextNonce {
|
||||
log.Errorf("state nonce was larger than mset.nextNonce (%d > %d)", stateNonce, mset.nextNonce)
|
||||
@ -917,13 +1023,12 @@ func (mp *MessagePool) Remove(ctx context.Context, from address.Address, nonce u
|
||||
}
|
||||
|
||||
func (mp *MessagePool) remove(ctx context.Context, from address.Address, nonce uint64, applied bool) {
|
||||
sk, err := mp.api.StateAccountKey(ctx, from, mp.curTs)
|
||||
mset, ok, err := mp.getPendingMset(ctx, from)
|
||||
if err != nil {
|
||||
log.Debugf("mpoolremove failed to resolve sender: %s", err)
|
||||
log.Debugf("mpoolremove failed to get mset: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
mset, ok := mp.pending[sk]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@ -948,7 +1053,10 @@ func (mp *MessagePool) remove(ctx context.Context, from address.Address, nonce u
|
||||
mset.rm(nonce, applied)
|
||||
|
||||
if len(mset.msgs) == 0 {
|
||||
delete(mp.pending, from)
|
||||
if err = mp.deletePendingMset(ctx, from); err != nil {
|
||||
log.Debugf("mpoolremove failed to delete mset: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -964,9 +1072,10 @@ func (mp *MessagePool) Pending(ctx context.Context) ([]*types.SignedMessage, *ty
|
||||
|
||||
func (mp *MessagePool) allPending(ctx context.Context) ([]*types.SignedMessage, *types.TipSet) {
|
||||
out := make([]*types.SignedMessage, 0)
|
||||
for a := range mp.pending {
|
||||
out = append(out, mp.pendingFor(ctx, a)...)
|
||||
}
|
||||
|
||||
mp.forEachPending(func(a address.Address, mset *msgSet) {
|
||||
out = append(out, mset.toSlice()...)
|
||||
})
|
||||
|
||||
return out, mp.curTs
|
||||
}
|
||||
@ -981,28 +1090,17 @@ func (mp *MessagePool) PendingFor(ctx context.Context, a address.Address) ([]*ty
|
||||
}
|
||||
|
||||
func (mp *MessagePool) pendingFor(ctx context.Context, a address.Address) []*types.SignedMessage {
|
||||
sk, err := mp.api.StateAccountKey(ctx, a, mp.curTs)
|
||||
mset, ok, err := mp.getPendingMset(ctx, a)
|
||||
if err != nil {
|
||||
log.Debugf("mpoolpendingfor failed to resolve sender: %s", err)
|
||||
log.Debugf("mpoolpendingfor failed to get mset: %s", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
mset := mp.pending[sk]
|
||||
if mset == nil || len(mset.msgs) == 0 {
|
||||
if mset == nil || !ok || len(mset.msgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
set := make([]*types.SignedMessage, 0, len(mset.msgs))
|
||||
|
||||
for _, m := range mset.msgs {
|
||||
set = append(set, m)
|
||||
}
|
||||
|
||||
sort.Slice(set, func(i, j int) bool {
|
||||
return set[i].Message.Nonce < set[j].Message.Nonce
|
||||
})
|
||||
|
||||
return set
|
||||
return mset.toSlice()
|
||||
}
|
||||
|
||||
func (mp *MessagePool) HeadChange(ctx context.Context, revert []*types.TipSet, apply []*types.TipSet) error {
|
||||
@ -1341,53 +1439,61 @@ func (mp *MessagePool) loadLocal(ctx context.Context) error {
|
||||
log.Errorf("adding local message: %+v", err)
|
||||
}
|
||||
|
||||
sk, err := mp.api.StateAccountKey(ctx, sm.Message.From, mp.curTs)
|
||||
if err != nil {
|
||||
log.Debugf("mpoolloadLocal failed to resolve sender: %s", err)
|
||||
if err = mp.setLocal(ctx, sm.Message.From); err != nil {
|
||||
log.Debugf("mpoolloadLocal errored: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
mp.localAddrs[sk] = struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mp *MessagePool) Clear(local bool) {
|
||||
func (mp *MessagePool) Clear(ctx context.Context, local bool) {
|
||||
mp.lk.Lock()
|
||||
defer mp.lk.Unlock()
|
||||
|
||||
// remove everything if local is true, including removing local messages from
|
||||
// the datastore
|
||||
if local {
|
||||
for a := range mp.localAddrs {
|
||||
mset, ok := mp.pending[a]
|
||||
if !ok {
|
||||
continue
|
||||
mp.forEachLocal(ctx, func(ctx context.Context, la address.Address) {
|
||||
mset, ok, err := mp.getPendingMset(ctx, la)
|
||||
if err != nil {
|
||||
log.Warnf("errored while getting pending mset: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range mset.msgs {
|
||||
err := mp.localMsgs.Delete(datastore.NewKey(string(m.Cid().Bytes())))
|
||||
if err != nil {
|
||||
log.Warnf("error deleting local message: %s", err)
|
||||
if ok {
|
||||
for _, m := range mset.msgs {
|
||||
err := mp.localMsgs.Delete(datastore.NewKey(string(m.Cid().Bytes())))
|
||||
if err != nil {
|
||||
log.Warnf("error deleting local message: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
mp.pending = make(map[address.Address]*msgSet)
|
||||
mp.clearPending()
|
||||
mp.republished = nil
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// remove everything except the local messages
|
||||
for a := range mp.pending {
|
||||
_, isLocal := mp.localAddrs[a]
|
||||
if isLocal {
|
||||
continue
|
||||
mp.forEachPending(func(a address.Address, ms *msgSet) {
|
||||
isLocal, err := mp.isLocal(ctx, a)
|
||||
if err != nil {
|
||||
log.Warnf("errored while determining isLocal: %w", err)
|
||||
return
|
||||
}
|
||||
delete(mp.pending, a)
|
||||
}
|
||||
|
||||
if isLocal {
|
||||
return
|
||||
}
|
||||
|
||||
if err = mp.deletePendingMset(ctx, a); err != nil {
|
||||
log.Warnf("errored while deleting mset: %w", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func getBaseFeeLowerBound(baseFee, factor types.BigInt) types.BigInt {
|
||||
|
@ -537,7 +537,7 @@ func TestClearAll(t *testing.T) {
|
||||
mustAdd(t, mp, m)
|
||||
}
|
||||
|
||||
mp.Clear(true)
|
||||
mp.Clear(context.Background(), true)
|
||||
|
||||
pending, _ := mp.Pending(context.TODO())
|
||||
if len(pending) > 0 {
|
||||
@ -592,7 +592,7 @@ func TestClearNonLocal(t *testing.T) {
|
||||
mustAdd(t, mp, m)
|
||||
}
|
||||
|
||||
mp.Clear(false)
|
||||
mp.Clear(context.Background(), false)
|
||||
|
||||
pending, _ := mp.Pending(context.TODO())
|
||||
if len(pending) != 10 {
|
||||
|
@ -61,9 +61,9 @@ func (mp *MessagePool) pruneMessages(ctx context.Context, ts *types.TipSet) erro
|
||||
}
|
||||
|
||||
// we also never prune locally published messages
|
||||
for actor := range mp.localAddrs {
|
||||
mp.forEachLocal(ctx, func(ctx context.Context, actor address.Address) {
|
||||
protected[actor] = struct{}{}
|
||||
}
|
||||
})
|
||||
|
||||
// Collect all messages to track which ones to remove and create chains for block inclusion
|
||||
pruneMsgs := make(map[cid.Cid]*types.SignedMessage, mp.currentSize)
|
||||
|
@ -18,7 +18,7 @@ const repubMsgLimit = 30
|
||||
|
||||
var RepublishBatchDelay = 100 * time.Millisecond
|
||||
|
||||
func (mp *MessagePool) republishPendingMessages() error {
|
||||
func (mp *MessagePool) republishPendingMessages(ctx context.Context) error {
|
||||
mp.curTsLk.Lock()
|
||||
ts := mp.curTs
|
||||
|
||||
@ -32,13 +32,18 @@ func (mp *MessagePool) republishPendingMessages() error {
|
||||
pending := make(map[address.Address]map[uint64]*types.SignedMessage)
|
||||
mp.lk.Lock()
|
||||
mp.republished = nil // clear this to avoid races triggering an early republish
|
||||
for actor := range mp.localAddrs {
|
||||
mset, ok := mp.pending[actor]
|
||||
mp.forEachLocal(ctx, func(ctx context.Context, actor address.Address) {
|
||||
mset, ok, err := mp.getPendingMset(ctx, actor)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get mset: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !ok {
|
||||
continue
|
||||
return
|
||||
}
|
||||
if len(mset.msgs) == 0 {
|
||||
continue
|
||||
return
|
||||
}
|
||||
// we need to copy this while holding the lock to avoid races with concurrent modification
|
||||
pend := make(map[uint64]*types.SignedMessage, len(mset.msgs))
|
||||
@ -46,7 +51,8 @@ func (mp *MessagePool) republishPendingMessages() error {
|
||||
pend[nonce] = m
|
||||
}
|
||||
pending[actor] = pend
|
||||
}
|
||||
})
|
||||
|
||||
mp.lk.Unlock()
|
||||
mp.curTsLk.Unlock()
|
||||
|
||||
|
@ -654,8 +654,7 @@ func (mp *MessagePool) getPendingMessages(curTs, ts *types.TipSet) (map[address.
|
||||
inSync = true
|
||||
}
|
||||
|
||||
// first add our current pending messages
|
||||
for a, mset := range mp.pending {
|
||||
mp.forEachPending(func(a address.Address, mset *msgSet) {
|
||||
if inSync {
|
||||
// no need to copy the map
|
||||
result[a] = mset.msgs
|
||||
@ -668,7 +667,7 @@ func (mp *MessagePool) getPendingMessages(curTs, ts *types.TipSet) (map[address.
|
||||
result[a] = msetCopy
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// we are in sync, that's the happy path
|
||||
if inSync {
|
||||
|
@ -120,7 +120,7 @@ func (a *MpoolAPI) MpoolPending(ctx context.Context, tsk types.TipSetKey) ([]*ty
|
||||
}
|
||||
|
||||
func (a *MpoolAPI) MpoolClear(ctx context.Context, local bool) error {
|
||||
a.Mpool.Clear(local)
|
||||
a.Mpool.Clear(ctx, local)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user