fix: Removing normal transactions using RemoveWithoutRefTx (#89)

Co-authored-by: Aleksandr Bezobchuk <alexanderbez@users.noreply.github.com>
This commit is contained in:
David Terpay 2023-04-26 13:34:21 -04:00 committed by GitHub
parent d3a36f5329
commit 328d28dc3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 51 deletions

View File

@ -21,7 +21,6 @@ type (
GetBundledTransactions(tx sdk.Tx) ([][]byte, error)
WrapBundleTransaction(tx []byte) (sdk.Tx, error)
IsAuctionTx(tx sdk.Tx) (bool, error)
RemoveWithoutRefTx(tx sdk.Tx) error
}
ProposalHandler struct {
@ -277,7 +276,7 @@ func (h *ProposalHandler) verifyTx(ctx sdk.Context, tx sdk.Tx) error {
}
func (h *ProposalHandler) RemoveTx(tx sdk.Tx) {
if err := h.mempool.RemoveWithoutRefTx(tx); err != nil && !errors.Is(err, sdkmempool.ErrTxNotFound) {
if err := h.mempool.Remove(tx); err != nil && !errors.Is(err, sdkmempool.ErrTxNotFound) {
panic(fmt.Errorf("failed to remove invalid transaction from the mempool: %w", err))
}
}

View File

@ -106,9 +106,7 @@ func NewAuctionMempool(txDecoder sdk.TxDecoder, txEncoder sdk.TxEncoder, maxTx i
}
}
// Insert inserts a transaction into the mempool. If the transaction is a special
// auction tx (tx that contains a single MsgAuctionBid), it will also insert the
// transaction into the auction index.
// Insert inserts a transaction into the mempool based on the transaction type (normal or auction).
func (am *AuctionMempool) Insert(ctx context.Context, tx sdk.Tx) error {
isAuctionTx, err := am.IsAuctionTx(tx)
if err != nil {
@ -137,9 +135,7 @@ func (am *AuctionMempool) Insert(ctx context.Context, tx sdk.Tx) error {
return nil
}
// Remove removes a transaction from the mempool. If the transaction is a special
// auction tx (tx that contains a single MsgAuctionBid), it will also remove all
// referenced transactions from the global mempool.
// Remove removes a transaction from the mempool based on the transaction type (normal or auction).
func (am *AuctionMempool) Remove(tx sdk.Tx) error {
isAuctionTx, err := am.IsAuctionTx(tx)
if err != nil {
@ -152,39 +148,6 @@ func (am *AuctionMempool) Remove(tx sdk.Tx) error {
am.removeTx(am.globalIndex, tx)
case isAuctionTx:
am.removeTx(am.auctionIndex, tx)
// Remove all referenced transactions from the global mempool.
bundleTxs, err := am.GetBundledTransactions(tx)
if err != nil {
return err
}
for _, refTx := range bundleTxs {
wrappedRefTx, err := am.WrapBundleTransaction(refTx)
if err != nil {
return err
}
am.removeTx(am.globalIndex, wrappedRefTx)
}
}
return nil
}
// RemoveWithoutRefTx removes a transaction from the mempool without removing
// any referenced transactions. Referenced transactions only exist in special
// auction transactions (txs that only include a single MsgAuctionBid). This
// API is used to ensure that searchers are unable to remove valid transactions
// from the global mempool.
func (am *AuctionMempool) RemoveWithoutRefTx(tx sdk.Tx) error {
isAuctionTx, err := am.IsAuctionTx(tx)
if err != nil {
return err
}
if isAuctionTx {
am.removeTx(am.auctionIndex, tx)
}
return nil

View File

@ -1,6 +1,7 @@
package mempool_test
import (
"context"
"math/rand"
"testing"
"time"
@ -134,9 +135,9 @@ func (suite *IntegrationTestSuite) TestAuctionMempoolRemove() {
suite.Require().NotNil(auctionIterator)
tx := auctionIterator.Tx()
suite.Require().Len(tx.GetMsgs(), 1)
suite.Require().NoError(suite.mempool.RemoveWithoutRefTx(tx))
suite.Require().NoError(suite.mempool.Remove(tx))
// Ensure that the auction tx was removed from the auction and global mempool
// Ensure that the auction tx was removed from the auction mempool only
suite.Require().Equal(numberAuctionTxs-1, suite.mempool.CountAuctionTx())
suite.Require().Equal(numMempoolTxs, suite.mempool.CountTx())
contains, err := suite.mempool.Contains(tx)
@ -144,22 +145,54 @@ func (suite *IntegrationTestSuite) TestAuctionMempoolRemove() {
suite.Require().False(contains)
// Attempt to remove again and ensure that the tx is not found
suite.Require().NoError(suite.mempool.RemoveWithoutRefTx(tx))
suite.Require().NoError(suite.mempool.Remove(tx))
suite.Require().Equal(numberAuctionTxs-1, suite.mempool.CountAuctionTx())
suite.Require().Equal(numMempoolTxs, suite.mempool.CountTx())
// Attempt to remove with the bundled txs
suite.Require().NoError(suite.mempool.Remove(tx))
suite.Require().Equal(numberAuctionTxs-1, suite.mempool.CountAuctionTx())
suite.Require().Equal(numMempoolTxs-numberBundledTxs, suite.mempool.CountTx())
// Bundled txs should be in the global mempool
auctionMsg, err := mempool.GetMsgAuctionBidFromTx(tx)
suite.Require().NoError(err)
for _, refTx := range auctionMsg.GetTransactions() {
tx, err := suite.encCfg.TxConfig.TxDecoder()(refTx)
suite.Require().NoError(err)
suite.Require().False(suite.mempool.Contains(tx))
contains, err = suite.mempool.Contains(tx)
suite.Require().NoError(err)
suite.Require().True(contains)
}
// Attempt to remove a global tx
iterator := suite.mempool.Select(context.Background(), nil)
tx = iterator.Tx()
size := suite.mempool.CountTx()
suite.mempool.Remove(tx)
suite.Require().Equal(size-1, suite.mempool.CountTx())
// Remove the rest of the global transactions
iterator = suite.mempool.Select(context.Background(), nil)
suite.Require().NotNil(iterator)
for iterator != nil {
tx = iterator.Tx()
suite.Require().NoError(suite.mempool.Remove(tx))
iterator = suite.mempool.Select(context.Background(), nil)
}
suite.Require().Equal(0, suite.mempool.CountTx())
// Remove the rest of the auction transactions
auctionIterator = suite.mempool.AuctionBidSelect(suite.ctx)
for auctionIterator != nil {
tx = auctionIterator.Tx()
suite.Require().NoError(suite.mempool.Remove(tx))
auctionIterator = suite.mempool.AuctionBidSelect(suite.ctx)
}
suite.Require().Equal(0, suite.mempool.CountAuctionTx())
// Ensure that the mempool is empty
iterator = suite.mempool.Select(context.Background(), nil)
suite.Require().Nil(iterator)
auctionIterator = suite.mempool.AuctionBidSelect(suite.ctx)
suite.Require().Nil(auctionIterator)
suite.Require().Equal(0, suite.mempool.CountTx())
suite.Require().Equal(0, suite.mempool.CountAuctionTx())
}
func (suite *IntegrationTestSuite) TestAuctionMempoolSelect() {
@ -167,7 +200,7 @@ func (suite *IntegrationTestSuite) TestAuctionMempoolSelect() {
numberAuctionTxs := 10
numberBundledTxs := 5
insertRefTxs := true
suite.CreateFilledMempool(numberTotalTxs, numberAuctionTxs, numberBundledTxs, insertRefTxs)
totalTxs := suite.CreateFilledMempool(numberTotalTxs, numberAuctionTxs, numberBundledTxs, insertRefTxs)
// iterate through the entire auction mempool and ensure the bids are in order
var highestBid sdk.Coin
@ -195,4 +228,12 @@ func (suite *IntegrationTestSuite) TestAuctionMempoolSelect() {
}
suite.Require().Equal(numberAuctionTxs, numberTxsSeen)
iterator := suite.mempool.Select(context.Background(), nil)
numberTxsSeen = 0
for iterator != nil {
iterator = iterator.Next()
numberTxsSeen++
}
suite.Require().Equal(totalTxs, numberTxsSeen)
}