Merge pull request #4165 from filecoin-project/fix/message-signer-nonce-generation

fix: make message signer nonce generation transactional
This commit is contained in:
Łukasz Magiera 2020-10-05 18:56:35 +02:00 committed by GitHub
commit fee84c4b44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 102 additions and 32 deletions

View File

@ -3,6 +3,7 @@ package messagesigner
import ( import (
"bytes" "bytes"
"context" "context"
"sync"
"github.com/filecoin-project/go-address" "github.com/filecoin-project/go-address"
"github.com/filecoin-project/lotus/chain/messagepool" "github.com/filecoin-project/lotus/chain/messagepool"
@ -16,7 +17,7 @@ import (
"golang.org/x/xerrors" "golang.org/x/xerrors"
) )
const dsKeyActorNonce = "ActorNonce" const dsKeyActorNonce = "ActorNextNonce"
var log = logging.Logger("messagesigner") var log = logging.Logger("messagesigner")
@ -28,6 +29,7 @@ type mpoolAPI interface {
// when signing a message // when signing a message
type MessageSigner struct { type MessageSigner struct {
wallet *wallet.Wallet wallet *wallet.Wallet
lk sync.Mutex
mpool mpoolAPI mpool mpoolAPI
ds datastore.Batching ds datastore.Batching
} }
@ -47,25 +49,42 @@ func newMessageSigner(wallet *wallet.Wallet, mpool mpoolAPI, ds dtypes.MetadataD
// SignMessage increments the nonce for the message From address, and signs // SignMessage increments the nonce for the message From address, and signs
// the message // the message
func (ms *MessageSigner) SignMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) { func (ms *MessageSigner) SignMessage(ctx context.Context, msg *types.Message, cb func(*types.SignedMessage) error) (*types.SignedMessage, error) {
ms.lk.Lock()
defer ms.lk.Unlock()
// Get the next message nonce
nonce, err := ms.nextNonce(msg.From) nonce, err := ms.nextNonce(msg.From)
if err != nil { if err != nil {
return nil, xerrors.Errorf("failed to create nonce: %w", err) return nil, xerrors.Errorf("failed to create nonce: %w", err)
} }
// Sign the message with the nonce
msg.Nonce = nonce msg.Nonce = nonce
sig, err := ms.wallet.Sign(ctx, msg.From, msg.Cid().Bytes()) sig, err := ms.wallet.Sign(ctx, msg.From, msg.Cid().Bytes())
if err != nil { if err != nil {
return nil, xerrors.Errorf("failed to sign message: %w", err) return nil, xerrors.Errorf("failed to sign message: %w", err)
} }
return &types.SignedMessage{ // Callback with the signed message
smsg := &types.SignedMessage{
Message: *msg, Message: *msg,
Signature: *sig, Signature: *sig,
}, nil }
err = cb(smsg)
if err != nil {
return nil, err
}
// If the callback executed successfully, write the nonce to the datastore
if err := ms.saveNonce(msg.From, nonce); err != nil {
return nil, xerrors.Errorf("failed to save nonce: %w", err)
}
return smsg, nil
} }
// nextNonce increments the nonce. // nextNonce gets the next nonce for the given address.
// If there is no nonce in the datastore, gets the nonce from the message pool. // If there is no nonce in the datastore, gets the nonce from the message pool.
func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) { func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) {
// Nonces used to be created by the mempool and we need to support nodes // Nonces used to be created by the mempool and we need to support nodes
@ -77,21 +96,22 @@ func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) {
return 0, xerrors.Errorf("failed to get nonce from mempool: %w", err) return 0, xerrors.Errorf("failed to get nonce from mempool: %w", err)
} }
// Get the nonce for this address from the datastore // Get the next nonce for this address from the datastore
addrNonceKey := datastore.KeyWithNamespaces([]string{dsKeyActorNonce, addr.String()}) addrNonceKey := ms.dstoreKey(addr)
dsNonceBytes, err := ms.ds.Get(addrNonceKey) dsNonceBytes, err := ms.ds.Get(addrNonceKey)
switch { switch {
case xerrors.Is(err, datastore.ErrNotFound): case xerrors.Is(err, datastore.ErrNotFound):
// If a nonce for this address hasn't yet been created in the // If a nonce for this address hasn't yet been created in the
// datastore, just use the nonce from the mempool // datastore, just use the nonce from the mempool
return nonce, nil
case err != nil: case err != nil:
return 0, xerrors.Errorf("failed to get nonce from datastore: %w", err) return 0, xerrors.Errorf("failed to get nonce from datastore: %w", err)
default: default:
// There is a nonce in the datastore, so unmarshall and increment it // There is a nonce in the datastore, so unmarshall it
maj, val, err := cbg.CborReadHeader(bytes.NewReader(dsNonceBytes)) maj, dsNonce, err := cbg.CborReadHeader(bytes.NewReader(dsNonceBytes))
if err != nil { if err != nil {
return 0, xerrors.Errorf("failed to parse nonce from datastore: %w", err) return 0, xerrors.Errorf("failed to parse nonce from datastore: %w", err)
} }
@ -99,26 +119,37 @@ func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) {
return 0, xerrors.Errorf("bad cbor type parsing nonce from datastore") return 0, xerrors.Errorf("bad cbor type parsing nonce from datastore")
} }
dsNonce := val + 1
// The message pool nonce should be <= than the datastore nonce // The message pool nonce should be <= than the datastore nonce
if nonce <= dsNonce { if nonce <= dsNonce {
nonce = dsNonce nonce = dsNonce
} else { } else {
log.Warnf("mempool nonce was larger than datastore nonce (%d > %d)", nonce, dsNonce) log.Warnf("mempool nonce was larger than datastore nonce (%d > %d)", nonce, dsNonce)
} }
}
// Write the nonce for this address to the datastore return nonce, nil
}
}
// saveNonce increments the nonce for this address and writes it to the
// datastore
func (ms *MessageSigner) saveNonce(addr address.Address, nonce uint64) error {
// Increment the nonce
nonce++
// Write the nonce to the datastore
addrNonceKey := ms.dstoreKey(addr)
buf := bytes.Buffer{} buf := bytes.Buffer{}
_, err = buf.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, nonce)) _, err := buf.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, nonce))
if err != nil { if err != nil {
return 0, xerrors.Errorf("failed to marshall nonce: %w", err) return xerrors.Errorf("failed to marshall nonce: %w", err)
} }
err = ms.ds.Put(addrNonceKey, buf.Bytes()) err = ms.ds.Put(addrNonceKey, buf.Bytes())
if err != nil { if err != nil {
return 0, xerrors.Errorf("failed to write nonce to datastore: %w", err) return xerrors.Errorf("failed to write nonce to datastore: %w", err)
} }
return nil
return nonce, nil }
func (ms *MessageSigner) dstoreKey(addr address.Address) datastore.Key {
return datastore.KeyWithNamespaces([]string{dsKeyActorNonce, addr.String()})
} }

