refactor: move nonce generation from mpool to wallet

This commit is contained in:
Dirk McCormick 2020-09-23 18:28:11 +02:00
parent efd2dff0ca
commit 3c524ac0e0
5 changed files with 310 additions and 128 deletions

View File

@ -75,8 +75,6 @@ var (
ErrRBFTooLowPremium = errors.New("replace by fee has too low GasPremium") ErrRBFTooLowPremium = errors.New("replace by fee has too low GasPremium")
ErrTooManyPendingMessages = errors.New("too many pending messages for actor") ErrTooManyPendingMessages = errors.New("too many pending messages for actor")
ErrNonceGap = errors.New("unfulfilled nonce gap") ErrNonceGap = errors.New("unfulfilled nonce gap")
ErrTryAgain = errors.New("state inconsistency while pushing message; please try again")
) )
const ( const (
@ -795,98 +793,6 @@ func (mp *MessagePool) getStateBalance(addr address.Address, ts *types.TipSet) (
return act.Balance, nil return act.Balance, nil
} }
func (mp *MessagePool) PushWithNonce(ctx context.Context, addr address.Address, cb func(address.Address, uint64) (*types.SignedMessage, error)) (*types.SignedMessage, error) {
// serialize push access to reduce lock contention
mp.addSema <- struct{}{}
defer func() {
<-mp.addSema
}()
mp.curTsLk.Lock()
mp.lk.Lock()
curTs := mp.curTs
fromKey := addr
if fromKey.Protocol() == address.ID {
var err error
fromKey, err = mp.api.StateAccountKey(ctx, fromKey, mp.curTs)
if err != nil {
mp.lk.Unlock()
mp.curTsLk.Unlock()
return nil, xerrors.Errorf("resolving sender key: %w", err)
}
}
nonce, err := mp.getNonceLocked(fromKey, mp.curTs)
if err != nil {
mp.lk.Unlock()
mp.curTsLk.Unlock()
return nil, xerrors.Errorf("get nonce locked failed: %w", err)
}
// release the locks for signing
mp.lk.Unlock()
mp.curTsLk.Unlock()
msg, err := cb(fromKey, nonce)
if err != nil {
return nil, err
}
err = mp.checkMessage(msg)
if err != nil {
return nil, err
}
msgb, err := msg.Serialize()
if err != nil {
return nil, err
}
// reacquire the locks and check state for consistency
mp.curTsLk.Lock()
defer mp.curTsLk.Unlock()
if mp.curTs != curTs {
return nil, ErrTryAgain
}
mp.lk.Lock()
defer mp.lk.Unlock()
nonce2, err := mp.getNonceLocked(fromKey, mp.curTs)
if err != nil {
return nil, xerrors.Errorf("get nonce locked failed: %w", err)
}
if nonce2 != nonce {
return nil, ErrTryAgain
}
publish, err := mp.verifyMsgBeforeAdd(msg, curTs, true)
if err != nil {
return nil, err
}
if err := mp.checkBalance(msg, curTs); err != nil {
return nil, err
}
if err := mp.addLocked(msg, false); err != nil {
return nil, xerrors.Errorf("add locked failed: %w", err)
}
if err := mp.addLocal(msg, msgb); err != nil {
log.Errorf("addLocal failed: %+v", err)
}
if publish {
err = mp.api.PubSubPublish(build.MessagesTopic(mp.netName), msgb)
}
return msg, err
}
func (mp *MessagePool) Remove(from address.Address, nonce uint64, applied bool) { func (mp *MessagePool) Remove(from address.Address, nonce uint64, applied bool) {
mp.lk.Lock() mp.lk.Lock()
defer mp.lk.Unlock() defer mp.lk.Unlock()

View File

@ -0,0 +1,116 @@
package messagesigner
import (
"bytes"
"context"
"github.com/filecoin-project/lotus/chain/wallet"
"github.com/filecoin-project/lotus/chain/messagepool"
"github.com/filecoin-project/go-address"
"github.com/filecoin-project/lotus/chain/types"
"github.com/filecoin-project/lotus/node/modules/dtypes"
"github.com/ipfs/go-datastore"
"github.com/ipfs/go-datastore/namespace"
cbg "github.com/whyrusleeping/cbor-gen"
"golang.org/x/xerrors"
)
const dsKeyActorNonce = "ActorNonce"
type mpoolAPI interface {
GetNonce(address.Address) (uint64, error)
}
// MessageSigner keeps track of nonces per address, and increments the nonce
// when signing a message
type MessageSigner struct {
wallet *wallet.Wallet
mpool mpoolAPI
ds datastore.Batching
}
func NewMessageSigner(wallet *wallet.Wallet, mpool *messagepool.MessagePool, ds dtypes.MetadataDS) *MessageSigner {
return newMessageSigner(wallet, mpool, ds)
}
func newMessageSigner(wallet *wallet.Wallet, mpool mpoolAPI, ds dtypes.MetadataDS) *MessageSigner {
ds = namespace.Wrap(ds, datastore.NewKey("/message-signer/"))
return &MessageSigner{
wallet: wallet,
mpool: mpool,
ds: ds,
}
}
// SignMessage increments the nonce for the message From address, and signs
// the message
func (ms *MessageSigner) SignMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) {
nonce, err := ms.nextNonce(msg.From)
if err != nil {
return nil, xerrors.Errorf("failed to create nonce: %w", err)
}
msg.Nonce = nonce
sig, err := ms.wallet.Sign(ctx, msg.From, msg.Cid().Bytes())
if err != nil {
return nil, xerrors.Errorf("failed to sign message: %w", err)
}
return &types.SignedMessage{
Message: *msg,
Signature: *sig,
}, nil
}
// nextNonce increments the nonce.
// If there is no nonce in the datastore, gets the nonce from the message pool.
func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) {
addrNonceKey := datastore.KeyWithNamespaces([]string{dsKeyActorNonce, addr.String()})
// Get the nonce for this address from the datastore
nonceBytes, err := ms.ds.Get(addrNonceKey)
var nonce uint64
switch {
case xerrors.Is(err, datastore.ErrNotFound):
// If a nonce for this address hasn't yet been created in the
// datastore, check the mempool - nonces used to be created by
// the mempool so we need to support nodes that still have mempool
// nonces. Note that the mempool returns the actor state's nonce by
// default.
nonce, err = ms.mpool.GetNonce(addr)
if err != nil {
return 0, xerrors.Errorf("failed to get nonce from mempool: %w", err)
}
case err != nil:
return 0, xerrors.Errorf("failed to get nonce from datastore: %w", err)
default:
// There is a nonce in the mempool, so unmarshall and increment it
maj, val, err := cbg.CborReadHeader(bytes.NewReader(nonceBytes))
if err != nil {
return 0, xerrors.Errorf("failed to parse nonce from datastore: %w", err)
}
if maj != cbg.MajUnsignedInt {
return 0, xerrors.Errorf("bad cbor type parsing nonce from datastore")
}
nonce = val + 1
}
// Write the nonce for this address to the datastore
buf := bytes.Buffer{}
_, err = buf.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, nonce))
if err != nil {
return 0, xerrors.Errorf("failed to marshall nonce: %w", err)
}
err = ms.ds.Put(addrNonceKey, buf.Bytes())
if err != nil {
return 0, xerrors.Errorf("failed to write nonce to datastore: %w", err)
}
return nonce, nil
}

