fix(auth): audit issues with unordered txs (#23392)
Co-authored-by: Alex | Interchain Labs <alex@interchainlabs.io> Co-authored-by: Alexander Peters <alpe@users.noreply.github.com>
This commit is contained in:
parent
8eb6822d25
commit
ddf9e18ee4
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@ -55,6 +56,21 @@ type testTx struct {
|
||||
address sdk.AccAddress
|
||||
// useful for debugging
|
||||
strAddress string
|
||||
unordered bool
|
||||
timeout *time.Time
|
||||
}
|
||||
|
||||
// GetTimeoutTimeStamp implements types.TxWithUnordered.
|
||||
func (tx testTx) GetTimeoutTimeStamp() time.Time {
|
||||
if tx.timeout == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return *tx.timeout
|
||||
}
|
||||
|
||||
// GetUnordered implements types.TxWithUnordered.
|
||||
func (tx testTx) GetUnordered() bool {
|
||||
return tx.unordered
|
||||
}
|
||||
|
||||
func (tx testTx) GetSigners() ([][]byte, error) { panic("not implemented") }
|
||||
@ -73,6 +89,7 @@ func (tx testTx) GetSignaturesV2() (res []txsigning.SignatureV2, err error) {
|
||||
|
||||
var (
|
||||
_ sdk.Tx = (*testTx)(nil)
|
||||
_ sdk.TxWithUnordered = (*testTx)(nil)
|
||||
_ signing.SigVerifiableTx = (*testTx)(nil)
|
||||
_ cryptotypes.PubKey = (*testPubKey)(nil)
|
||||
)
|
||||
|
||||
@ -224,13 +224,13 @@ func (mp *PriorityNonceMempool[C]) Insert(ctx context.Context, tx sdk.Tx) error
|
||||
priority := mp.cfg.TxPriority.GetTxPriority(ctx, tx)
|
||||
nonce := sig.Sequence
|
||||
|
||||
// if it's an unordered tx, we use the gas instead of the nonce
|
||||
// if it's an unordered tx, we use the timeout timestamp instead of the nonce
|
||||
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
|
||||
gasLimit, err := unordered.GetGasLimit()
|
||||
nonce = gasLimit
|
||||
if err != nil {
|
||||
return err
|
||||
timestamp := unordered.GetTimeoutTimeStamp().Unix()
|
||||
if timestamp < 0 {
|
||||
return errors.New("invalid timestamp value")
|
||||
}
|
||||
nonce = uint64(timestamp)
|
||||
}
|
||||
|
||||
key := txMeta[C]{nonce: nonce, priority: priority, sender: sender}
|
||||
@ -469,13 +469,13 @@ func (mp *PriorityNonceMempool[C]) Remove(tx sdk.Tx) error {
|
||||
sender := sig.Signer.String()
|
||||
nonce := sig.Sequence
|
||||
|
||||
// if it's an unordered tx, we use the gas instead of the nonce
|
||||
// if it's an unordered tx, we use the timeout timestamp instead of the nonce
|
||||
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
|
||||
gasLimit, err := unordered.GetGasLimit()
|
||||
nonce = gasLimit
|
||||
if err != nil {
|
||||
return err
|
||||
timestamp := unordered.GetTimeoutTimeStamp().Unix()
|
||||
if timestamp < 0 {
|
||||
return errors.New("invalid timestamp value")
|
||||
}
|
||||
nonce = uint64(timestamp)
|
||||
}
|
||||
|
||||
scoreKey := txMeta[C]{nonce: nonce, sender: sender}
|
||||
|
||||
@ -970,3 +970,40 @@ func TestNextSenderTx_TxReplacement(t *testing.T) {
|
||||
iter := mp.Select(ctx, nil)
|
||||
require.Equal(t, txs[3], iter.Tx())
|
||||
}
|
||||
|
||||
func TestPriorityNonceMempool_UnorderedTx(t *testing.T) {
|
||||
ctx := sdk.NewContext(nil, false, log.NewNopLogger())
|
||||
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2)
|
||||
sa := accounts[0].Address
|
||||
sb := accounts[1].Address
|
||||
|
||||
mp := mempool.DefaultPriorityMempool()
|
||||
|
||||
now := time.Now()
|
||||
oneHour := now.Add(1 * time.Hour)
|
||||
thirtyMin := now.Add(30 * time.Minute)
|
||||
twoHours := now.Add(2 * time.Hour)
|
||||
fifteenMin := now.Add(15 * time.Minute)
|
||||
|
||||
txs := []testTx{
|
||||
{id: 1, priority: 0, address: sa, timeout: &thirtyMin, unordered: true},
|
||||
{id: 0, priority: 0, address: sa, timeout: &oneHour, unordered: true},
|
||||
{id: 3, priority: 0, address: sb, timeout: &fifteenMin, unordered: true},
|
||||
{id: 2, priority: 0, address: sb, timeout: &twoHours, unordered: true},
|
||||
}
|
||||
|
||||
for _, tx := range txs {
|
||||
c := ctx.WithPriority(tx.priority)
|
||||
require.NoError(t, mp.Insert(c, tx))
|
||||
}
|
||||
|
||||
require.Equal(t, 4, mp.CountTx())
|
||||
|
||||
orderedTxs := fetchTxs(mp.Select(ctx, nil), 100000)
|
||||
require.Equal(t, len(txs), len(orderedTxs))
|
||||
|
||||
// check order
|
||||
for i, tx := range orderedTxs {
|
||||
require.Equal(t, txs[i].id, tx.(testTx).id)
|
||||
}
|
||||
}
|
||||
|
||||
@ -139,21 +139,21 @@ func (snm *SenderNonceMempool) Insert(_ context.Context, tx sdk.Tx) error {
|
||||
sender := sdk.AccAddress(sig.PubKey.Address()).String()
|
||||
nonce := sig.Sequence
|
||||
|
||||
// if it's an unordered tx, we use the timeout timestamp instead of the nonce
|
||||
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
|
||||
timestamp := unordered.GetTimeoutTimeStamp().Unix()
|
||||
if timestamp < 0 {
|
||||
return errors.New("invalid timestamp value")
|
||||
}
|
||||
nonce = uint64(timestamp)
|
||||
}
|
||||
|
||||
senderTxs, found := snm.senders[sender]
|
||||
if !found {
|
||||
senderTxs = skiplist.New(skiplist.Uint64)
|
||||
snm.senders[sender] = senderTxs
|
||||
}
|
||||
|
||||
// if it's an unordered tx, we use the gas instead of the nonce
|
||||
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
|
||||
gasLimit, err := unordered.GetGasLimit()
|
||||
nonce = gasLimit
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
senderTxs.Set(nonce, tx)
|
||||
|
||||
key := txKey{nonce: nonce, address: sender}
|
||||
@ -236,13 +236,13 @@ func (snm *SenderNonceMempool) Remove(tx sdk.Tx) error {
|
||||
sender := sdk.AccAddress(sig.PubKey.Address()).String()
|
||||
nonce := sig.Sequence
|
||||
|
||||
// if it's an unordered tx, we use the gas instead of the nonce
|
||||
// if it's an unordered tx, we use the timeout timestamp instead of the nonce
|
||||
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
|
||||
gasLimit, err := unordered.GetGasLimit()
|
||||
nonce = gasLimit
|
||||
if err != nil {
|
||||
return err
|
||||
timestamp := unordered.GetTimeoutTimeStamp().Unix()
|
||||
if timestamp < 0 {
|
||||
return errors.New("invalid timestamp value")
|
||||
}
|
||||
nonce = uint64(timestamp)
|
||||
}
|
||||
|
||||
senderTxs, found := snm.senders[sender]
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@ -192,3 +193,67 @@ func (s *MempoolTestSuite) TestTxNotFoundOnSender() {
|
||||
err = mp.Remove(tx)
|
||||
require.Equal(t, mempool.ErrTxNotFound, err)
|
||||
}
|
||||
|
||||
func (s *MempoolTestSuite) TestUnorderedTx() {
|
||||
t := s.T()
|
||||
|
||||
ctx := sdk.NewContext(nil, false, log.NewNopLogger())
|
||||
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2)
|
||||
sa := accounts[0].Address
|
||||
sb := accounts[1].Address
|
||||
|
||||
mp := mempool.NewSenderNonceMempool(mempool.SenderNonceMaxTxOpt(5000))
|
||||
|
||||
now := time.Now()
|
||||
oneHour := now.Add(1 * time.Hour)
|
||||
thirtyMin := now.Add(30 * time.Minute)
|
||||
twoHours := now.Add(2 * time.Hour)
|
||||
fifteenMin := now.Add(15 * time.Minute)
|
||||
|
||||
txs := []testTx{
|
||||
{id: 0, address: sa, timeout: &oneHour, unordered: true},
|
||||
{id: 1, address: sa, timeout: &thirtyMin, unordered: true},
|
||||
{id: 2, address: sb, timeout: &twoHours, unordered: true},
|
||||
{id: 3, address: sb, timeout: &fifteenMin, unordered: true},
|
||||
}
|
||||
|
||||
for _, tx := range txs {
|
||||
c := ctx.WithPriority(tx.priority)
|
||||
require.NoError(t, mp.Insert(c, tx))
|
||||
}
|
||||
|
||||
require.Equal(t, 4, mp.CountTx())
|
||||
|
||||
orderedTxs := fetchTxs(mp.Select(ctx, nil), 100000)
|
||||
require.Equal(t, len(txs), len(orderedTxs))
|
||||
|
||||
// Because the sender is selected randomly it can be any of these options
|
||||
acceptableOptions := [][]int{
|
||||
{3, 1, 2, 0},
|
||||
{3, 1, 0, 2},
|
||||
{3, 2, 1, 0},
|
||||
{1, 3, 0, 2},
|
||||
{1, 3, 2, 0},
|
||||
{1, 0, 3, 2},
|
||||
}
|
||||
|
||||
orderedTxsIds := make([]int, len(orderedTxs))
|
||||
for i, tx := range orderedTxs {
|
||||
orderedTxsIds[i] = tx.(testTx).id
|
||||
}
|
||||
|
||||
anyAcceptableOrder := false
|
||||
for _, option := range acceptableOptions {
|
||||
for i, tx := range orderedTxs {
|
||||
if tx.(testTx).id != txs[option[i]].id {
|
||||
break
|
||||
}
|
||||
|
||||
if i == len(orderedTxs)-1 {
|
||||
anyAcceptableOrder = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, anyAcceptableOrder, "expected any of %v but got %v", acceptableOptions, orderedTxsIds)
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
@ -1384,3 +1385,34 @@ func TestAnteHandlerReCheck(t *testing.T) {
|
||||
_, err = suite.anteHandler(suite.ctx, tx, false)
|
||||
require.NotNil(t, err, "antehandler on recheck did not fail once feePayer no longer has sufficient funds")
|
||||
}
|
||||
|
||||
func TestAnteHandlerUnorderedTx(t *testing.T) {
|
||||
suite := SetupTestSuite(t, false)
|
||||
accs := suite.CreateTestAccounts(1)
|
||||
msg := testdata.NewTestMsg(accs[0].acc.GetAddress())
|
||||
|
||||
// First send a normal sequential tx with sequence 0
|
||||
suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), accs[0].acc.GetAddress(), authtypes.FeeCollectorName, testdata.NewTestFeeAmount()).Return(nil).AnyTimes()
|
||||
|
||||
privs, accNums, accSeqs := []cryptotypes.PrivKey{accs[0].priv}, []uint64{1000}, []uint64{0}
|
||||
_, err := suite.DeliverMsgs(t, privs, []sdk.Msg{msg}, testdata.NewTestFeeAmount(), testdata.NewTestGasLimit(), accNums, accSeqs, suite.ctx.ChainID(), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// we try to send another tx with the same sequence, it will fail
|
||||
_, err = suite.DeliverMsgs(t, privs, []sdk.Msg{msg}, testdata.NewTestFeeAmount(), testdata.NewTestGasLimit(), accNums, accSeqs, suite.ctx.ChainID(), false)
|
||||
require.Error(t, err)
|
||||
|
||||
// now we'll still use the same sequence but because it's unordered, it will be ignored and accepted anyway
|
||||
msgs := []sdk.Msg{msg}
|
||||
require.NoError(t, suite.txBuilder.SetMsgs(msgs...))
|
||||
suite.txBuilder.SetFeeAmount(testdata.NewTestFeeAmount())
|
||||
suite.txBuilder.SetGasLimit(testdata.NewTestGasLimit())
|
||||
|
||||
tx, txErr := suite.CreateTestUnorderedTx(suite.ctx, privs, accNums, accSeqs, suite.ctx.ChainID(), apisigning.SignMode_SIGN_MODE_DIRECT, true, time.Now().Add(time.Minute))
|
||||
require.NoError(t, txErr)
|
||||
txBytes, err := suite.clientCtx.TxConfig.TxEncoder()(tx)
|
||||
bytesCtx := suite.ctx.WithTxBytes(txBytes)
|
||||
require.NoError(t, err)
|
||||
_, err = suite.anteHandler(bytesCtx, tx, false)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@ -320,18 +320,24 @@ func (svd SigVerificationDecorator) consumeSignatureGas(
|
||||
// verifySig will verify the signature of the provided signer account.
|
||||
func (svd SigVerificationDecorator) verifySig(ctx context.Context, tx sdk.Tx, acc sdk.AccountI, sig signing.SignatureV2, newlyCreated bool) error {
|
||||
execMode := svd.ak.GetEnvironment().TransactionService.ExecMode(ctx)
|
||||
if execMode == transaction.ExecModeCheck {
|
||||
if sig.Sequence < acc.GetSequence() {
|
||||
unorderedTx, ok := tx.(sdk.TxWithUnordered)
|
||||
isUnordered := ok && unorderedTx.GetUnordered()
|
||||
|
||||
// only check sequence if the tx is not unordered
|
||||
if !isUnordered {
|
||||
if execMode == transaction.ExecModeCheck {
|
||||
if sig.Sequence < acc.GetSequence() {
|
||||
return errorsmod.Wrapf(
|
||||
sdkerrors.ErrWrongSequence,
|
||||
"account sequence mismatch: expected higher than or equal to %d, got %d", acc.GetSequence(), sig.Sequence,
|
||||
)
|
||||
}
|
||||
} else if sig.Sequence != acc.GetSequence() {
|
||||
return errorsmod.Wrapf(
|
||||
sdkerrors.ErrWrongSequence,
|
||||
"account sequence mismatch, expected higher than or equal to %d, got %d", acc.GetSequence(), sig.Sequence,
|
||||
"account sequence mismatch: expected %d, got %d", acc.GetSequence(), sig.Sequence,
|
||||
)
|
||||
}
|
||||
} else if sig.Sequence != acc.GetSequence() {
|
||||
return errorsmod.Wrapf(
|
||||
sdkerrors.ErrWrongSequence,
|
||||
"account sequence mismatch: expected %d, got %d", acc.GetSequence(), sig.Sequence,
|
||||
)
|
||||
}
|
||||
|
||||
// we're in simulation mode, or in ReCheckTx, or context is not
|
||||
|
||||
@ -3,6 +3,7 @@ package ante_test
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
@ -241,6 +242,67 @@ func (suite *AnteTestSuite) RunTestCase(t *testing.T, tc TestCase, args TestCase
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *AnteTestSuite) CreateTestUnorderedTx(
|
||||
ctx sdk.Context, privs []cryptotypes.PrivKey,
|
||||
accNums, accSeqs []uint64,
|
||||
chainID string, signMode apisigning.SignMode,
|
||||
unordered bool, unorderedTimeout time.Time,
|
||||
) (xauthsigning.Tx, error) {
|
||||
suite.txBuilder.SetUnordered(unordered)
|
||||
suite.txBuilder.SetTimeoutTimestamp(unorderedTimeout)
|
||||
|
||||
// First round: we gather all the signer infos. We use the "set empty
|
||||
// signature" hack to do that.
|
||||
var sigsV2 []signing.SignatureV2
|
||||
for i, priv := range privs {
|
||||
sigV2 := signing.SignatureV2{
|
||||
PubKey: priv.PubKey(),
|
||||
Data: &signing.SingleSignatureData{
|
||||
SignMode: signMode,
|
||||
Signature: nil,
|
||||
},
|
||||
Sequence: accSeqs[i],
|
||||
}
|
||||
|
||||
sigsV2 = append(sigsV2, sigV2)
|
||||
}
|
||||
err := suite.txBuilder.SetSignatures(sigsV2...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Second round: all signer infos are set, so each signer can sign.
|
||||
sigsV2 = []signing.SignatureV2{}
|
||||
for i, priv := range privs {
|
||||
anyPk, err := codectypes.NewAnyWithValue(priv.PubKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signerData := txsigning.SignerData{
|
||||
Address: sdk.AccAddress(priv.PubKey().Address()).String(),
|
||||
ChainID: chainID,
|
||||
AccountNumber: accNums[i],
|
||||
Sequence: accSeqs[i],
|
||||
PubKey: &anypb.Any{TypeUrl: anyPk.TypeUrl, Value: anyPk.Value},
|
||||
}
|
||||
sigV2, err := tx.SignWithPrivKey(
|
||||
ctx, signMode, signerData,
|
||||
suite.txBuilder, priv, suite.clientCtx.TxConfig, accSeqs[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sigsV2 = append(sigsV2, sigV2)
|
||||
}
|
||||
err = suite.txBuilder.SetSignatures(sigsV2...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return suite.txBuilder.GetTx(), nil
|
||||
}
|
||||
|
||||
// CreateTestTx is a helper function to create a tx given multiple inputs.
|
||||
func (suite *AnteTestSuite) CreateTestTx(
|
||||
ctx sdk.Context, privs []cryptotypes.PrivKey,
|
||||
|
||||
@ -17,7 +17,9 @@ import (
|
||||
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
|
||||
"github.com/cosmos/cosmos-sdk/types/tx/signing"
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/ante/unorderedtx"
|
||||
authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
|
||||
)
|
||||
|
||||
// bufPool is a pool of bytes.Buffer objects to reduce memory allocations.
|
||||
@ -146,8 +148,12 @@ func (d *UnorderedTxDecorator) ValidateTx(ctx context.Context, tx transaction.Tx
|
||||
|
||||
// TxIdentifier returns a unique identifier for a transaction that is intended to be unordered.
|
||||
func TxIdentifier(timeout uint64, tx sdk.Tx) ([32]byte, error) {
|
||||
feetx := tx.(sdk.FeeTx)
|
||||
if feetx.GetFee().IsZero() {
|
||||
sigTx, ok := tx.(authsigning.Tx)
|
||||
if !ok {
|
||||
return [32]byte{}, errorsmod.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type")
|
||||
}
|
||||
|
||||
if sigTx.GetFee().IsZero() {
|
||||
return [32]byte{}, errorsmod.Wrap(
|
||||
sdkerrors.ErrInvalidRequest,
|
||||
"unordered transaction must have a fee",
|
||||
@ -159,6 +165,18 @@ func TxIdentifier(timeout uint64, tx sdk.Tx) ([32]byte, error) {
|
||||
buf.Reset()
|
||||
defer bufPool.Put(buf)
|
||||
|
||||
// Add signatures to the transaction identifier
|
||||
signatures, err := sigTx.GetSignaturesV2()
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
|
||||
for _, sig := range signatures {
|
||||
if err := addSignatures(sig.Data, buf); err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Use the buffer
|
||||
for _, msg := range tx.GetMsgs() {
|
||||
// loop through the messages and write them to the buffer
|
||||
@ -189,7 +207,7 @@ func TxIdentifier(timeout uint64, tx sdk.Tx) ([32]byte, error) {
|
||||
}
|
||||
|
||||
// write gas to the buffer
|
||||
if err := binary.Write(buf, binary.LittleEndian, feetx.GetGas()); err != nil {
|
||||
if err := binary.Write(buf, binary.LittleEndian, sigTx.GetGas()); err != nil {
|
||||
return [32]byte{}, errorsmod.Wrap(
|
||||
sdkerrors.ErrInvalidRequest,
|
||||
"failed to write unordered to buffer",
|
||||
@ -201,3 +219,27 @@ func TxIdentifier(timeout uint64, tx sdk.Tx) ([32]byte, error) {
|
||||
// Return the Buffer to the pool
|
||||
return txHash, nil
|
||||
}
|
||||
|
||||
func addSignatures(sig signing.SignatureData, buf *bytes.Buffer) error {
|
||||
switch data := sig.(type) {
|
||||
case *signing.SingleSignatureData:
|
||||
if _, err := buf.Write(data.Signature); err != nil {
|
||||
return errorsmod.Wrap(
|
||||
sdkerrors.ErrInvalidRequest,
|
||||
"failed to write single signature to buffer",
|
||||
)
|
||||
}
|
||||
return nil
|
||||
|
||||
case *signing.MultiSignatureData:
|
||||
for _, sigdata := range data.Signatures {
|
||||
if err := addSignatures(sigdata, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unexpected SignatureData %T", data)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user