mpool: Fix racy nonce logic

This commit is contained in:
Łukasz Magiera 2019-09-20 11:01:49 +02:00
parent 7073676983
commit 113f6f4791
2 changed files with 45 additions and 18 deletions

View File

@ -2,13 +2,15 @@ package chain
import ( import (
"encoding/base64" "encoding/base64"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"sync" "sync"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/pkg/errors"
"golang.org/x/xerrors"
"github.com/filecoin-project/go-lotus/chain/address" "github.com/filecoin-project/go-lotus/chain/address"
"github.com/filecoin-project/go-lotus/chain/stmgr" "github.com/filecoin-project/go-lotus/chain/stmgr"
"github.com/filecoin-project/go-lotus/chain/types" "github.com/filecoin-project/go-lotus/chain/types"
"github.com/pkg/errors"
) )
type MessagePool struct { type MessagePool struct {
@ -22,8 +24,8 @@ type MessagePool struct {
} }
type msgSet struct { type msgSet struct {
msgs map[uint64]*types.SignedMessage msgs map[uint64]*types.SignedMessage
startNonce uint64 nextNonce uint64
} }
func newMsgSet() *msgSet { func newMsgSet() *msgSet {
@ -32,11 +34,20 @@ func newMsgSet() *msgSet {
} }
} }
func (ms *msgSet) add(m *types.SignedMessage) { func (ms *msgSet) add(m *types.SignedMessage) error {
if len(ms.msgs) == 0 || m.Message.Nonce < ms.startNonce { if len(ms.msgs) == 0 || m.Message.Nonce >= ms.nextNonce {
ms.startNonce = m.Message.Nonce ms.nextNonce = m.Message.Nonce + 1
}
if _, has := ms.msgs[m.Message.Nonce]; has {
if m.Cid() != ms.msgs[m.Message.Nonce].Cid() {
log.Error("Add with duplicate nonce")
return xerrors.Errorf("message to %s with nonce %d already in mpool")
}
log.Warn("Add called with the same message multiple times")
} }
ms.msgs[m.Message.Nonce] = m ms.msgs[m.Message.Nonce] = m
return nil
} }
func NewMessagePool(sm *stmgr.StateManager, ps *pubsub.PubSub) *MessagePool { func NewMessagePool(sm *stmgr.StateManager, ps *pubsub.PubSub) *MessagePool {
@ -76,13 +87,15 @@ func (mp *MessagePool) addLocked(m *types.SignedMessage) error {
return err return err
} }
log.Info("mpooladd: %s", base64.StdEncoding.EncodeToString(data)) log.Infof("mpooladd: %d %s", m.Message.Nonce, base64.StdEncoding.EncodeToString(data))
if err := m.Signature.Verify(m.Message.From, data); err != nil { if err := m.Signature.Verify(m.Message.From, data); err != nil {
log.Warnf("mpooladd signature verification failed: %s", err)
return err return err
} }
if _, err := mp.sm.ChainStore().PutMessage(m); err != nil { if _, err := mp.sm.ChainStore().PutMessage(m); err != nil {
log.Warnf("mpooladd cs.PutMessage failed: %s", err)
return err return err
} }
@ -106,7 +119,7 @@ func (mp *MessagePool) GetNonce(addr address.Address) (uint64, error) {
func (mp *MessagePool) getNonceLocked(addr address.Address) (uint64, error) { func (mp *MessagePool) getNonceLocked(addr address.Address) (uint64, error) {
mset, ok := mp.pending[addr] mset, ok := mp.pending[addr]
if ok { if ok {
return mset.startNonce + uint64(len(mset.msgs)), nil return mset.nextNonce, nil
} }
act, err := mp.sm.GetActor(addr) act, err := mp.sm.GetActor(addr)
@ -157,7 +170,16 @@ func (mp *MessagePool) Remove(from address.Address, nonce uint64) {
delete(mset.msgs, nonce) delete(mset.msgs, nonce)
if len(mset.msgs) == 0 { if len(mset.msgs) == 0 {
delete(mp.pending, from) // FIXME: This is racy
//delete(mp.pending, from)
} else {
var max uint64
for nonce := range mset.msgs {
if max < nonce {
max = nonce
}
}
mset.nextNonce = max + 1
} }
} }
@ -166,13 +188,18 @@ func (mp *MessagePool) Pending() []*types.SignedMessage {
defer mp.lk.Unlock() defer mp.lk.Unlock()
out := make([]*types.SignedMessage, 0) out := make([]*types.SignedMessage, 0)
for _, mset := range mp.pending { for _, mset := range mp.pending {
for i := mset.startNonce; true; i++ { if len(mset.msgs) == 0 {
m, ok := mset.msgs[i] continue
if !ok {
break
}
out = append(out, m)
} }
set := make([]*types.SignedMessage, len(mset.msgs))
var i uint64
for i = mset.nextNonce - 1; mset.msgs[i] != nil; i-- {
set[len(mset.msgs)-int(mset.nextNonce-i)] = mset.msgs[i]
}
out = append(out, set[len(mset.msgs)-int(mset.nextNonce-i-1):]...)
} }
return out return out

View File

@ -25,8 +25,8 @@ class ConnMgr extends React.Component {
const nodes = this.props.nodes const nodes = this.props.nodes
let keys = Object.keys(nodes) let keys = Object.keys(nodes)
const newConns = await keys.filter((_, i) => i > 0).map(async (kfrom, i) => { const newConns = await keys.filter((_, i) => i > 0).filter(kfrom => this.props.nodes[kfrom].conn !== undefined).map(async (kfrom, i) => {
return keys.filter((_, j) => i >= j).map(async kto => { return keys.filter((_, j) => i >= j).filter(kto => this.props.nodes[kto].conn !== undefined).map(async kto => {
const fromNd = this.props.nodes[kfrom] const fromNd = this.props.nodes[kfrom]
const toNd = this.props.nodes[kto] const toNd = this.props.nodes[kto]