View File

@ -0,0 +1,159 @@
package messagesigner
import (
"context"
"sync"
"testing"
"github.com/filecoin-project/lotus/chain/wallet"
"github.com/filecoin-project/go-state-types/crypto"
"github.com/stretchr/testify/require"
ds_sync "github.com/ipfs/go-datastore/sync"
"github.com/filecoin-project/go-address"
"github.com/filecoin-project/lotus/chain/types"
"github.com/ipfs/go-datastore"
)
type mockMpool struct {
lk sync.RWMutex
nonces map[address.Address]uint64
}
func newMockMpool() *mockMpool {
return &mockMpool{nonces: make(map[address.Address]uint64)}
}
func (mp *mockMpool) setNonce(addr address.Address, nonce uint64) {
mp.lk.Lock()
defer mp.lk.Unlock()
mp.nonces[addr] = nonce
}
func (mp *mockMpool) GetNonce(addr address.Address) (uint64, error) {
mp.lk.RLock()
defer mp.lk.RUnlock()
return mp.nonces[addr], nil
}
func TestMessageSignerSignMessage(t *testing.T) {
ctx := context.Background()
w, _ := wallet.NewWallet(wallet.NewMemKeyStore())
from1, err := w.GenerateKey(crypto.SigTypeSecp256k1)
require.NoError(t, err)
from2, err := w.GenerateKey(crypto.SigTypeSecp256k1)
require.NoError(t, err)
to1, err := w.GenerateKey(crypto.SigTypeSecp256k1)
require.NoError(t, err)
to2, err := w.GenerateKey(crypto.SigTypeSecp256k1)
require.NoError(t, err)
type msgSpec struct {
msg *types.Message
mpoolNonce [1]uint64
expNonce uint64
}
tests := []struct {
name string
msgs []msgSpec
}{{
// No nonce yet in datastore
name: "no nonce yet",
msgs: []msgSpec{{
msg: &types.Message{
To: to1,
From: from1,
},
expNonce: 0,
}},
}, {
// Get nonce value of zero from mpool
name: "mpool nonce zero",
msgs: []msgSpec{{
msg: &types.Message{
To: to1,
From: from1,
},
mpoolNonce: [1]uint64{0},
expNonce: 0,
}},
}, {
// Get non-zero nonce value from mpool
name: "mpool nonce set",
msgs: []msgSpec{{
msg: &types.Message{
To: to1,
From: from1,
},
mpoolNonce: [1]uint64{5},
expNonce: 5,
}, {
msg: &types.Message{
To: to1,
From: from1,
},
// Should ignore mpool nonce because after the first message nonce
// will come from the datastore
mpoolNonce: [1]uint64{10},
expNonce: 6,
}},
}, {
// Nonce should increment independently for each address
name: "nonce increments per address",
msgs: []msgSpec{{
msg: &types.Message{
To: to1,
From: from1,
},
expNonce: 0,
}, {
msg: &types.Message{
To: to1,
From: from1,
},
expNonce: 1,
}, {
msg: &types.Message{
To: to2,
From: from2,
},
mpoolNonce: [1]uint64{5},
expNonce: 5,
}, {
msg: &types.Message{
To: to2,
From: from2,
},
expNonce: 6,
}, {
msg: &types.Message{
To: to1,
From: from1,
},
expNonce: 2,
}},
}}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
mpool := newMockMpool()
ds := ds_sync.MutexWrap(datastore.NewMapDatastore())
ms := newMessageSigner(w, mpool, ds)
for _, m := range tt.msgs {
if len(m.mpoolNonce) == 1 {
mpool.setNonce(m.msg.From, m.mpoolNonce[0])
}
smsg, err := ms.SignMessage(ctx, m.msg)
require.NoError(t, err)
require.Equal(t, m.expNonce, smsg.Message.Nonce)
}
})
}
}