View File

@ -5,6 +5,8 @@ import (
"sync" "sync"
"testing" "testing"
"golang.org/x/xerrors"
"github.com/filecoin-project/lotus/chain/wallet" "github.com/filecoin-project/lotus/chain/wallet"
"github.com/filecoin-project/go-state-types/crypto" "github.com/filecoin-project/go-state-types/crypto"
@ -58,6 +60,7 @@ func TestMessageSignerSignMessage(t *testing.T) {
msg *types.Message msg *types.Message
mpoolNonce [1]uint64 mpoolNonce [1]uint64
expNonce uint64 expNonce uint64
cbErr error
} }
tests := []struct { tests := []struct {
name string name string
@ -137,6 +140,37 @@ func TestMessageSignerSignMessage(t *testing.T) {
}, },
expNonce: 2, expNonce: 2,
}}, }},
}, {
name: "recover from callback error",
msgs: []msgSpec{{
// No nonce yet in datastore
msg: &types.Message{
To: to1,
From: from1,
},
expNonce: 0,
}, {
// Increment nonce
msg: &types.Message{
To: to1,
From: from1,
},
expNonce: 1,
}, {
// Callback returns error
msg: &types.Message{
To: to1,
From: from1,
},
cbErr: xerrors.Errorf("err"),
}, {
// Callback successful, should increment nonce in datastore
msg: &types.Message{
To: to1,
From: from1,
},
expNonce: 2,
}},
}} }}
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
@ -149,9 +183,18 @@ func TestMessageSignerSignMessage(t *testing.T) {
if len(m.mpoolNonce) == 1 { if len(m.mpoolNonce) == 1 {
mpool.setNonce(m.msg.From, m.mpoolNonce[0]) mpool.setNonce(m.msg.From, m.mpoolNonce[0])
} }
smsg, err := ms.SignMessage(ctx, m.msg) merr := m.cbErr
require.NoError(t, err) smsg, err := ms.SignMessage(ctx, m.msg, func(message *types.SignedMessage) error {
require.Equal(t, m.expNonce, smsg.Message.Nonce) return merr
})
if m.cbErr != nil {
require.Error(t, err)
require.Nil(t, smsg)
} else {
require.NoError(t, err)
require.Equal(t, m.expNonce, smsg.Message.Nonce)
}
} }
}) })
} }

View File

@ -160,17 +160,13 @@ func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spe
return nil, xerrors.Errorf("mpool push: not enough funds: %s < %s", b, msg.Value) return nil, xerrors.Errorf("mpool push: not enough funds: %s < %s", b, msg.Value)
} }
smsg, err := a.MessageSigner.SignMessage(ctx, msg) // Sign and push the message
if err != nil { return a.MessageSigner.SignMessage(ctx, msg, func(smsg *types.SignedMessage) error {
return nil, xerrors.Errorf("mpool push: failed to sign message: %w", err) if _, err := a.Mpool.Push(smsg); err != nil {
} return xerrors.Errorf("mpool push: failed to push message: %w", err)
}
_, err = a.Mpool.Push(smsg) return nil
if err != nil { })
return nil, xerrors.Errorf("mpool push: failed to push message: %w", err)
}
return smsg, err
} }
func (a *MpoolAPI) MpoolGetNonce(ctx context.Context, addr address.Address) (uint64, error) { func (a *MpoolAPI) MpoolGetNonce(ctx context.Context, addr address.Address) (uint64, error) {