Merge pull request #4165 from filecoin-project/fix/message-signer-nonce-generation
fix: make message signer nonce generation transactional
This commit is contained in:
commit
fee84c4b44
@ -3,6 +3,7 @@ package messagesigner
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/filecoin-project/go-address"
|
"github.com/filecoin-project/go-address"
|
||||||
"github.com/filecoin-project/lotus/chain/messagepool"
|
"github.com/filecoin-project/lotus/chain/messagepool"
|
||||||
@ -16,7 +17,7 @@ import (
|
|||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
)
|
)
|
||||||
|
|
||||||
const dsKeyActorNonce = "ActorNonce"
|
const dsKeyActorNonce = "ActorNextNonce"
|
||||||
|
|
||||||
var log = logging.Logger("messagesigner")
|
var log = logging.Logger("messagesigner")
|
||||||
|
|
||||||
@ -28,6 +29,7 @@ type mpoolAPI interface {
|
|||||||
// when signing a message
|
// when signing a message
|
||||||
type MessageSigner struct {
|
type MessageSigner struct {
|
||||||
wallet *wallet.Wallet
|
wallet *wallet.Wallet
|
||||||
|
lk sync.Mutex
|
||||||
mpool mpoolAPI
|
mpool mpoolAPI
|
||||||
ds datastore.Batching
|
ds datastore.Batching
|
||||||
}
|
}
|
||||||
@ -47,25 +49,42 @@ func newMessageSigner(wallet *wallet.Wallet, mpool mpoolAPI, ds dtypes.MetadataD
|
|||||||
|
|
||||||
// SignMessage increments the nonce for the message From address, and signs
|
// SignMessage increments the nonce for the message From address, and signs
|
||||||
// the message
|
// the message
|
||||||
func (ms *MessageSigner) SignMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) {
|
func (ms *MessageSigner) SignMessage(ctx context.Context, msg *types.Message, cb func(*types.SignedMessage) error) (*types.SignedMessage, error) {
|
||||||
|
ms.lk.Lock()
|
||||||
|
defer ms.lk.Unlock()
|
||||||
|
|
||||||
|
// Get the next message nonce
|
||||||
nonce, err := ms.nextNonce(msg.From)
|
nonce, err := ms.nextNonce(msg.From)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, xerrors.Errorf("failed to create nonce: %w", err)
|
return nil, xerrors.Errorf("failed to create nonce: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sign the message with the nonce
|
||||||
msg.Nonce = nonce
|
msg.Nonce = nonce
|
||||||
sig, err := ms.wallet.Sign(ctx, msg.From, msg.Cid().Bytes())
|
sig, err := ms.wallet.Sign(ctx, msg.From, msg.Cid().Bytes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, xerrors.Errorf("failed to sign message: %w", err)
|
return nil, xerrors.Errorf("failed to sign message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &types.SignedMessage{
|
// Callback with the signed message
|
||||||
|
smsg := &types.SignedMessage{
|
||||||
Message: *msg,
|
Message: *msg,
|
||||||
Signature: *sig,
|
Signature: *sig,
|
||||||
}, nil
|
}
|
||||||
|
err = cb(smsg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// nextNonce increments the nonce.
|
// If the callback executed successfully, write the nonce to the datastore
|
||||||
|
if err := ms.saveNonce(msg.From, nonce); err != nil {
|
||||||
|
return nil, xerrors.Errorf("failed to save nonce: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return smsg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nextNonce gets the next nonce for the given address.
|
||||||
// If there is no nonce in the datastore, gets the nonce from the message pool.
|
// If there is no nonce in the datastore, gets the nonce from the message pool.
|
||||||
func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) {
|
func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) {
|
||||||
// Nonces used to be created by the mempool and we need to support nodes
|
// Nonces used to be created by the mempool and we need to support nodes
|
||||||
@ -77,21 +96,22 @@ func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) {
|
|||||||
return 0, xerrors.Errorf("failed to get nonce from mempool: %w", err)
|
return 0, xerrors.Errorf("failed to get nonce from mempool: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the nonce for this address from the datastore
|
// Get the next nonce for this address from the datastore
|
||||||
addrNonceKey := datastore.KeyWithNamespaces([]string{dsKeyActorNonce, addr.String()})
|
addrNonceKey := ms.dstoreKey(addr)
|
||||||
dsNonceBytes, err := ms.ds.Get(addrNonceKey)
|
dsNonceBytes, err := ms.ds.Get(addrNonceKey)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case xerrors.Is(err, datastore.ErrNotFound):
|
case xerrors.Is(err, datastore.ErrNotFound):
|
||||||
// If a nonce for this address hasn't yet been created in the
|
// If a nonce for this address hasn't yet been created in the
|
||||||
// datastore, just use the nonce from the mempool
|
// datastore, just use the nonce from the mempool
|
||||||
|
return nonce, nil
|
||||||
|
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return 0, xerrors.Errorf("failed to get nonce from datastore: %w", err)
|
return 0, xerrors.Errorf("failed to get nonce from datastore: %w", err)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// There is a nonce in the datastore, so unmarshall and increment it
|
// There is a nonce in the datastore, so unmarshall it
|
||||||
maj, val, err := cbg.CborReadHeader(bytes.NewReader(dsNonceBytes))
|
maj, dsNonce, err := cbg.CborReadHeader(bytes.NewReader(dsNonceBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, xerrors.Errorf("failed to parse nonce from datastore: %w", err)
|
return 0, xerrors.Errorf("failed to parse nonce from datastore: %w", err)
|
||||||
}
|
}
|
||||||
@ -99,26 +119,37 @@ func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) {
|
|||||||
return 0, xerrors.Errorf("bad cbor type parsing nonce from datastore")
|
return 0, xerrors.Errorf("bad cbor type parsing nonce from datastore")
|
||||||
}
|
}
|
||||||
|
|
||||||
dsNonce := val + 1
|
|
||||||
|
|
||||||
// The message pool nonce should be <= than the datastore nonce
|
// The message pool nonce should be <= than the datastore nonce
|
||||||
if nonce <= dsNonce {
|
if nonce <= dsNonce {
|
||||||
nonce = dsNonce
|
nonce = dsNonce
|
||||||
} else {
|
} else {
|
||||||
log.Warnf("mempool nonce was larger than datastore nonce (%d > %d)", nonce, dsNonce)
|
log.Warnf("mempool nonce was larger than datastore nonce (%d > %d)", nonce, dsNonce)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Write the nonce for this address to the datastore
|
|
||||||
buf := bytes.Buffer{}
|
|
||||||
_, err = buf.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, nonce))
|
|
||||||
if err != nil {
|
|
||||||
return 0, xerrors.Errorf("failed to marshall nonce: %w", err)
|
|
||||||
}
|
|
||||||
err = ms.ds.Put(addrNonceKey, buf.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return 0, xerrors.Errorf("failed to write nonce to datastore: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nonce, nil
|
return nonce, nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveNonce increments the nonce for this address and writes it to the
|
||||||
|
// datastore
|
||||||
|
func (ms *MessageSigner) saveNonce(addr address.Address, nonce uint64) error {
|
||||||
|
// Increment the nonce
|
||||||
|
nonce++
|
||||||
|
|
||||||
|
// Write the nonce to the datastore
|
||||||
|
addrNonceKey := ms.dstoreKey(addr)
|
||||||
|
buf := bytes.Buffer{}
|
||||||
|
_, err := buf.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, nonce))
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("failed to marshall nonce: %w", err)
|
||||||
|
}
|
||||||
|
err = ms.ds.Put(addrNonceKey, buf.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("failed to write nonce to datastore: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ms *MessageSigner) dstoreKey(addr address.Address) datastore.Key {
|
||||||
|
return datastore.KeyWithNamespaces([]string{dsKeyActorNonce, addr.String()})
|
||||||
|
}
|
||||||
|
@ -5,6 +5,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"github.com/filecoin-project/lotus/chain/wallet"
|
"github.com/filecoin-project/lotus/chain/wallet"
|
||||||
|
|
||||||
"github.com/filecoin-project/go-state-types/crypto"
|
"github.com/filecoin-project/go-state-types/crypto"
|
||||||
@ -58,6 +60,7 @@ func TestMessageSignerSignMessage(t *testing.T) {
|
|||||||
msg *types.Message
|
msg *types.Message
|
||||||
mpoolNonce [1]uint64
|
mpoolNonce [1]uint64
|
||||||
expNonce uint64
|
expNonce uint64
|
||||||
|
cbErr error
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@ -137,6 +140,37 @@ func TestMessageSignerSignMessage(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expNonce: 2,
|
expNonce: 2,
|
||||||
}},
|
}},
|
||||||
|
}, {
|
||||||
|
name: "recover from callback error",
|
||||||
|
msgs: []msgSpec{{
|
||||||
|
// No nonce yet in datastore
|
||||||
|
msg: &types.Message{
|
||||||
|
To: to1,
|
||||||
|
From: from1,
|
||||||
|
},
|
||||||
|
expNonce: 0,
|
||||||
|
}, {
|
||||||
|
// Increment nonce
|
||||||
|
msg: &types.Message{
|
||||||
|
To: to1,
|
||||||
|
From: from1,
|
||||||
|
},
|
||||||
|
expNonce: 1,
|
||||||
|
}, {
|
||||||
|
// Callback returns error
|
||||||
|
msg: &types.Message{
|
||||||
|
To: to1,
|
||||||
|
From: from1,
|
||||||
|
},
|
||||||
|
cbErr: xerrors.Errorf("err"),
|
||||||
|
}, {
|
||||||
|
// Callback successful, should increment nonce in datastore
|
||||||
|
msg: &types.Message{
|
||||||
|
To: to1,
|
||||||
|
From: from1,
|
||||||
|
},
|
||||||
|
expNonce: 2,
|
||||||
|
}},
|
||||||
}}
|
}}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
tt := tt
|
tt := tt
|
||||||
@ -149,10 +183,19 @@ func TestMessageSignerSignMessage(t *testing.T) {
|
|||||||
if len(m.mpoolNonce) == 1 {
|
if len(m.mpoolNonce) == 1 {
|
||||||
mpool.setNonce(m.msg.From, m.mpoolNonce[0])
|
mpool.setNonce(m.msg.From, m.mpoolNonce[0])
|
||||||
}
|
}
|
||||||
smsg, err := ms.SignMessage(ctx, m.msg)
|
merr := m.cbErr
|
||||||
|
smsg, err := ms.SignMessage(ctx, m.msg, func(message *types.SignedMessage) error {
|
||||||
|
return merr
|
||||||
|
})
|
||||||
|
|
||||||
|
if m.cbErr != nil {
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, smsg)
|
||||||
|
} else {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, m.expNonce, smsg.Message.Nonce)
|
require.Equal(t, m.expNonce, smsg.Message.Nonce)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -160,17 +160,13 @@ func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spe
|
|||||||
return nil, xerrors.Errorf("mpool push: not enough funds: %s < %s", b, msg.Value)
|
return nil, xerrors.Errorf("mpool push: not enough funds: %s < %s", b, msg.Value)
|
||||||
}
|
}
|
||||||
|
|
||||||
smsg, err := a.MessageSigner.SignMessage(ctx, msg)
|
// Sign and push the message
|
||||||
if err != nil {
|
return a.MessageSigner.SignMessage(ctx, msg, func(smsg *types.SignedMessage) error {
|
||||||
return nil, xerrors.Errorf("mpool push: failed to sign message: %w", err)
|
if _, err := a.Mpool.Push(smsg); err != nil {
|
||||||
|
return xerrors.Errorf("mpool push: failed to push message: %w", err)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
_, err = a.Mpool.Push(smsg)
|
})
|
||||||
if err != nil {
|
|
||||||
return nil, xerrors.Errorf("mpool push: failed to push message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return smsg, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *MpoolAPI) MpoolGetNonce(ctx context.Context, addr address.Address) (uint64, error) {
|
func (a *MpoolAPI) MpoolGetNonce(ctx context.Context, addr address.Address) (uint64, error) {
|
||||||
|
Loading…
Reference in New Issue
Block a user