View File

@ -6,6 +6,8 @@ import (
"os" "os"
"time" "time"
"github.com/filecoin-project/lotus/chain/messagesigner"
logging "github.com/ipfs/go-log" logging "github.com/ipfs/go-log"
ci "github.com/libp2p/go-libp2p-core/crypto" ci "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/host"
@ -259,6 +261,7 @@ func Online() Option {
Override(new(*store.ChainStore), modules.ChainStore), Override(new(*store.ChainStore), modules.ChainStore),
Override(new(*stmgr.StateManager), stmgr.NewStateManager), Override(new(*stmgr.StateManager), stmgr.NewStateManager),
Override(new(*wallet.Wallet), wallet.NewWallet), Override(new(*wallet.Wallet), wallet.NewWallet),
Override(new(*messagesigner.MessageSigner), messagesigner.NewMessageSigner),
Override(new(dtypes.ChainGCLocker), blockstore.NewGCLocker), Override(new(dtypes.ChainGCLocker), blockstore.NewGCLocker),
Override(new(dtypes.ChainGCBlockstore), modules.ChainGCBlockstore), Override(new(dtypes.ChainGCBlockstore), modules.ChainGCBlockstore),

View File

@ -4,14 +4,14 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/filecoin-project/lotus/chain/messagesigner"
"github.com/filecoin-project/go-address" "github.com/filecoin-project/go-address"
"github.com/ipfs/go-cid" "github.com/ipfs/go-cid"
"go.uber.org/fx" "go.uber.org/fx"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/filecoin-project/lotus/api" "github.com/filecoin-project/lotus/api"
"github.com/filecoin-project/lotus/chain/messagepool"
"github.com/filecoin-project/lotus/chain/store"
"github.com/filecoin-project/lotus/chain/types" "github.com/filecoin-project/lotus/chain/types"
"github.com/filecoin-project/lotus/node/modules/dtypes" "github.com/filecoin-project/lotus/node/modules/dtypes"
) )
@ -22,9 +22,7 @@ type MpoolAPI struct {
WalletAPI WalletAPI
GasAPI GasAPI
Chain *store.ChainStore MessageSigner *messagesigner.MessageSigner
Mpool *messagepool.MessagePool
PushLocks *dtypes.MpoolLocker PushLocks *dtypes.MpoolLocker
} }
@ -114,12 +112,14 @@ func (a *MpoolAPI) MpoolPush(ctx context.Context, smsg *types.SignedMessage) (ci
} }
func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spec *api.MessageSendSpec) (*types.SignedMessage, error) { func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spec *api.MessageSendSpec) (*types.SignedMessage, error) {
cp := *msg
msg = &cp
inMsg := *msg inMsg := *msg
{
fromA, err := a.Stmgr.ResolveToKeyAddress(ctx, msg.From, nil) fromA, err := a.Stmgr.ResolveToKeyAddress(ctx, msg.From, nil)
if err != nil { if err != nil {
return nil, xerrors.Errorf("getting key address: %w", err) return nil, xerrors.Errorf("getting key address: %w", err)
} }
{
done, err := a.PushLocks.TakeLock(ctx, fromA) done, err := a.PushLocks.TakeLock(ctx, fromA)
if err != nil { if err != nil {
return nil, xerrors.Errorf("taking lock: %w", err) return nil, xerrors.Errorf("taking lock: %w", err)
@ -131,7 +131,7 @@ func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spe
return nil, xerrors.Errorf("MpoolPushMessage expects message nonce to be 0, was %d", msg.Nonce) return nil, xerrors.Errorf("MpoolPushMessage expects message nonce to be 0, was %d", msg.Nonce)
} }
msg, err := a.GasAPI.GasEstimateMessageGas(ctx, msg, spec, types.EmptyTSK) msg, err = a.GasAPI.GasEstimateMessageGas(ctx, msg, spec, types.EmptyTSK)
if err != nil { if err != nil {
return nil, xerrors.Errorf("GasEstimateMessageGas error: %w", err) return nil, xerrors.Errorf("GasEstimateMessageGas error: %w", err)
} }
@ -143,11 +143,9 @@ func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spe
inJson, outJson) inJson, outJson)
} }
sign := func(from address.Address, nonce uint64) (*types.SignedMessage, error) {
msg.Nonce = nonce
if msg.From.Protocol() == address.ID { if msg.From.Protocol() == address.ID {
log.Warnf("Push from ID address (%s), adjusting to %s", msg.From, from) log.Warnf("Push from ID address (%s), adjusting to %s", msg.From, fromA)
msg.From = from msg.From = fromA
} }
b, err := a.WalletBalance(ctx, msg.From) b, err := a.WalletBalance(ctx, msg.From)
@ -159,17 +157,17 @@ func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spe
return nil, xerrors.Errorf("mpool push: not enough funds: %s < %s", b, msg.Value) return nil, xerrors.Errorf("mpool push: not enough funds: %s < %s", b, msg.Value)
} }
return a.WalletSignMessage(ctx, from, msg) smsg, err := a.MessageSigner.SignMessage(ctx, msg)
if err != nil {
return nil, xerrors.Errorf("mpool push: failed to sign message: %w", err)
} }
var m *types.SignedMessage _, err = a.Mpool.Push(smsg)
again: if err != nil {
m, err = a.Mpool.PushWithNonce(ctx, msg.From, sign) return nil, xerrors.Errorf("mpool push: failed to push message: %w", err)
if err == messagepool.ErrTryAgain {
log.Debugf("temporary failure while pushing message: %s; retrying", err)
goto again
} }
return m, err
return smsg, err
} }
func (a *MpoolAPI) MpoolGetNonce(ctx context.Context, addr address.Address) (uint64, error) { func (a *MpoolAPI) MpoolGetNonce(ctx context.Context, addr address.Address) (uint64, error) {