From 3c524ac0e0d806468e1e3a28437e94115089d477 Mon Sep 17 00:00:00 2001 From: Dirk McCormick Date: Wed, 23 Sep 2020 18:28:11 +0200 Subject: [PATCH] refactor: move nonce generation from mpool to wallet --- chain/messagepool/messagepool.go | 94 ------------- chain/messagesigner/messagesigner.go | 116 ++++++++++++++++ chain/messagesigner/messagesigner_test.go | 159 ++++++++++++++++++++++ node/builder.go | 3 + node/impl/full/mpool.go | 66 +++++---- 5 files changed, 310 insertions(+), 128 deletions(-) create mode 100644 chain/messagesigner/messagesigner.go create mode 100644 chain/messagesigner/messagesigner_test.go diff --git a/chain/messagepool/messagepool.go b/chain/messagepool/messagepool.go index 96900925f..d54ea7164 100644 --- a/chain/messagepool/messagepool.go +++ b/chain/messagepool/messagepool.go @@ -75,8 +75,6 @@ var ( ErrRBFTooLowPremium = errors.New("replace by fee has too low GasPremium") ErrTooManyPendingMessages = errors.New("too many pending messages for actor") ErrNonceGap = errors.New("unfulfilled nonce gap") - - ErrTryAgain = errors.New("state inconsistency while pushing message; please try again") ) const ( @@ -795,98 +793,6 @@ func (mp *MessagePool) getStateBalance(addr address.Address, ts *types.TipSet) ( return act.Balance, nil } -func (mp *MessagePool) PushWithNonce(ctx context.Context, addr address.Address, cb func(address.Address, uint64) (*types.SignedMessage, error)) (*types.SignedMessage, error) { - // serialize push access to reduce lock contention - mp.addSema <- struct{}{} - defer func() { - <-mp.addSema - }() - - mp.curTsLk.Lock() - mp.lk.Lock() - - curTs := mp.curTs - - fromKey := addr - if fromKey.Protocol() == address.ID { - var err error - fromKey, err = mp.api.StateAccountKey(ctx, fromKey, mp.curTs) - if err != nil { - mp.lk.Unlock() - mp.curTsLk.Unlock() - return nil, xerrors.Errorf("resolving sender key: %w", err) - } - } - - nonce, err := mp.getNonceLocked(fromKey, mp.curTs) - if err != nil { - mp.lk.Unlock() - mp.curTsLk.Unlock() - return nil, xerrors.Errorf("get nonce locked failed: %w", err) - } - - // release the locks for signing - mp.lk.Unlock() - mp.curTsLk.Unlock() - - msg, err := cb(fromKey, nonce) - if err != nil { - return nil, err - } - - err = mp.checkMessage(msg) - if err != nil { - return nil, err - } - - msgb, err := msg.Serialize() - if err != nil { - return nil, err - } - - // reacquire the locks and check state for consistency - mp.curTsLk.Lock() - defer mp.curTsLk.Unlock() - - if mp.curTs != curTs { - return nil, ErrTryAgain - } - - mp.lk.Lock() - defer mp.lk.Unlock() - - nonce2, err := mp.getNonceLocked(fromKey, mp.curTs) - if err != nil { - return nil, xerrors.Errorf("get nonce locked failed: %w", err) - } - - if nonce2 != nonce { - return nil, ErrTryAgain - } - - publish, err := mp.verifyMsgBeforeAdd(msg, curTs, true) - if err != nil { - return nil, err - } - - if err := mp.checkBalance(msg, curTs); err != nil { - return nil, err - } - - if err := mp.addLocked(msg, false); err != nil { - return nil, xerrors.Errorf("add locked failed: %w", err) - } - if err := mp.addLocal(msg, msgb); err != nil { - log.Errorf("addLocal failed: %+v", err) - } - - if publish { - err = mp.api.PubSubPublish(build.MessagesTopic(mp.netName), msgb) - } - - return msg, err -} - func (mp *MessagePool) Remove(from address.Address, nonce uint64, applied bool) { mp.lk.Lock() defer mp.lk.Unlock() diff --git a/chain/messagesigner/messagesigner.go b/chain/messagesigner/messagesigner.go new file mode 100644 index 000000000..41b0edee9 --- /dev/null +++ b/chain/messagesigner/messagesigner.go @@ -0,0 +1,116 @@ +package messagesigner + +import ( + "bytes" + "context" + + "github.com/filecoin-project/lotus/chain/wallet" + + "github.com/filecoin-project/lotus/chain/messagepool" + + "github.com/filecoin-project/go-address" + "github.com/filecoin-project/lotus/chain/types" + "github.com/filecoin-project/lotus/node/modules/dtypes" + "github.com/ipfs/go-datastore" + "github.com/ipfs/go-datastore/namespace" + cbg "github.com/whyrusleeping/cbor-gen" + "golang.org/x/xerrors" +) + +const dsKeyActorNonce = "ActorNonce" + +type mpoolAPI interface { + GetNonce(address.Address) (uint64, error) +} + +// MessageSigner keeps track of nonces per address, and increments the nonce +// when signing a message +type MessageSigner struct { + wallet *wallet.Wallet + mpool mpoolAPI + ds datastore.Batching +} + +func NewMessageSigner(wallet *wallet.Wallet, mpool *messagepool.MessagePool, ds dtypes.MetadataDS) *MessageSigner { + return newMessageSigner(wallet, mpool, ds) +} + +func newMessageSigner(wallet *wallet.Wallet, mpool mpoolAPI, ds dtypes.MetadataDS) *MessageSigner { + ds = namespace.Wrap(ds, datastore.NewKey("/message-signer/")) + return &MessageSigner{ + wallet: wallet, + mpool: mpool, + ds: ds, + } +} + +// SignMessage increments the nonce for the message From address, and signs +// the message +func (ms *MessageSigner) SignMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) { + nonce, err := ms.nextNonce(msg.From) + if err != nil { + return nil, xerrors.Errorf("failed to create nonce: %w", err) + } + + msg.Nonce = nonce + sig, err := ms.wallet.Sign(ctx, msg.From, msg.Cid().Bytes()) + if err != nil { + return nil, xerrors.Errorf("failed to sign message: %w", err) + } + + return &types.SignedMessage{ + Message: *msg, + Signature: *sig, + }, nil +} + +// nextNonce increments the nonce. +// If there is no nonce in the datastore, gets the nonce from the message pool. +func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) { + addrNonceKey := datastore.KeyWithNamespaces([]string{dsKeyActorNonce, addr.String()}) + + // Get the nonce for this address from the datastore + nonceBytes, err := ms.ds.Get(addrNonceKey) + + var nonce uint64 + switch { + case xerrors.Is(err, datastore.ErrNotFound): + // If a nonce for this address hasn't yet been created in the + // datastore, check the mempool - nonces used to be created by + // the mempool so we need to support nodes that still have mempool + // nonces. Note that the mempool returns the actor state's nonce by + // default. + nonce, err = ms.mpool.GetNonce(addr) + if err != nil { + return 0, xerrors.Errorf("failed to get nonce from mempool: %w", err) + } + + case err != nil: + return 0, xerrors.Errorf("failed to get nonce from datastore: %w", err) + + default: + // There is a nonce in the mempool, so unmarshall and increment it + maj, val, err := cbg.CborReadHeader(bytes.NewReader(nonceBytes)) + if err != nil { + return 0, xerrors.Errorf("failed to parse nonce from datastore: %w", err) + } + if maj != cbg.MajUnsignedInt { + return 0, xerrors.Errorf("bad cbor type parsing nonce from datastore") + } + + nonce = val + 1 + } + + // Write the nonce for this address to the datastore + buf := bytes.Buffer{} + _, err = buf.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, nonce)) + if err != nil { + return 0, xerrors.Errorf("failed to marshall nonce: %w", err) + } + err = ms.ds.Put(addrNonceKey, buf.Bytes()) + if err != nil { + return 0, xerrors.Errorf("failed to write nonce to datastore: %w", err) + } + + return nonce, nil +} diff --git a/chain/messagesigner/messagesigner_test.go b/chain/messagesigner/messagesigner_test.go new file mode 100644 index 000000000..e52137892 --- /dev/null +++ b/chain/messagesigner/messagesigner_test.go @@ -0,0 +1,159 @@ +package messagesigner + +import ( + "context" + "sync" + "testing" + + "github.com/filecoin-project/lotus/chain/wallet" + + "github.com/filecoin-project/go-state-types/crypto" + "github.com/stretchr/testify/require" + + ds_sync "github.com/ipfs/go-datastore/sync" + + "github.com/filecoin-project/go-address" + + "github.com/filecoin-project/lotus/chain/types" + "github.com/ipfs/go-datastore" +) + +type mockMpool struct { + lk sync.RWMutex + nonces map[address.Address]uint64 +} + +func newMockMpool() *mockMpool { + return &mockMpool{nonces: make(map[address.Address]uint64)} +} + +func (mp *mockMpool) setNonce(addr address.Address, nonce uint64) { + mp.lk.Lock() + defer mp.lk.Unlock() + + mp.nonces[addr] = nonce +} + +func (mp *mockMpool) GetNonce(addr address.Address) (uint64, error) { + mp.lk.RLock() + defer mp.lk.RUnlock() + + return mp.nonces[addr], nil +} + +func TestMessageSignerSignMessage(t *testing.T) { + ctx := context.Background() + + w, _ := wallet.NewWallet(wallet.NewMemKeyStore()) + from1, err := w.GenerateKey(crypto.SigTypeSecp256k1) + require.NoError(t, err) + from2, err := w.GenerateKey(crypto.SigTypeSecp256k1) + require.NoError(t, err) + to1, err := w.GenerateKey(crypto.SigTypeSecp256k1) + require.NoError(t, err) + to2, err := w.GenerateKey(crypto.SigTypeSecp256k1) + require.NoError(t, err) + + type msgSpec struct { + msg *types.Message + mpoolNonce [1]uint64 + expNonce uint64 + } + tests := []struct { + name string + msgs []msgSpec + }{{ + // No nonce yet in datastore + name: "no nonce yet", + msgs: []msgSpec{{ + msg: &types.Message{ + To: to1, + From: from1, + }, + expNonce: 0, + }}, + }, { + // Get nonce value of zero from mpool + name: "mpool nonce zero", + msgs: []msgSpec{{ + msg: &types.Message{ + To: to1, + From: from1, + }, + mpoolNonce: [1]uint64{0}, + expNonce: 0, + }}, + }, { + // Get non-zero nonce value from mpool + name: "mpool nonce set", + msgs: []msgSpec{{ + msg: &types.Message{ + To: to1, + From: from1, + }, + mpoolNonce: [1]uint64{5}, + expNonce: 5, + }, { + msg: &types.Message{ + To: to1, + From: from1, + }, + // Should ignore mpool nonce because after the first message nonce + // will come from the datastore + mpoolNonce: [1]uint64{10}, + expNonce: 6, + }}, + }, { + // Nonce should increment independently for each address + name: "nonce increments per address", + msgs: []msgSpec{{ + msg: &types.Message{ + To: to1, + From: from1, + }, + expNonce: 0, + }, { + msg: &types.Message{ + To: to1, + From: from1, + }, + expNonce: 1, + }, { + msg: &types.Message{ + To: to2, + From: from2, + }, + mpoolNonce: [1]uint64{5}, + expNonce: 5, + }, { + msg: &types.Message{ + To: to2, + From: from2, + }, + expNonce: 6, + }, { + msg: &types.Message{ + To: to1, + From: from1, + }, + expNonce: 2, + }}, + }} + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + mpool := newMockMpool() + ds := ds_sync.MutexWrap(datastore.NewMapDatastore()) + ms := newMessageSigner(w, mpool, ds) + + for _, m := range tt.msgs { + if len(m.mpoolNonce) == 1 { + mpool.setNonce(m.msg.From, m.mpoolNonce[0]) + } + smsg, err := ms.SignMessage(ctx, m.msg) + require.NoError(t, err) + require.Equal(t, m.expNonce, smsg.Message.Nonce) + } + }) + } +} diff --git a/node/builder.go b/node/builder.go index c37a5db58..c49789a6a 100644 --- a/node/builder.go +++ b/node/builder.go @@ -6,6 +6,8 @@ import ( "os" "time" + "github.com/filecoin-project/lotus/chain/messagesigner" + logging "github.com/ipfs/go-log" ci "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/host" @@ -259,6 +261,7 @@ func Online() Option { Override(new(*store.ChainStore), modules.ChainStore), Override(new(*stmgr.StateManager), stmgr.NewStateManager), Override(new(*wallet.Wallet), wallet.NewWallet), + Override(new(*messagesigner.MessageSigner), messagesigner.NewMessageSigner), Override(new(dtypes.ChainGCLocker), blockstore.NewGCLocker), Override(new(dtypes.ChainGCBlockstore), modules.ChainGCBlockstore), diff --git a/node/impl/full/mpool.go b/node/impl/full/mpool.go index 6acb17990..003260496 100644 --- a/node/impl/full/mpool.go +++ b/node/impl/full/mpool.go @@ -4,14 +4,14 @@ import ( "context" "encoding/json" + "github.com/filecoin-project/lotus/chain/messagesigner" + "github.com/filecoin-project/go-address" "github.com/ipfs/go-cid" "go.uber.org/fx" "golang.org/x/xerrors" "github.com/filecoin-project/lotus/api" - "github.com/filecoin-project/lotus/chain/messagepool" - "github.com/filecoin-project/lotus/chain/store" "github.com/filecoin-project/lotus/chain/types" "github.com/filecoin-project/lotus/node/modules/dtypes" ) @@ -22,9 +22,7 @@ type MpoolAPI struct { WalletAPI GasAPI - Chain *store.ChainStore - - Mpool *messagepool.MessagePool + MessageSigner *messagesigner.MessageSigner PushLocks *dtypes.MpoolLocker } @@ -114,12 +112,14 @@ func (a *MpoolAPI) MpoolPush(ctx context.Context, smsg *types.SignedMessage) (ci } func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spec *api.MessageSendSpec) (*types.SignedMessage, error) { + cp := *msg + msg = &cp inMsg := *msg + fromA, err := a.Stmgr.ResolveToKeyAddress(ctx, msg.From, nil) + if err != nil { + return nil, xerrors.Errorf("getting key address: %w", err) + } { - fromA, err := a.Stmgr.ResolveToKeyAddress(ctx, msg.From, nil) - if err != nil { - return nil, xerrors.Errorf("getting key address: %w", err) - } done, err := a.PushLocks.TakeLock(ctx, fromA) if err != nil { return nil, xerrors.Errorf("taking lock: %w", err) @@ -131,7 +131,7 @@ func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spe return nil, xerrors.Errorf("MpoolPushMessage expects message nonce to be 0, was %d", msg.Nonce) } - msg, err := a.GasAPI.GasEstimateMessageGas(ctx, msg, spec, types.EmptyTSK) + msg, err = a.GasAPI.GasEstimateMessageGas(ctx, msg, spec, types.EmptyTSK) if err != nil { return nil, xerrors.Errorf("GasEstimateMessageGas error: %w", err) } @@ -143,33 +143,31 @@ func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spe inJson, outJson) } - sign := func(from address.Address, nonce uint64) (*types.SignedMessage, error) { - msg.Nonce = nonce - if msg.From.Protocol() == address.ID { - log.Warnf("Push from ID address (%s), adjusting to %s", msg.From, from) - msg.From = from - } - - b, err := a.WalletBalance(ctx, msg.From) - if err != nil { - return nil, xerrors.Errorf("mpool push: getting origin balance: %w", err) - } - - if b.LessThan(msg.Value) { - return nil, xerrors.Errorf("mpool push: not enough funds: %s < %s", b, msg.Value) - } - - return a.WalletSignMessage(ctx, from, msg) + if msg.From.Protocol() == address.ID { + log.Warnf("Push from ID address (%s), adjusting to %s", msg.From, fromA) + msg.From = fromA } - var m *types.SignedMessage -again: - m, err = a.Mpool.PushWithNonce(ctx, msg.From, sign) - if err == messagepool.ErrTryAgain { - log.Debugf("temporary failure while pushing message: %s; retrying", err) - goto again + b, err := a.WalletBalance(ctx, msg.From) + if err != nil { + return nil, xerrors.Errorf("mpool push: getting origin balance: %w", err) } - return m, err + + if b.LessThan(msg.Value) { + return nil, xerrors.Errorf("mpool push: not enough funds: %s < %s", b, msg.Value) + } + + smsg, err := a.MessageSigner.SignMessage(ctx, msg) + if err != nil { + return nil, xerrors.Errorf("mpool push: failed to sign message: %w", err) + } + + _, err = a.Mpool.Push(smsg) + if err != nil { + return nil, xerrors.Errorf("mpool push: failed to push message: %w", err) + } + + return smsg, err } func (a *MpoolAPI) MpoolGetNonce(ctx context.Context, addr address.Address) (uint64, error) {