From 328d28dc3f943feda8e6881525643f55b6dac054 Mon Sep 17 00:00:00 2001 From: David Terpay <35130517+davidterpay@users.noreply.github.com> Date: Wed, 26 Apr 2023 13:34:21 -0400 Subject: [PATCH] fix: Removing normal transactions using RemoveWithoutRefTx (#89) Co-authored-by: Aleksandr Bezobchuk --- abci/abci.go | 3 +- mempool/mempool.go | 41 ++------------------------- mempool/mempool_test.go | 61 ++++++++++++++++++++++++++++++++++------- 3 files changed, 54 insertions(+), 51 deletions(-) diff --git a/abci/abci.go b/abci/abci.go index de218e4..fea89ad 100644 --- a/abci/abci.go +++ b/abci/abci.go @@ -21,7 +21,6 @@ type ( GetBundledTransactions(tx sdk.Tx) ([][]byte, error) WrapBundleTransaction(tx []byte) (sdk.Tx, error) IsAuctionTx(tx sdk.Tx) (bool, error) - RemoveWithoutRefTx(tx sdk.Tx) error } ProposalHandler struct { @@ -277,7 +276,7 @@ func (h *ProposalHandler) verifyTx(ctx sdk.Context, tx sdk.Tx) error { } func (h *ProposalHandler) RemoveTx(tx sdk.Tx) { - if err := h.mempool.RemoveWithoutRefTx(tx); err != nil && !errors.Is(err, sdkmempool.ErrTxNotFound) { + if err := h.mempool.Remove(tx); err != nil && !errors.Is(err, sdkmempool.ErrTxNotFound) { panic(fmt.Errorf("failed to remove invalid transaction from the mempool: %w", err)) } } diff --git a/mempool/mempool.go b/mempool/mempool.go index 6ff89fe..14415e5 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -106,9 +106,7 @@ func NewAuctionMempool(txDecoder sdk.TxDecoder, txEncoder sdk.TxEncoder, maxTx i } } -// Insert inserts a transaction into the mempool. If the transaction is a special -// auction tx (tx that contains a single MsgAuctionBid), it will also insert the -// transaction into the auction index. +// Insert inserts a transaction into the mempool based on the transaction type (normal or auction). func (am *AuctionMempool) Insert(ctx context.Context, tx sdk.Tx) error { isAuctionTx, err := am.IsAuctionTx(tx) if err != nil { @@ -137,9 +135,7 @@ func (am *AuctionMempool) Insert(ctx context.Context, tx sdk.Tx) error { return nil } -// Remove removes a transaction from the mempool. If the transaction is a special -// auction tx (tx that contains a single MsgAuctionBid), it will also remove all -// referenced transactions from the global mempool. +// Remove removes a transaction from the mempool based on the transaction type (normal or auction). func (am *AuctionMempool) Remove(tx sdk.Tx) error { isAuctionTx, err := am.IsAuctionTx(tx) if err != nil { @@ -152,39 +148,6 @@ func (am *AuctionMempool) Remove(tx sdk.Tx) error { am.removeTx(am.globalIndex, tx) case isAuctionTx: am.removeTx(am.auctionIndex, tx) - - // Remove all referenced transactions from the global mempool. - bundleTxs, err := am.GetBundledTransactions(tx) - if err != nil { - return err - } - - for _, refTx := range bundleTxs { - wrappedRefTx, err := am.WrapBundleTransaction(refTx) - if err != nil { - return err - } - - am.removeTx(am.globalIndex, wrappedRefTx) - } - } - - return nil -} - -// RemoveWithoutRefTx removes a transaction from the mempool without removing -// any referenced transactions. Referenced transactions only exist in special -// auction transactions (txs that only include a single MsgAuctionBid). This -// API is used to ensure that searchers are unable to remove valid transactions -// from the global mempool. -func (am *AuctionMempool) RemoveWithoutRefTx(tx sdk.Tx) error { - isAuctionTx, err := am.IsAuctionTx(tx) - if err != nil { - return err - } - - if isAuctionTx { - am.removeTx(am.auctionIndex, tx) } return nil diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index cd9c10c..33db26e 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -1,6 +1,7 @@ package mempool_test import ( + "context" "math/rand" "testing" "time" @@ -134,9 +135,9 @@ func (suite *IntegrationTestSuite) TestAuctionMempoolRemove() { suite.Require().NotNil(auctionIterator) tx := auctionIterator.Tx() suite.Require().Len(tx.GetMsgs(), 1) - suite.Require().NoError(suite.mempool.RemoveWithoutRefTx(tx)) + suite.Require().NoError(suite.mempool.Remove(tx)) - // Ensure that the auction tx was removed from the auction and global mempool + // Ensure that the auction tx was removed from the auction mempool only suite.Require().Equal(numberAuctionTxs-1, suite.mempool.CountAuctionTx()) suite.Require().Equal(numMempoolTxs, suite.mempool.CountTx()) contains, err := suite.mempool.Contains(tx) @@ -144,22 +145,54 @@ func (suite *IntegrationTestSuite) TestAuctionMempoolRemove() { suite.Require().False(contains) // Attempt to remove again and ensure that the tx is not found - suite.Require().NoError(suite.mempool.RemoveWithoutRefTx(tx)) + suite.Require().NoError(suite.mempool.Remove(tx)) suite.Require().Equal(numberAuctionTxs-1, suite.mempool.CountAuctionTx()) suite.Require().Equal(numMempoolTxs, suite.mempool.CountTx()) - // Attempt to remove with the bundled txs - suite.Require().NoError(suite.mempool.Remove(tx)) - suite.Require().Equal(numberAuctionTxs-1, suite.mempool.CountAuctionTx()) - suite.Require().Equal(numMempoolTxs-numberBundledTxs, suite.mempool.CountTx()) - + // Bundled txs should be in the global mempool auctionMsg, err := mempool.GetMsgAuctionBidFromTx(tx) suite.Require().NoError(err) for _, refTx := range auctionMsg.GetTransactions() { tx, err := suite.encCfg.TxConfig.TxDecoder()(refTx) suite.Require().NoError(err) - suite.Require().False(suite.mempool.Contains(tx)) + contains, err = suite.mempool.Contains(tx) + suite.Require().NoError(err) + suite.Require().True(contains) } + + // Attempt to remove a global tx + iterator := suite.mempool.Select(context.Background(), nil) + tx = iterator.Tx() + size := suite.mempool.CountTx() + suite.mempool.Remove(tx) + suite.Require().Equal(size-1, suite.mempool.CountTx()) + + // Remove the rest of the global transactions + iterator = suite.mempool.Select(context.Background(), nil) + suite.Require().NotNil(iterator) + for iterator != nil { + tx = iterator.Tx() + suite.Require().NoError(suite.mempool.Remove(tx)) + iterator = suite.mempool.Select(context.Background(), nil) + } + suite.Require().Equal(0, suite.mempool.CountTx()) + + // Remove the rest of the auction transactions + auctionIterator = suite.mempool.AuctionBidSelect(suite.ctx) + for auctionIterator != nil { + tx = auctionIterator.Tx() + suite.Require().NoError(suite.mempool.Remove(tx)) + auctionIterator = suite.mempool.AuctionBidSelect(suite.ctx) + } + suite.Require().Equal(0, suite.mempool.CountAuctionTx()) + + // Ensure that the mempool is empty + iterator = suite.mempool.Select(context.Background(), nil) + suite.Require().Nil(iterator) + auctionIterator = suite.mempool.AuctionBidSelect(suite.ctx) + suite.Require().Nil(auctionIterator) + suite.Require().Equal(0, suite.mempool.CountTx()) + suite.Require().Equal(0, suite.mempool.CountAuctionTx()) } func (suite *IntegrationTestSuite) TestAuctionMempoolSelect() { @@ -167,7 +200,7 @@ func (suite *IntegrationTestSuite) TestAuctionMempoolSelect() { numberAuctionTxs := 10 numberBundledTxs := 5 insertRefTxs := true - suite.CreateFilledMempool(numberTotalTxs, numberAuctionTxs, numberBundledTxs, insertRefTxs) + totalTxs := suite.CreateFilledMempool(numberTotalTxs, numberAuctionTxs, numberBundledTxs, insertRefTxs) // iterate through the entire auction mempool and ensure the bids are in order var highestBid sdk.Coin @@ -195,4 +228,12 @@ func (suite *IntegrationTestSuite) TestAuctionMempoolSelect() { } suite.Require().Equal(numberAuctionTxs, numberTxsSeen) + + iterator := suite.mempool.Select(context.Background(), nil) + numberTxsSeen = 0 + for iterator != nil { + iterator = iterator.Next() + numberTxsSeen++ + } + suite.Require().Equal(totalTxs, numberTxsSeen) }