Merge pull request #3970 from filecoin-project/feat/mpool-to-wallet
refactor: move nonce generation out of mpool
This commit is contained in:
commit
1479142047
@ -75,8 +75,6 @@ var (
|
||||
ErrRBFTooLowPremium = errors.New("replace by fee has too low GasPremium")
|
||||
ErrTooManyPendingMessages = errors.New("too many pending messages for actor")
|
||||
ErrNonceGap = errors.New("unfulfilled nonce gap")
|
||||
|
||||
ErrTryAgain = errors.New("state inconsistency while pushing message; please try again")
|
||||
)
|
||||
|
||||
const (
|
||||
@ -795,98 +793,6 @@ func (mp *MessagePool) getStateBalance(addr address.Address, ts *types.TipSet) (
|
||||
return act.Balance, nil
|
||||
}
|
||||
|
||||
func (mp *MessagePool) PushWithNonce(ctx context.Context, addr address.Address, cb func(address.Address, uint64) (*types.SignedMessage, error)) (*types.SignedMessage, error) {
|
||||
// serialize push access to reduce lock contention
|
||||
mp.addSema <- struct{}{}
|
||||
defer func() {
|
||||
<-mp.addSema
|
||||
}()
|
||||
|
||||
mp.curTsLk.Lock()
|
||||
mp.lk.Lock()
|
||||
|
||||
curTs := mp.curTs
|
||||
|
||||
fromKey := addr
|
||||
if fromKey.Protocol() == address.ID {
|
||||
var err error
|
||||
fromKey, err = mp.api.StateAccountKey(ctx, fromKey, mp.curTs)
|
||||
if err != nil {
|
||||
mp.lk.Unlock()
|
||||
mp.curTsLk.Unlock()
|
||||
return nil, xerrors.Errorf("resolving sender key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
nonce, err := mp.getNonceLocked(fromKey, mp.curTs)
|
||||
if err != nil {
|
||||
mp.lk.Unlock()
|
||||
mp.curTsLk.Unlock()
|
||||
return nil, xerrors.Errorf("get nonce locked failed: %w", err)
|
||||
}
|
||||
|
||||
// release the locks for signing
|
||||
mp.lk.Unlock()
|
||||
mp.curTsLk.Unlock()
|
||||
|
||||
msg, err := cb(fromKey, nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = mp.checkMessage(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgb, err := msg.Serialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// reacquire the locks and check state for consistency
|
||||
mp.curTsLk.Lock()
|
||||
defer mp.curTsLk.Unlock()
|
||||
|
||||
if mp.curTs != curTs {
|
||||
return nil, ErrTryAgain
|
||||
}
|
||||
|
||||
mp.lk.Lock()
|
||||
defer mp.lk.Unlock()
|
||||
|
||||
nonce2, err := mp.getNonceLocked(fromKey, mp.curTs)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get nonce locked failed: %w", err)
|
||||
}
|
||||
|
||||
if nonce2 != nonce {
|
||||
return nil, ErrTryAgain
|
||||
}
|
||||
|
||||
publish, err := mp.verifyMsgBeforeAdd(msg, curTs, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := mp.checkBalance(msg, curTs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := mp.addLocked(msg, false); err != nil {
|
||||
return nil, xerrors.Errorf("add locked failed: %w", err)
|
||||
}
|
||||
if err := mp.addLocal(msg, msgb); err != nil {
|
||||
log.Errorf("addLocal failed: %+v", err)
|
||||
}
|
||||
|
||||
if publish {
|
||||
err = mp.api.PubSubPublish(build.MessagesTopic(mp.netName), msgb)
|
||||
}
|
||||
|
||||
return msg, err
|
||||
}
|
||||
|
||||
func (mp *MessagePool) Remove(from address.Address, nonce uint64, applied bool) {
|
||||
mp.lk.Lock()
|
||||
defer mp.lk.Unlock()
|
||||
|
124
chain/messagesigner/messagesigner.go
Normal file
124
chain/messagesigner/messagesigner.go
Normal file
@ -0,0 +1,124 @@
|
||||
package messagesigner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
||||
"github.com/filecoin-project/go-address"
|
||||
"github.com/filecoin-project/lotus/chain/messagepool"
|
||||
"github.com/filecoin-project/lotus/chain/types"
|
||||
"github.com/filecoin-project/lotus/chain/wallet"
|
||||
"github.com/filecoin-project/lotus/node/modules/dtypes"
|
||||
"github.com/ipfs/go-datastore"
|
||||
"github.com/ipfs/go-datastore/namespace"
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
cbg "github.com/whyrusleeping/cbor-gen"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
const dsKeyActorNonce = "ActorNonce"
|
||||
|
||||
var log = logging.Logger("messagesigner")
|
||||
|
||||
type mpoolAPI interface {
|
||||
GetNonce(address.Address) (uint64, error)
|
||||
}
|
||||
|
||||
// MessageSigner keeps track of nonces per address, and increments the nonce
|
||||
// when signing a message
|
||||
type MessageSigner struct {
|
||||
wallet *wallet.Wallet
|
||||
mpool mpoolAPI
|
||||
ds datastore.Batching
|
||||
}
|
||||
|
||||
func NewMessageSigner(wallet *wallet.Wallet, mpool *messagepool.MessagePool, ds dtypes.MetadataDS) *MessageSigner {
|
||||
return newMessageSigner(wallet, mpool, ds)
|
||||
}
|
||||
|
||||
func newMessageSigner(wallet *wallet.Wallet, mpool mpoolAPI, ds dtypes.MetadataDS) *MessageSigner {
|
||||
ds = namespace.Wrap(ds, datastore.NewKey("/message-signer/"))
|
||||
return &MessageSigner{
|
||||
wallet: wallet,
|
||||
mpool: mpool,
|
||||
ds: ds,
|
||||
}
|
||||
}
|
||||
|
||||
// SignMessage increments the nonce for the message From address, and signs
|
||||
// the message
|
||||
func (ms *MessageSigner) SignMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) {
|
||||
nonce, err := ms.nextNonce(msg.From)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to create nonce: %w", err)
|
||||
}
|
||||
|
||||
msg.Nonce = nonce
|
||||
sig, err := ms.wallet.Sign(ctx, msg.From, msg.Cid().Bytes())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to sign message: %w", err)
|
||||
}
|
||||
|
||||
return &types.SignedMessage{
|
||||
Message: *msg,
|
||||
Signature: *sig,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// nextNonce increments the nonce.
|
||||
// If there is no nonce in the datastore, gets the nonce from the message pool.
|
||||
func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) {
|
||||
// Nonces used to be created by the mempool and we need to support nodes
|
||||
// that have mempool nonces, so first check the mempool for a nonce for
|
||||
// this address. Note that the mempool returns the actor state's nonce
|
||||
// by default.
|
||||
nonce, err := ms.mpool.GetNonce(addr)
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("failed to get nonce from mempool: %w", err)
|
||||
}
|
||||
|
||||
// Get the nonce for this address from the datastore
|
||||
addrNonceKey := datastore.KeyWithNamespaces([]string{dsKeyActorNonce, addr.String()})
|
||||
dsNonceBytes, err := ms.ds.Get(addrNonceKey)
|
||||
|
||||
switch {
|
||||
case xerrors.Is(err, datastore.ErrNotFound):
|
||||
// If a nonce for this address hasn't yet been created in the
|
||||
// datastore, just use the nonce from the mempool
|
||||
|
||||
case err != nil:
|
||||
return 0, xerrors.Errorf("failed to get nonce from datastore: %w", err)
|
||||
|
||||
default:
|
||||
// There is a nonce in the datastore, so unmarshall and increment it
|
||||
maj, val, err := cbg.CborReadHeader(bytes.NewReader(dsNonceBytes))
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("failed to parse nonce from datastore: %w", err)
|
||||
}
|
||||
if maj != cbg.MajUnsignedInt {
|
||||
return 0, xerrors.Errorf("bad cbor type parsing nonce from datastore")
|
||||
}
|
||||
|
||||
dsNonce := val + 1
|
||||
|
||||
// The message pool nonce should be <= than the datastore nonce
|
||||
if nonce <= dsNonce {
|
||||
nonce = dsNonce
|
||||
} else {
|
||||
log.Warnf("mempool nonce was larger than datastore nonce (%d > %d)", nonce, dsNonce)
|
||||
}
|
||||
}
|
||||
|
||||
// Write the nonce for this address to the datastore
|
||||
buf := bytes.Buffer{}
|
||||
_, err = buf.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, nonce))
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("failed to marshall nonce: %w", err)
|
||||
}
|
||||
err = ms.ds.Put(addrNonceKey, buf.Bytes())
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("failed to write nonce to datastore: %w", err)
|
||||
}
|
||||
|
||||
return nonce, nil
|
||||
}
|
158
chain/messagesigner/messagesigner_test.go
Normal file
158
chain/messagesigner/messagesigner_test.go
Normal file
@ -0,0 +1,158 @@
|
||||
package messagesigner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/filecoin-project/lotus/chain/wallet"
|
||||
|
||||
"github.com/filecoin-project/go-state-types/crypto"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
ds_sync "github.com/ipfs/go-datastore/sync"
|
||||
|
||||
"github.com/filecoin-project/go-address"
|
||||
|
||||
"github.com/filecoin-project/lotus/chain/types"
|
||||
"github.com/ipfs/go-datastore"
|
||||
)
|
||||
|
||||
type mockMpool struct {
|
||||
lk sync.RWMutex
|
||||
nonces map[address.Address]uint64
|
||||
}
|
||||
|
||||
func newMockMpool() *mockMpool {
|
||||
return &mockMpool{nonces: make(map[address.Address]uint64)}
|
||||
}
|
||||
|
||||
func (mp *mockMpool) setNonce(addr address.Address, nonce uint64) {
|
||||
mp.lk.Lock()
|
||||
defer mp.lk.Unlock()
|
||||
|
||||
mp.nonces[addr] = nonce
|
||||
}
|
||||
|
||||
func (mp *mockMpool) GetNonce(addr address.Address) (uint64, error) {
|
||||
mp.lk.RLock()
|
||||
defer mp.lk.RUnlock()
|
||||
|
||||
return mp.nonces[addr], nil
|
||||
}
|
||||
|
||||
func TestMessageSignerSignMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
w, _ := wallet.NewWallet(wallet.NewMemKeyStore())
|
||||
from1, err := w.GenerateKey(crypto.SigTypeSecp256k1)
|
||||
require.NoError(t, err)
|
||||
from2, err := w.GenerateKey(crypto.SigTypeSecp256k1)
|
||||
require.NoError(t, err)
|
||||
to1, err := w.GenerateKey(crypto.SigTypeSecp256k1)
|
||||
require.NoError(t, err)
|
||||
to2, err := w.GenerateKey(crypto.SigTypeSecp256k1)
|
||||
require.NoError(t, err)
|
||||
|
||||
type msgSpec struct {
|
||||
msg *types.Message
|
||||
mpoolNonce [1]uint64
|
||||
expNonce uint64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
msgs []msgSpec
|
||||
}{{
|
||||
// No nonce yet in datastore
|
||||
name: "no nonce yet",
|
||||
msgs: []msgSpec{{
|
||||
msg: &types.Message{
|
||||
To: to1,
|
||||
From: from1,
|
||||
},
|
||||
expNonce: 0,
|
||||
}},
|
||||
}, {
|
||||
// Get nonce value of zero from mpool
|
||||
name: "mpool nonce zero",
|
||||
msgs: []msgSpec{{
|
||||
msg: &types.Message{
|
||||
To: to1,
|
||||
From: from1,
|
||||
},
|
||||
mpoolNonce: [1]uint64{0},
|
||||
expNonce: 0,
|
||||
}},
|
||||
}, {
|
||||
// Get non-zero nonce value from mpool
|
||||
name: "mpool nonce set",
|
||||
msgs: []msgSpec{{
|
||||
msg: &types.Message{
|
||||
To: to1,
|
||||
From: from1,
|
||||
},
|
||||
mpoolNonce: [1]uint64{5},
|
||||
expNonce: 5,
|
||||
}, {
|
||||
msg: &types.Message{
|
||||
To: to1,
|
||||
From: from1,
|
||||
},
|
||||
// Should adjust datastore nonce because mpool nonce is higher
|
||||
mpoolNonce: [1]uint64{10},
|
||||
expNonce: 10,
|
||||
}},
|
||||
}, {
|
||||
// Nonce should increment independently for each address
|
||||
name: "nonce increments per address",
|
||||
msgs: []msgSpec{{
|
||||
msg: &types.Message{
|
||||
To: to1,
|
||||
From: from1,
|
||||
},
|
||||
expNonce: 0,
|
||||
}, {
|
||||
msg: &types.Message{
|
||||
To: to1,
|
||||
From: from1,
|
||||
},
|
||||
expNonce: 1,
|
||||
}, {
|
||||
msg: &types.Message{
|
||||
To: to2,
|
||||
From: from2,
|
||||
},
|
||||
mpoolNonce: [1]uint64{5},
|
||||
expNonce: 5,
|
||||
}, {
|
||||
msg: &types.Message{
|
||||
To: to2,
|
||||
From: from2,
|
||||
},
|
||||
expNonce: 6,
|
||||
}, {
|
||||
msg: &types.Message{
|
||||
To: to1,
|
||||
From: from1,
|
||||
},
|
||||
expNonce: 2,
|
||||
}},
|
||||
}}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mpool := newMockMpool()
|
||||
ds := ds_sync.MutexWrap(datastore.NewMapDatastore())
|
||||
ms := newMessageSigner(w, mpool, ds)
|
||||
|
||||
for _, m := range tt.msgs {
|
||||
if len(m.mpoolNonce) == 1 {
|
||||
mpool.setNonce(m.msg.From, m.mpoolNonce[0])
|
||||
}
|
||||
smsg, err := ms.SignMessage(ctx, m.msg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, m.expNonce, smsg.Message.Nonce)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -35,6 +35,7 @@ import (
|
||||
"github.com/filecoin-project/lotus/chain/gen/slashfilter"
|
||||
"github.com/filecoin-project/lotus/chain/market"
|
||||
"github.com/filecoin-project/lotus/chain/messagepool"
|
||||
"github.com/filecoin-project/lotus/chain/messagesigner"
|
||||
"github.com/filecoin-project/lotus/chain/metrics"
|
||||
"github.com/filecoin-project/lotus/chain/stmgr"
|
||||
"github.com/filecoin-project/lotus/chain/store"
|
||||
@ -259,6 +260,7 @@ func Online() Option {
|
||||
Override(new(*store.ChainStore), modules.ChainStore),
|
||||
Override(new(*stmgr.StateManager), stmgr.NewStateManager),
|
||||
Override(new(*wallet.Wallet), wallet.NewWallet),
|
||||
Override(new(*messagesigner.MessageSigner), messagesigner.NewMessageSigner),
|
||||
|
||||
Override(new(dtypes.ChainGCLocker), blockstore.NewGCLocker),
|
||||
Override(new(dtypes.ChainGCBlockstore), modules.ChainGCBlockstore),
|
||||
|
@ -10,8 +10,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/filecoin-project/lotus/api"
|
||||
"github.com/filecoin-project/lotus/chain/messagepool"
|
||||
"github.com/filecoin-project/lotus/chain/store"
|
||||
"github.com/filecoin-project/lotus/chain/messagesigner"
|
||||
"github.com/filecoin-project/lotus/chain/types"
|
||||
"github.com/filecoin-project/lotus/node/modules/dtypes"
|
||||
)
|
||||
@ -22,9 +21,7 @@ type MpoolAPI struct {
|
||||
WalletAPI
|
||||
GasAPI
|
||||
|
||||
Chain *store.ChainStore
|
||||
|
||||
Mpool *messagepool.MessagePool
|
||||
MessageSigner *messagesigner.MessageSigner
|
||||
|
||||
PushLocks *dtypes.MpoolLocker
|
||||
}
|
||||
@ -114,12 +111,14 @@ func (a *MpoolAPI) MpoolPush(ctx context.Context, smsg *types.SignedMessage) (ci
|
||||
}
|
||||
|
||||
func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spec *api.MessageSendSpec) (*types.SignedMessage, error) {
|
||||
cp := *msg
|
||||
msg = &cp
|
||||
inMsg := *msg
|
||||
fromA, err := a.Stmgr.ResolveToKeyAddress(ctx, msg.From, nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("getting key address: %w", err)
|
||||
}
|
||||
{
|
||||
fromA, err := a.Stmgr.ResolveToKeyAddress(ctx, msg.From, nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("getting key address: %w", err)
|
||||
}
|
||||
done, err := a.PushLocks.TakeLock(ctx, fromA)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("taking lock: %w", err)
|
||||
@ -131,7 +130,7 @@ func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spe
|
||||
return nil, xerrors.Errorf("MpoolPushMessage expects message nonce to be 0, was %d", msg.Nonce)
|
||||
}
|
||||
|
||||
msg, err := a.GasAPI.GasEstimateMessageGas(ctx, msg, spec, types.EmptyTSK)
|
||||
msg, err = a.GasAPI.GasEstimateMessageGas(ctx, msg, spec, types.EmptyTSK)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("GasEstimateMessageGas error: %w", err)
|
||||
}
|
||||
@ -143,33 +142,31 @@ func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spe
|
||||
inJson, outJson)
|
||||
}
|
||||
|
||||
sign := func(from address.Address, nonce uint64) (*types.SignedMessage, error) {
|
||||
msg.Nonce = nonce
|
||||
if msg.From.Protocol() == address.ID {
|
||||
log.Warnf("Push from ID address (%s), adjusting to %s", msg.From, from)
|
||||
msg.From = from
|
||||
}
|
||||
|
||||
b, err := a.WalletBalance(ctx, msg.From)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("mpool push: getting origin balance: %w", err)
|
||||
}
|
||||
|
||||
if b.LessThan(msg.Value) {
|
||||
return nil, xerrors.Errorf("mpool push: not enough funds: %s < %s", b, msg.Value)
|
||||
}
|
||||
|
||||
return a.WalletSignMessage(ctx, from, msg)
|
||||
if msg.From.Protocol() == address.ID {
|
||||
log.Warnf("Push from ID address (%s), adjusting to %s", msg.From, fromA)
|
||||
msg.From = fromA
|
||||
}
|
||||
|
||||
var m *types.SignedMessage
|
||||
again:
|
||||
m, err = a.Mpool.PushWithNonce(ctx, msg.From, sign)
|
||||
if err == messagepool.ErrTryAgain {
|
||||
log.Debugf("temporary failure while pushing message: %s; retrying", err)
|
||||
goto again
|
||||
b, err := a.WalletBalance(ctx, msg.From)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("mpool push: getting origin balance: %w", err)
|
||||
}
|
||||
return m, err
|
||||
|
||||
if b.LessThan(msg.Value) {
|
||||
return nil, xerrors.Errorf("mpool push: not enough funds: %s < %s", b, msg.Value)
|
||||
}
|
||||
|
||||
smsg, err := a.MessageSigner.SignMessage(ctx, msg)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("mpool push: failed to sign message: %w", err)
|
||||
}
|
||||
|
||||
_, err = a.Mpool.Push(smsg)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("mpool push: failed to push message: %w", err)
|
||||
}
|
||||
|
||||
return smsg, err
|
||||
}
|
||||
|
||||
func (a *MpoolAPI) MpoolGetNonce(ctx context.Context, addr address.Address) (uint64, error) {
|
||||
|
Loading…
Reference in New Issue
Block a user