diff --git a/abci/abci.go b/abci/abci.go index 167a198..3c39164 100644 --- a/abci/abci.go +++ b/abci/abci.go @@ -57,7 +57,7 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { // bundled transactions are valid. selectBidTxLoop: for ; bidTxIterator != nil; bidTxIterator = bidTxIterator.Next() { - tmpBidTx := mempool.UnwrapBidTx(bidTxIterator.Tx()) + tmpBidTx := bidTxIterator.Tx() bidTxBz, err := h.txVerifier.PrepareProposalVerifyTx(tmpBidTx) if err != nil { diff --git a/abci/abci_test.go b/abci/abci_test.go index e72e54e..4dd237e 100644 --- a/abci/abci_test.go +++ b/abci/abci_test.go @@ -234,7 +234,7 @@ func (suite *ABCITestSuite) exportMempool(exportRefTxs bool) [][]byte { auctionIterator := suite.mempool.AuctionBidSelect(suite.ctx) for ; auctionIterator != nil; auctionIterator = auctionIterator.Next() { - auctionTx := auctionIterator.Tx().(*mempool.WrappedBidTx).Tx + auctionTx := auctionIterator.Tx() txBz, err := suite.encodingConfig.TxConfig.TxEncoder()(auctionTx) suite.Require().NoError(err) @@ -771,6 +771,6 @@ func (suite *ABCITestSuite) isTopBidValid() bool { } // check if the top bid is valid - _, err := suite.executeAnteHandler(iterator.Tx().(*mempool.WrappedBidTx).Tx) + _, err := suite.executeAnteHandler(iterator.Tx()) return err == nil } diff --git a/mempool/mempool.go b/mempool/mempool.go index 2b06b60..7a2e451 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -34,7 +34,12 @@ type AuctionMempool struct { func AuctionTxPriority() TxPriority[string] { return TxPriority[string]{ GetTxPriority: func(goCtx context.Context, tx sdk.Tx) string { - return tx.(*WrappedBidTx).GetBid().String() + msgAuctionBid, err := GetMsgAuctionBidFromTx(tx) + if err != nil { + panic(err) + } + + return msgAuctionBid.Bid.String() }, Compare: func(a, b string) int { aCoins, _ := sdk.ParseCoinsNormalized(a) @@ -99,7 +104,7 @@ func (am *AuctionMempool) Insert(ctx context.Context, tx sdk.Tx) error { } if msg != nil { - if err := am.auctionIndex.Insert(ctx, NewWrappedBidTx(tx, msg.GetBid())); err != nil { + if err := am.auctionIndex.Insert(ctx, tx); err != nil { removeTx(am.globalIndex, tx) return fmt.Errorf("failed to insert tx into auction index: %w", err) } @@ -123,7 +128,7 @@ func (am *AuctionMempool) Remove(tx sdk.Tx) error { // 2. Remove the bid from the auction index (if applicable). In addition, we // remove all referenced transactions from the global mempool. if msg != nil { - removeTx(am.auctionIndex, NewWrappedBidTx(tx, msg.GetBid())) + removeTx(am.auctionIndex, tx) for _, refRawTx := range msg.GetTransactions() { refTx, err := am.txDecoder(refRawTx) @@ -152,7 +157,7 @@ func (am *AuctionMempool) RemoveWithoutRefTx(tx sdk.Tx) error { } if msg != nil { - removeTx(am.auctionIndex, NewWrappedBidTx(tx, msg.GetBid())) + removeTx(am.auctionIndex, tx) } return nil diff --git a/mempool/tx.go b/mempool/utils.go similarity index 62% rename from mempool/tx.go rename to mempool/utils.go index 5efb6e2..6bcd4ad 100644 --- a/mempool/tx.go +++ b/mempool/utils.go @@ -4,27 +4,9 @@ import ( "errors" sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/cosmos/cosmos-sdk/x/auth/signing" buildertypes "github.com/skip-mev/pob/x/builder/types" ) -// WrappedBidTx defines a wrapper around an sdk.Tx that contains a single -// MsgAuctionBid message with additional metadata. -type WrappedBidTx struct { - signing.Tx - - bid sdk.Coin -} - -func NewWrappedBidTx(tx sdk.Tx, bid sdk.Coin) *WrappedBidTx { - return &WrappedBidTx{ - Tx: tx.(signing.Tx), - bid: bid, - } -} - -func (wbtx *WrappedBidTx) GetBid() sdk.Coin { return wbtx.bid } - // GetMsgAuctionBidFromTx attempts to retrieve a MsgAuctionBid from an sdk.Tx if // one exists. If a MsgAuctionBid does exist and other messages are also present, // an error is returned. If no MsgAuctionBid is present, is returned. @@ -51,17 +33,3 @@ func GetMsgAuctionBidFromTx(tx sdk.Tx) (*buildertypes.MsgAuctionBid, error) { return nil, errors.New("invalid MsgAuctionBid transaction") } } - -// UnwrapBidTx attempts to unwrap a WrappedBidTx from an sdk.Tx if one exists. -func UnwrapBidTx(tx sdk.Tx) sdk.Tx { - if tx == nil { - return nil - } - - wTx, ok := tx.(*WrappedBidTx) - if ok { - return wTx.Tx - } - - return tx -} diff --git a/mempool/tx_test.go b/mempool/utils_test.go similarity index 67% rename from mempool/tx_test.go rename to mempool/utils_test.go index 00187b2..c2903aa 100644 --- a/mempool/tx_test.go +++ b/mempool/utils_test.go @@ -3,7 +3,6 @@ package mempool_test import ( "testing" - sdk "github.com/cosmos/cosmos-sdk/types" banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" pobcodec "github.com/skip-mev/pob/codec" "github.com/skip-mev/pob/mempool" @@ -47,22 +46,3 @@ func TestGetMsgAuctionBidFromTx_NoBid(t *testing.T) { require.NoError(t, err) require.Nil(t, msg) } - -func TestGetUnwrappedTx(t *testing.T) { - encCfg := pobcodec.CreateEncodingConfig() - - txBuilder := encCfg.TxConfig.NewTxBuilder() - txBuilder.SetMsgs(&buildertypes.MsgAuctionBid{}) - tx := txBuilder.GetTx() - - bid := sdk.NewCoin("foo", sdk.NewInt(1000000)) - wrappedTx := mempool.NewWrappedBidTx(tx, bid) - unWrappedTx := mempool.UnwrapBidTx(wrappedTx) - - unwrappedBz, err := encCfg.TxConfig.TxEncoder()(unWrappedTx) - require.NoError(t, err) - - txBz, err := encCfg.TxConfig.TxEncoder()(tx) - require.NoError(t, err) - require.Equal(t, txBz, unwrappedBz) -} diff --git a/x/builder/ante/ante.go b/x/builder/ante/ante.go index 72dd852..1998876 100644 --- a/x/builder/ante/ante.go +++ b/x/builder/ante/ante.go @@ -83,7 +83,12 @@ func (ad BuilderDecorator) GetTopAuctionBid(ctx sdk.Context) (sdk.Coin, error) { return sdk.Coin{}, nil } - return auctionTx.(*mempool.WrappedBidTx).GetBid(), nil + msgAuctionBid, err := mempool.GetMsgAuctionBidFromTx(auctionTx) + if err != nil { + return sdk.Coin{}, err + } + + return msgAuctionBid.Bid, nil } // IsTopBidTx returns true if the transaction inputted is the highest bidding auction transaction in the mempool. @@ -93,8 +98,7 @@ func (ad BuilderDecorator) IsTopBidTx(ctx sdk.Context, tx sdk.Tx) (bool, error) return false, nil } - topBidTx := mempool.UnwrapBidTx(auctionTx) - topBidBz, err := ad.txEncoder(topBidTx) + topBidBz, err := ad.txEncoder(auctionTx) if err != nil { return false, err }