From 0579f8b1a2d164e4e8dfaddc52a1cc020f189b0c Mon Sep 17 00:00:00 2001 From: David Terpay <35130517+davidterpay@users.noreply.github.com> Date: Mon, 10 Apr 2023 12:22:55 -0400 Subject: [PATCH] [ENG-681]: Comet ReCheckTx fix for the app-side mempool (#53) Co-authored-by: Aleksandr Bezobchuk --- abci/abci_test.go | 2 +- mempool/mempool.go | 66 +++++++++++++++++++++++++++++---- mempool/mempool_test.go | 22 ++++++++++- x/builder/ante/ante.go | 12 ++++++ x/builder/ante/ante_test.go | 2 +- x/builder/keeper/keeper_test.go | 2 +- 6 files changed, 94 insertions(+), 12 deletions(-) diff --git a/abci/abci_test.go b/abci/abci_test.go index 0e551c2..9974d92 100644 --- a/abci/abci_test.go +++ b/abci/abci_test.go @@ -65,7 +65,7 @@ func (suite *ABCITestSuite) SetupTest() { suite.ctx = testCtx.Ctx // Mempool set up - suite.mempool = mempool.NewAuctionMempool(suite.encodingConfig.TxConfig.TxDecoder(), 0) + suite.mempool = mempool.NewAuctionMempool(suite.encodingConfig.TxConfig.TxDecoder(), suite.encodingConfig.TxConfig.TxEncoder(), 0) suite.auctionBidAmount = sdk.NewCoin("foo", sdk.NewInt(1000000000)) suite.minBidIncrement = sdk.NewCoin("foo", sdk.NewInt(1000)) diff --git a/mempool/mempool.go b/mempool/mempool.go index 7a2e451..2e4bcac 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -2,6 +2,8 @@ package mempool import ( "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" @@ -27,6 +29,14 @@ type AuctionMempool struct { // txDecoder defines the sdk.Tx decoder that allows us to decode transactions // and construct sdk.Txs from the bundled transactions. txDecoder sdk.TxDecoder + + // txEncoder defines the sdk.Tx encoder that allows us to encode transactions + // to bytes. + txEncoder sdk.TxEncoder + + // txIndex is a map of all transactions in the mempool. It is used + // to quickly check if a transaction is already in the mempool. + txIndex map[string]struct{} } // AuctionTxPriority returns a TxPriority over auction bid transactions only. It @@ -72,7 +82,7 @@ func AuctionTxPriority() TxPriority[string] { } } -func NewAuctionMempool(txDecoder sdk.TxDecoder, maxTx int) *AuctionMempool { +func NewAuctionMempool(txDecoder sdk.TxDecoder, txEncoder sdk.TxEncoder, maxTx int) *AuctionMempool { return &AuctionMempool{ globalIndex: NewPriorityMempool( PriorityNonceMempoolConfig[int64]{ @@ -87,6 +97,8 @@ func NewAuctionMempool(txDecoder sdk.TxDecoder, maxTx int) *AuctionMempool { }, ), txDecoder: txDecoder, + txEncoder: txEncoder, + txIndex: make(map[string]struct{}), } } @@ -105,11 +117,18 @@ func (am *AuctionMempool) Insert(ctx context.Context, tx sdk.Tx) error { if msg != nil { if err := am.auctionIndex.Insert(ctx, tx); err != nil { - removeTx(am.globalIndex, tx) + am.removeTx(am.globalIndex, tx) return fmt.Errorf("failed to insert tx into auction index: %w", err) } } + txHashStr, err := am.getTxHashStr(tx) + if err != nil { + return err + } + + am.txIndex[txHashStr] = struct{}{} + return nil } @@ -118,7 +137,7 @@ func (am *AuctionMempool) Insert(ctx context.Context, tx sdk.Tx) error { // referenced transactions from the global mempool. func (am *AuctionMempool) Remove(tx sdk.Tx) error { // 1. Remove the tx from the global index - removeTx(am.globalIndex, tx) + am.removeTx(am.globalIndex, tx) msg, err := GetMsgAuctionBidFromTx(tx) if err != nil { @@ -128,7 +147,7 @@ func (am *AuctionMempool) Remove(tx sdk.Tx) error { // 2. Remove the bid from the auction index (if applicable). In addition, we // remove all referenced transactions from the global mempool. if msg != nil { - removeTx(am.auctionIndex, tx) + am.removeTx(am.auctionIndex, tx) for _, refRawTx := range msg.GetTransactions() { refTx, err := am.txDecoder(refRawTx) @@ -136,7 +155,7 @@ func (am *AuctionMempool) Remove(tx sdk.Tx) error { return fmt.Errorf("failed to decode referenced tx: %w", err) } - removeTx(am.globalIndex, refTx) + am.removeTx(am.globalIndex, refTx) } } @@ -149,7 +168,7 @@ func (am *AuctionMempool) Remove(tx sdk.Tx) error { // API is used to ensure that searchers are unable to remove valid transactions // from the global mempool. func (am *AuctionMempool) RemoveWithoutRefTx(tx sdk.Tx) error { - removeTx(am.globalIndex, tx) + am.removeTx(am.globalIndex, tx) msg, err := GetMsgAuctionBidFromTx(tx) if err != nil { @@ -157,7 +176,7 @@ func (am *AuctionMempool) RemoveWithoutRefTx(tx sdk.Tx) error { } if msg != nil { - removeTx(am.auctionIndex, tx) + am.removeTx(am.auctionIndex, tx) } return nil @@ -190,9 +209,40 @@ func (am *AuctionMempool) CountTx() int { return am.globalIndex.CountTx() } -func removeTx(mp sdkmempool.Mempool, tx sdk.Tx) { +// Contains returns true if the transaction is contained in the mempool. +func (am *AuctionMempool) Contains(tx sdk.Tx) (bool, error) { + txHashStr, err := am.getTxHashStr(tx) + if err != nil { + return false, fmt.Errorf("failed to get tx hash string: %w", err) + } + + _, ok := am.txIndex[txHashStr] + return ok, nil +} + +// getTxHashStr returns the transaction hash string for a given transaction. +func (am *AuctionMempool) getTxHashStr(tx sdk.Tx) (string, error) { + txBz, err := am.txEncoder(tx) + if err != nil { + return "", fmt.Errorf("failed to encode transaction: %w", err) + } + + txHash := sha256.Sum256(txBz) + txHashStr := hex.EncodeToString(txHash[:]) + + return txHashStr, nil +} + +func (am *AuctionMempool) removeTx(mp sdkmempool.Mempool, tx sdk.Tx) { err := mp.Remove(tx) if err != nil && !errors.Is(err, sdkmempool.ErrTxNotFound) { panic(fmt.Errorf("failed to remove invalid transaction from the mempool: %w", err)) } + + txHashStr, err := am.getTxHashStr(tx) + if err != nil { + panic(fmt.Errorf("failed to get tx hash string: %w", err)) + } + + delete(am.txIndex, txHashStr) } diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index cf1c079..81ea12a 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -32,7 +32,7 @@ func TestMempoolTestSuite(t *testing.T) { func (suite *IntegrationTestSuite) SetupTest() { // Mempool setup suite.encCfg = testutils.CreateTestEncodingConfig() - suite.mempool = mempool.NewAuctionMempool(suite.encCfg.TxConfig.TxDecoder(), 0) + suite.mempool = mempool.NewAuctionMempool(suite.encCfg.TxConfig.TxDecoder(), suite.encCfg.TxConfig.TxEncoder(), 0) suite.ctx = sdk.NewContext(nil, cmtproto.Header{}, false, log.NewNopLogger()) // Init accounts @@ -64,6 +64,9 @@ func (suite *IntegrationTestSuite) CreateFilledMempool(numNormalTxs, numAuctionT suite.nonces[acc.Address.String()]++ priority := suite.random.Int63n(100) + 1 suite.Require().NoError(suite.mempool.Insert(suite.ctx.WithPriority(priority), randomTx)) + contains, err := suite.mempool.Contains(randomTx) + suite.Require().NoError(err) + suite.Require().True(contains) } suite.Require().Equal(numNormalTxs, suite.mempool.CountTx()) @@ -89,6 +92,9 @@ func (suite *IntegrationTestSuite) CreateFilledMempool(numNormalTxs, numAuctionT // insert the auction tx into the global mempool suite.Require().NoError(suite.mempool.Insert(suite.ctx.WithPriority(priority), auctionTx)) + contains, err := suite.mempool.Contains(auctionTx) + suite.Require().NoError(err) + suite.Require().True(contains) suite.nonces[acc.Address.String()]++ if insertRefTxs { @@ -96,6 +102,9 @@ func (suite *IntegrationTestSuite) CreateFilledMempool(numNormalTxs, numAuctionT refTx, err := suite.encCfg.TxConfig.TxDecoder()(refRawTx) suite.Require().NoError(err) suite.Require().NoError(suite.mempool.Insert(suite.ctx.WithPriority(priority), refTx)) + contains, err = suite.mempool.Contains(refTx) + suite.Require().NoError(err) + suite.Require().True(contains) } } } @@ -130,6 +139,9 @@ func (suite *IntegrationTestSuite) TestAuctionMempoolRemove() { // Ensure that the auction tx was removed from the auction and global mempool suite.Require().Equal(numberAuctionTxs-1, suite.mempool.CountAuctionTx()) suite.Require().Equal(numMempoolTxs-1, suite.mempool.CountTx()) + contains, err := suite.mempool.Contains(tx) + suite.Require().NoError(err) + suite.Require().False(contains) // Attempt to remove again and ensure that the tx is not found suite.Require().NoError(suite.mempool.RemoveWithoutRefTx(tx)) @@ -140,6 +152,14 @@ func (suite *IntegrationTestSuite) TestAuctionMempoolRemove() { suite.Require().NoError(suite.mempool.Remove(tx)) suite.Require().Equal(numberAuctionTxs-1, suite.mempool.CountAuctionTx()) suite.Require().Equal(numMempoolTxs-numberBundledTxs-1, suite.mempool.CountTx()) + + auctionMsg, err := mempool.GetMsgAuctionBidFromTx(tx) + suite.Require().NoError(err) + for _, refTx := range auctionMsg.GetTransactions() { + tx, err := suite.encCfg.TxConfig.TxDecoder()(refTx) + suite.Require().NoError(err) + suite.Require().False(suite.mempool.Contains(tx)) + } } func (suite *IntegrationTestSuite) TestAuctionMempoolSelect() { diff --git a/x/builder/ante/ante.go b/x/builder/ante/ante.go index d0e169d..9058f01 100644 --- a/x/builder/ante/ante.go +++ b/x/builder/ante/ante.go @@ -39,6 +39,18 @@ func NewBuilderDecorator(ak keeper.Keeper, txDecoder sdk.TxDecoder, txEncoder sd // AnteHandle validates that the auction bid is valid if one exists. If valid it will deduct the entrance fee from the // bidder's account. func (ad BuilderDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) { + // If comet is re-checking a transaction, we only need to check if the transaction is in the application-side mempool. + if ctx.IsReCheckTx() { + contains, err := ad.mempool.Contains(tx) + if err != nil { + return ctx, err + } + + if !contains { + return ctx, fmt.Errorf("transaction not found in application mempool") + } + } + auctionMsg, err := mempool.GetMsgAuctionBidFromTx(tx) if err != nil { return ctx, err diff --git a/x/builder/ante/ante_test.go b/x/builder/ante/ante_test.go index 5838306..a0829b2 100644 --- a/x/builder/ante/ante_test.go +++ b/x/builder/ante/ante_test.go @@ -244,7 +244,7 @@ func (suite *AnteTestSuite) TestAnteHandler() { suite.Require().NoError(err) // Insert the top bid into the mempool - mempool := mempool.NewAuctionMempool(suite.encodingConfig.TxConfig.TxDecoder(), 0) + mempool := mempool.NewAuctionMempool(suite.encodingConfig.TxConfig.TxDecoder(), suite.encodingConfig.TxConfig.TxEncoder(), 0) if insertTopBid { topAuctionTx, err := testutils.CreateAuctionTxWithSigners(suite.encodingConfig.TxConfig, topBidder, topBid, 0, timeout, []testutils.Account{}) suite.Require().NoError(err) diff --git a/x/builder/keeper/keeper_test.go b/x/builder/keeper/keeper_test.go index 3c168fb..2dcb7d5 100644 --- a/x/builder/keeper/keeper_test.go +++ b/x/builder/keeper/keeper_test.go @@ -64,6 +64,6 @@ func (suite *KeeperTestSuite) SetupTest() { err := suite.builderKeeper.SetParams(suite.ctx, types.DefaultParams()) suite.Require().NoError(err) - suite.mempool = mempool.NewAuctionMempool(suite.encCfg.TxConfig.TxDecoder(), 0) + suite.mempool = mempool.NewAuctionMempool(suite.encCfg.TxConfig.TxDecoder(), suite.encCfg.TxConfig.TxEncoder(), 0) suite.msgServer = keeper.NewMsgServerImpl(suite.builderKeeper) }