From acaf60e5b8e6e64b062f43ffcbfef18c06e243ee Mon Sep 17 00:00:00 2001 From: David Terpay <35130517+davidterpay@users.noreply.github.com> Date: Wed, 12 Apr 2023 08:11:09 -0400 Subject: [PATCH] [ENG-704]: Decoupling the auction and global mempools (#55) --- abci/abci.go | 74 +++++++++++++------------------------ abci/abci_test.go | 43 +++++++++++++++------ mempool/mempool.go | 27 +++++++------- mempool/mempool_test.go | 9 ++--- x/builder/ante/ante_test.go | 1 - 5 files changed, 73 insertions(+), 81 deletions(-) diff --git a/abci/abci.go b/abci/abci.go index bf47d84..84bef9a 100644 --- a/abci/abci.go +++ b/abci/abci.go @@ -13,7 +13,6 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" sdkmempool "github.com/cosmos/cosmos-sdk/types/mempool" "github.com/skip-mev/pob/mempool" - buildertypes "github.com/skip-mev/pob/x/builder/types" ) type ProposalHandler struct { @@ -49,9 +48,9 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { totalTxBytes int64 ) - bidTxMap := make(map[string]struct{}) bidTxIterator := h.mempool.AuctionBidSelect(ctx) txsToRemove := make(map[sdk.Tx]struct{}, 0) + seenTxs := make(map[string]struct{}, 0) // Attempt to select the highest bid transaction that is valid and whose // bundled transactions are valid. @@ -67,8 +66,8 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { bidTxSize := int64(len(bidTxBz)) if bidTxSize <= req.MaxTxBytes { - bidMsg, ok := tmpBidTx.GetMsgs()[0].(*buildertypes.MsgAuctionBid) - if !ok { + bidMsg, err := mempool.GetMsgAuctionBidFromTx(tmpBidTx) + if err != nil { // This should never happen, as CheckTx will ensure only valid bids // enter the mempool, but in case it does, we need to remove the // transaction from the mempool. @@ -76,8 +75,7 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { continue selectBidTxLoop } - bundledTxsRaw := make([][]byte, len(bidMsg.Transactions)) - for i, refTxRaw := range bidMsg.Transactions { + for _, refTxRaw := range bidMsg.Transactions { refTx, err := h.txDecoder(refTxRaw) if err != nil { // Malformed bundled transaction, so we remove the bid transaction @@ -92,39 +90,31 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { txsToRemove[tmpBidTx] = struct{}{} continue selectBidTxLoop } - - bundledTxsRaw[i] = refTxRaw } // At this point, both the bid transaction itself and all the bundled // transactions are valid. So we select the bid transaction along with - // all the bundled transactions. We also mark these transactions and + // all the bundled transactions. We also mark these transactions as seen and // update the total size selected thus far. totalTxBytes += bidTxSize - - bidTxHash := sha256.Sum256(bidTxBz) - bidTxHashStr := hex.EncodeToString(bidTxHash[:]) - - bidTxMap[bidTxHashStr] = struct{}{} selectedTxs = append(selectedTxs, bidTxBz) + selectedTxs = append(selectedTxs, bidMsg.Transactions...) - for _, refTxRaw := range bundledTxsRaw { - refTxHash := sha256.Sum256(refTxRaw) - refTxHashStr := hex.EncodeToString(refTxHash[:]) - - bidTxMap[refTxHashStr] = struct{}{} - selectedTxs = append(selectedTxs, refTxRaw) + for _, refTxRaw := range bidMsg.Transactions { + hash := sha256.Sum256(refTxRaw) + txHash := hex.EncodeToString(hash[:]) + seenTxs[txHash] = struct{}{} } break selectBidTxLoop } + txsToRemove[tmpBidTx] = struct{}{} h.logger.Info( - "failed to select auction bid tx; tx size is too large; skipping auction", + "failed to select auction bid tx; tx size is too large", "tx_size", bidTxSize, "max_size", req.MaxTxBytes, ) - break selectBidTxLoop } // Remove all invalid transactions from the mempool. @@ -141,37 +131,32 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { for ; iterator != nil; iterator = iterator.Next() { memTx := iterator.Tx() - // We've already selected the highest bid transaction, so we can skip - // all other auction transactions. - isAuctionTx, err := h.isAuctionTx(memTx) + // If the transaction is already included in the proposal, then we skip it. + txBz, err := h.txEncoder(memTx) if err != nil { txsToRemove[memTx] = struct{}{} continue selectTxLoop } - if isAuctionTx { + hash := sha256.Sum256(txBz) + txHash := hex.EncodeToString(hash[:]) + if _, ok := seenTxs[txHash]; ok { continue selectTxLoop } - txBz, err := h.txVerifier.PrepareProposalVerifyTx(memTx) + txBz, err = h.txVerifier.PrepareProposalVerifyTx(memTx) if err != nil { txsToRemove[memTx] = struct{}{} continue selectTxLoop } - // Referenced/bundled transaction may exist in the mempool, so we explicitly - // check prior to considering the transaction. - txHash := sha256.Sum256(txBz) - txHashStr := hex.EncodeToString(txHash[:]) - if _, ok := bidTxMap[txHashStr]; !ok { - txSize := int64(len(txBz)) - if totalTxBytes += txSize; totalTxBytes <= req.MaxTxBytes { - selectedTxs = append(selectedTxs, txBz) - } else { - // We've reached capacity per req.MaxTxBytes so we cannot select any - // more transactions. - break selectTxLoop - } + txSize := int64(len(txBz)) + if totalTxBytes += txSize; totalTxBytes <= req.MaxTxBytes { + selectedTxs = append(selectedTxs, txBz) + } else { + // We've reached capacity per req.MaxTxBytes so we cannot select any + // more transactions. + break selectTxLoop } } @@ -228,12 +213,3 @@ func (h *ProposalHandler) RemoveTx(tx sdk.Tx) { panic(fmt.Errorf("failed to remove invalid transaction from the mempool: %w", err)) } } - -func (h *ProposalHandler) isAuctionTx(tx sdk.Tx) (bool, error) { - msgAuctionBid, err := mempool.GetMsgAuctionBidFromTx(tx) - if err != nil { - return false, err - } - - return msgAuctionBid != nil, nil -} diff --git a/abci/abci_test.go b/abci/abci_test.go index 9974d92..e65228e 100644 --- a/abci/abci_test.go +++ b/abci/abci_test.go @@ -2,6 +2,9 @@ package abci_test import ( "bytes" + "crypto/sha256" + "encoding/hex" + "fmt" "math/rand" "testing" "time" @@ -30,6 +33,7 @@ type ABCITestSuite struct { logger log.Logger encodingConfig testutils.EncodingConfig proposalHandler *abci.ProposalHandler + txs map[string]struct{} // auction bid setup auctionBidAmount sdk.Coin @@ -66,6 +70,7 @@ func (suite *ABCITestSuite) SetupTest() { // Mempool set up suite.mempool = mempool.NewAuctionMempool(suite.encodingConfig.TxConfig.TxDecoder(), suite.encodingConfig.TxConfig.TxEncoder(), 0) + suite.txs = make(map[string]struct{}) suite.auctionBidAmount = sdk.NewCoin("foo", sdk.NewInt(1000000000)) suite.minBidIncrement = sdk.NewCoin("foo", sdk.NewInt(1000)) @@ -116,6 +121,13 @@ func (suite *ABCITestSuite) PrepareProposalVerifyTx(tx sdk.Tx) ([]byte, error) { return nil, err } + hash := sha256.Sum256(txBz) + txHash := hex.EncodeToString(hash[:]) + if _, ok := suite.txs[txHash]; ok { + return nil, fmt.Errorf("tx already in mempool") + } + suite.txs[txHash] = struct{}{} + return txBz, nil } @@ -130,6 +142,13 @@ func (suite *ABCITestSuite) ProcessProposalVerifyTx(txBz []byte) (sdk.Tx, error) return tx, err } + hash := sha256.Sum256(txBz) + txHash := hex.EncodeToString(hash[:]) + if _, ok := suite.txs[txHash]; ok { + return nil, fmt.Errorf("tx already in mempool") + } + suite.txs[txHash] = struct{}{} + return tx, nil } @@ -214,11 +233,11 @@ func (suite *ABCITestSuite) createFilledMempool(numNormalTxs, numAuctionTxs, num var totalNumTxs int suite.Require().Equal(numAuctionTxs, suite.mempool.CountAuctionTx()) if insertRefTxs { - totalNumTxs = numNormalTxs + numAuctionTxs*(numBundledTxs+1) + totalNumTxs = numNormalTxs + numAuctionTxs*(numBundledTxs) suite.Require().Equal(totalNumTxs, suite.mempool.CountTx()) suite.Require().Equal(totalNumTxs, numSeenGlobalTxs) } else { - totalNumTxs = numNormalTxs + numAuctionTxs + totalNumTxs = numNormalTxs suite.Require().Equal(totalNumTxs, suite.mempool.CountTx()) suite.Require().Equal(totalNumTxs, numSeenGlobalTxs) } @@ -297,7 +316,7 @@ func (suite *ABCITestSuite) TestPrepareProposal() { insertRefTxs = true }, 4, - 4, + 3, true, }, { @@ -309,7 +328,7 @@ func (suite *ABCITestSuite) TestPrepareProposal() { insertRefTxs = false }, 4, - 1, + 0, true, }, { @@ -350,7 +369,7 @@ func (suite *ABCITestSuite) TestPrepareProposal() { insertRefTxs = false }, 4, - 10, + 0, true, }, { @@ -362,7 +381,7 @@ func (suite *ABCITestSuite) TestPrepareProposal() { insertRefTxs = true }, 31, - 40, + 30, true, }, { @@ -395,7 +414,7 @@ func (suite *ABCITestSuite) TestPrepareProposal() { numBundledTxs = 0 }, 2, - 2, + 1, true, }, { @@ -407,7 +426,7 @@ func (suite *ABCITestSuite) TestPrepareProposal() { insertRefTxs = false }, 5, - 2, + 1, true, }, { @@ -434,7 +453,7 @@ func (suite *ABCITestSuite) TestPrepareProposal() { numBundledTxs = 0 }, 101, - 101, + 100, true, }, { @@ -446,7 +465,7 @@ func (suite *ABCITestSuite) TestPrepareProposal() { insertRefTxs = true }, 104, - 104, + 103, true, }, { @@ -458,7 +477,7 @@ func (suite *ABCITestSuite) TestPrepareProposal() { insertRefTxs = false }, 104, - 101, + 100, true, }, { @@ -470,7 +489,7 @@ func (suite *ABCITestSuite) TestPrepareProposal() { insertRefTxs = true }, 201, - 300, + 200, true, }, } diff --git a/mempool/mempool.go b/mempool/mempool.go index 2e4bcac..9250eb3 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -106,18 +106,19 @@ func NewAuctionMempool(txDecoder sdk.TxDecoder, txEncoder sdk.TxEncoder, maxTx i // auction tx (tx that contains a single MsgAuctionBid), it will also insert the // transaction into the auction index. func (am *AuctionMempool) Insert(ctx context.Context, tx sdk.Tx) error { - if err := am.globalIndex.Insert(ctx, tx); err != nil { - return fmt.Errorf("failed to insert tx into global index: %w", err) - } - msg, err := GetMsgAuctionBidFromTx(tx) if err != nil { return err } - if msg != nil { + // Insert the transactions into the appropriate index. + switch { + case msg == nil: + if err := am.globalIndex.Insert(ctx, tx); err != nil { + return fmt.Errorf("failed to insert tx into global index: %w", err) + } + case msg != nil: if err := am.auctionIndex.Insert(ctx, tx); err != nil { - am.removeTx(am.globalIndex, tx) return fmt.Errorf("failed to insert tx into auction index: %w", err) } } @@ -136,19 +137,19 @@ func (am *AuctionMempool) Insert(ctx context.Context, tx sdk.Tx) error { // auction tx (tx that contains a single MsgAuctionBid), it will also remove all // referenced transactions from the global mempool. func (am *AuctionMempool) Remove(tx sdk.Tx) error { - // 1. Remove the tx from the global index - am.removeTx(am.globalIndex, tx) - msg, err := GetMsgAuctionBidFromTx(tx) if err != nil { return err } - // 2. Remove the bid from the auction index (if applicable). In addition, we - // remove all referenced transactions from the global mempool. - if msg != nil { + // Remove the transactions from the appropriate index. + switch { + case msg == nil: + am.removeTx(am.globalIndex, tx) + case msg != nil: am.removeTx(am.auctionIndex, tx) + // Remove all referenced transactions from the global mempool. for _, refRawTx := range msg.GetTransactions() { refTx, err := am.txDecoder(refRawTx) if err != nil { @@ -168,8 +169,6 @@ func (am *AuctionMempool) Remove(tx sdk.Tx) error { // API is used to ensure that searchers are unable to remove valid transactions // from the global mempool. func (am *AuctionMempool) RemoveWithoutRefTx(tx sdk.Tx) error { - am.removeTx(am.globalIndex, tx) - msg, err := GetMsgAuctionBidFromTx(tx) if err != nil { return err diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 81ea12a..a2de753 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -112,10 +112,9 @@ func (suite *IntegrationTestSuite) CreateFilledMempool(numNormalTxs, numAuctionT var totalNumTxs int suite.Require().Equal(numAuctionTxs, suite.mempool.CountAuctionTx()) if insertRefTxs { - totalNumTxs = numNormalTxs + numAuctionTxs*(numBundledTxs+1) + totalNumTxs = numNormalTxs + numAuctionTxs*(numBundledTxs) suite.Require().Equal(totalNumTxs, suite.mempool.CountTx()) } else { - totalNumTxs = numNormalTxs + numAuctionTxs suite.Require().Equal(totalNumTxs, suite.mempool.CountTx()) } @@ -138,7 +137,7 @@ func (suite *IntegrationTestSuite) TestAuctionMempoolRemove() { // Ensure that the auction tx was removed from the auction and global mempool suite.Require().Equal(numberAuctionTxs-1, suite.mempool.CountAuctionTx()) - suite.Require().Equal(numMempoolTxs-1, suite.mempool.CountTx()) + suite.Require().Equal(numMempoolTxs, suite.mempool.CountTx()) contains, err := suite.mempool.Contains(tx) suite.Require().NoError(err) suite.Require().False(contains) @@ -146,12 +145,12 @@ func (suite *IntegrationTestSuite) TestAuctionMempoolRemove() { // Attempt to remove again and ensure that the tx is not found suite.Require().NoError(suite.mempool.RemoveWithoutRefTx(tx)) suite.Require().Equal(numberAuctionTxs-1, suite.mempool.CountAuctionTx()) - suite.Require().Equal(numMempoolTxs-1, suite.mempool.CountTx()) + suite.Require().Equal(numMempoolTxs, suite.mempool.CountTx()) // Attempt to remove with the bundled txs suite.Require().NoError(suite.mempool.Remove(tx)) suite.Require().Equal(numberAuctionTxs-1, suite.mempool.CountAuctionTx()) - suite.Require().Equal(numMempoolTxs-numberBundledTxs-1, suite.mempool.CountTx()) + suite.Require().Equal(numMempoolTxs-numberBundledTxs, suite.mempool.CountTx()) auctionMsg, err := mempool.GetMsgAuctionBidFromTx(tx) suite.Require().NoError(err) diff --git a/x/builder/ante/ante_test.go b/x/builder/ante/ante_test.go index a0829b2..82dc102 100644 --- a/x/builder/ante/ante_test.go +++ b/x/builder/ante/ante_test.go @@ -251,7 +251,6 @@ func (suite *AnteTestSuite) TestAnteHandler() { suite.Require().Equal(0, mempool.CountTx()) suite.Require().Equal(0, mempool.CountAuctionTx()) suite.Require().NoError(mempool.Insert(suite.ctx, topAuctionTx)) - suite.Require().Equal(1, mempool.CountTx()) suite.Require().Equal(1, mempool.CountAuctionTx()) }