diff --git a/abci/abci.go b/abci/abci.go index 0b4f384..4bd9d46 100644 --- a/abci/abci.go +++ b/abci/abci.go @@ -50,16 +50,17 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { bidTxMap := make(map[string]struct{}) bidTxIterator := h.mempool.AuctionBidSelect(ctx) + txsToRemove := make(map[sdk.Tx]struct{}, 0) // Attempt to select the highest bid transaction that is valid and whose // bundled transactions are valid. selectBidTxLoop: for ; bidTxIterator != nil; bidTxIterator = bidTxIterator.Next() { - tmpBidTx := bidTxIterator.Tx() + tmpBidTx := mempool.UnwrapBidTx(bidTxIterator.Tx()) bidTxBz, err := h.txVerifier.PrepareProposalVerifyTx(tmpBidTx) if err != nil { - h.RemoveTx(tmpBidTx, true) + txsToRemove[tmpBidTx] = struct{}{} continue selectBidTxLoop } @@ -70,7 +71,7 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { // This should never happen, as CheckTx will ensure only valid bids // enter the mempool, but in case it does, we need to remove the // transaction from the mempool. - h.RemoveTx(tmpBidTx, true) + txsToRemove[tmpBidTx] = struct{}{} continue selectBidTxLoop } @@ -80,14 +81,14 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { if err != nil { // Malformed bundled transaction, so we remove the bid transaction // and try the next top bid. - h.RemoveTx(tmpBidTx, true) + txsToRemove[tmpBidTx] = struct{}{} continue selectBidTxLoop } if _, err := h.txVerifier.PrepareProposalVerifyTx(refTx); err != nil { // Invalid bundled transaction, so we remove the bid transaction // and try the next top bid. - h.RemoveTx(tmpBidTx, true) + txsToRemove[tmpBidTx] = struct{}{} continue selectBidTxLoop } @@ -125,7 +126,13 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { break selectBidTxLoop } - iterator := h.mempool.Select(ctx, req.Txs) + // Remove all invalid transactions from the mempool. + for tx := range txsToRemove { + h.RemoveTx(tx) + } + + iterator := h.mempool.Select(ctx, nil) + txsToRemove = map[sdk.Tx]struct{}{} // Select remaining transactions for the block proposal until we've reached // size capacity. @@ -135,12 +142,11 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { txBz, err := h.txVerifier.PrepareProposalVerifyTx(memTx) if err != nil { - h.RemoveTx(memTx, false) + txsToRemove[memTx] = struct{}{} continue selectTxLoop } - // Referenced/bundled transaction should not exist in the mempool, - // however, we cannot guarantee this won't happen. So, we explicitly + // Referenced/bundled transaction may exist in the mempool, so we explicitly // check prior to considering the transaction. txHash := sha256.Sum256(txBz) txHashStr := hex.EncodeToString(txHash[:]) @@ -156,6 +162,11 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { } } + // Remove all invalid transactions from the mempool. + for tx := range txsToRemove { + h.RemoveTx(tx) + } + return abci.ResponsePrepareProposal{Txs: selectedTxs} } } @@ -168,16 +179,8 @@ func (h *ProposalHandler) ProcessProposalHandler() sdk.ProcessProposalHandler { } } -func (h *ProposalHandler) RemoveTx(tx sdk.Tx, isAuctionTx bool) { - var err error - - if isAuctionTx { - err = h.mempool.RemoveWithoutRefTx(tx) - } else { - err = h.mempool.Remove(tx) - } - - if err != nil && !errors.Is(err, sdkmempool.ErrTxNotFound) { +func (h *ProposalHandler) RemoveTx(tx sdk.Tx) { + if err := h.mempool.RemoveWithoutRefTx(tx); err != nil && !errors.Is(err, sdkmempool.ErrTxNotFound) { panic(fmt.Errorf("failed to remove invalid transaction from the mempool: %w", err)) } } diff --git a/abci/abci_test.go b/abci/abci_test.go index cfe6172..2bb4ebf 100644 --- a/abci/abci_test.go +++ b/abci/abci_test.go @@ -1,12 +1,514 @@ package abci_test import ( + "math/rand" "testing" + "time" + + abcitypes "github.com/cometbft/cometbft/abci/types" + "github.com/cometbft/cometbft/libs/log" + storetypes "github.com/cosmos/cosmos-sdk/store/types" + "github.com/cosmos/cosmos-sdk/testutil" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/golang/mock/gomock" + "github.com/skip-mev/pob/abci" + "github.com/skip-mev/pob/mempool" + "github.com/skip-mev/pob/x/auction/ante" + "github.com/skip-mev/pob/x/auction/keeper" + "github.com/skip-mev/pob/x/auction/types" + "github.com/stretchr/testify/suite" ) -func TestPrepareProposalHandler(t *testing.T) { - // TODO: Implement me! - // - // Ref: ENG-569 - t.SkipNow() +type ABCITestSuite struct { + suite.Suite + ctx sdk.Context + + // mempool setup + mempool *mempool.AuctionMempool + logger log.Logger + encodingConfig encodingConfig + proposalHandler *abci.ProposalHandler + + // auction bid setup + auctionBidAmount sdk.Coins + minBidIncrement sdk.Coins + + // auction setup + auctionKeeper keeper.Keeper + bankKeeper *MockBankKeeper + accountKeeper *MockAccountKeeper + distrKeeper *MockDistributionKeeper + stakingKeeper *MockStakingKeeper + auctionDecorator ante.AuctionDecorator + key *storetypes.KVStoreKey + authorityAccount sdk.AccAddress + + // account set up + accounts []Account + balances sdk.Coins + random *rand.Rand + nonces map[string]uint64 +} + +func TestPrepareProposalSuite(t *testing.T) { + suite.Run(t, new(ABCITestSuite)) +} + +func (suite *ABCITestSuite) SetupTest() { + // General config + suite.encodingConfig = createTestEncodingConfig() + suite.random = rand.New(rand.NewSource(time.Now().Unix())) + suite.key = sdk.NewKVStoreKey(types.StoreKey) + testCtx := testutil.DefaultContextWithDB(suite.T(), suite.key, sdk.NewTransientStoreKey("transient_test")) + suite.ctx = testCtx.Ctx + + // Mempool set up + suite.mempool = mempool.NewAuctionMempool(suite.encodingConfig.TxConfig.TxDecoder(), 0) + suite.auctionBidAmount = sdk.NewCoins(sdk.NewCoin("foo", sdk.NewInt(1000000000))) + suite.minBidIncrement = sdk.NewCoins(sdk.NewCoin("foo", sdk.NewInt(1000))) + + // Mock keepers set up + ctrl := gomock.NewController(suite.T()) + suite.accountKeeper = NewMockAccountKeeper(ctrl) + suite.accountKeeper.EXPECT().GetModuleAddress(types.ModuleName).Return(sdk.AccAddress{}).AnyTimes() + suite.bankKeeper = NewMockBankKeeper(ctrl) + suite.distrKeeper = NewMockDistributionKeeper(ctrl) + suite.stakingKeeper = NewMockStakingKeeper(ctrl) + suite.authorityAccount = sdk.AccAddress([]byte("authority")) + + // Auction keeper / decorator set up + suite.auctionKeeper = keeper.NewKeeper( + suite.encodingConfig.Codec, + suite.key, + suite.accountKeeper, + suite.bankKeeper, + suite.distrKeeper, + suite.stakingKeeper, + suite.authorityAccount.String(), + ) + err := suite.auctionKeeper.SetParams(suite.ctx, types.DefaultParams()) + suite.Require().NoError(err) + suite.auctionDecorator = ante.NewAuctionDecorator(suite.auctionKeeper, suite.encodingConfig.TxConfig.TxDecoder(), suite.mempool) + + // Accounts set up + suite.accounts = RandomAccounts(suite.random, 1) + suite.balances = sdk.NewCoins(sdk.NewCoin("foo", sdk.NewInt(1000000000000000000))) + suite.nonces = make(map[string]uint64) + for _, acc := range suite.accounts { + suite.nonces[acc.Address.String()] = 0 + } + + // Proposal handler set up + suite.logger = log.NewNopLogger() + suite.proposalHandler = abci.NewProposalHandler(suite.mempool, suite.logger, suite, suite.encodingConfig.TxConfig.TxEncoder(), suite.encodingConfig.TxConfig.TxDecoder()) +} + +func (suite *ABCITestSuite) PrepareProposalVerifyTx(tx sdk.Tx) ([]byte, error) { + _, err := suite.executeAnteHandler(tx) + if err != nil { + return nil, err + } + + txBz, err := suite.encodingConfig.TxConfig.TxEncoder()(tx) + if err != nil { + return nil, err + } + + return txBz, nil +} + +func (suite *ABCITestSuite) ProcessProposalVerifyTx(_ []byte) (sdk.Tx, error) { + return nil, nil +} + +func (suite *ABCITestSuite) executeAnteHandler(tx sdk.Tx) (sdk.Context, error) { + signer := tx.GetMsgs()[0].GetSigners()[0] + suite.bankKeeper.EXPECT().GetAllBalances(suite.ctx, signer).AnyTimes().Return(suite.balances) + + next := func(ctx sdk.Context, tx sdk.Tx, simulate bool) (sdk.Context, error) { + return ctx, nil + } + + return suite.auctionDecorator.AnteHandle(suite.ctx, tx, false, next) +} + +func (suite *ABCITestSuite) createFilledMempool(numNormalTxs, numAuctionTxs, numBundledTxs int, insertRefTxs bool) int { + // Insert a bunch of normal transactions into the global mempool + for i := 0; i < numNormalTxs; i++ { + // randomly select an account to create the tx + randomIndex := suite.random.Intn(len(suite.accounts)) + acc := suite.accounts[randomIndex] + + // create a few random msgs + randomMsgs := createRandomMsgs(acc.Address, 3) + + nonce := suite.nonces[acc.Address.String()] + randomTx, err := createTx(suite.encodingConfig.TxConfig, acc, nonce, randomMsgs) + suite.Require().NoError(err) + + suite.nonces[acc.Address.String()]++ + priority := suite.random.Int63n(100) + 1 + suite.Require().NoError(suite.mempool.Insert(suite.ctx.WithPriority(priority), randomTx)) + } + + suite.Require().Equal(numNormalTxs, suite.mempool.CountTx()) + suite.Require().Equal(0, suite.mempool.CountAuctionTx()) + + // Insert a bunch of auction transactions into the global mempool and auction mempool + for i := 0; i < numAuctionTxs; i++ { + // randomly select a bidder to create the tx + randomIndex := suite.random.Intn(len(suite.accounts)) + acc := suite.accounts[randomIndex] + + // create a new auction bid msg with numBundledTxs bundled transactions + nonce := suite.nonces[acc.Address.String()] + bidMsg, err := createMsgAuctionBid(suite.encodingConfig.TxConfig, acc, suite.auctionBidAmount, nonce, numBundledTxs) + suite.nonces[acc.Address.String()] += uint64(numBundledTxs) + suite.Require().NoError(err) + + // create the auction tx + nonce = suite.nonces[acc.Address.String()] + auctionTx, err := createTx(suite.encodingConfig.TxConfig, acc, nonce, []sdk.Msg{bidMsg}) + suite.Require().NoError(err) + + // insert the auction tx into the global mempool + priority := suite.random.Int63n(100) + 1 + suite.Require().NoError(suite.mempool.Insert(suite.ctx.WithPriority(priority), auctionTx)) + suite.nonces[acc.Address.String()]++ + + if insertRefTxs { + for _, refRawTx := range bidMsg.GetTransactions() { + refTx, err := suite.encodingConfig.TxConfig.TxDecoder()(refRawTx) + suite.Require().NoError(err) + priority := suite.random.Int63n(100) + 1 + suite.Require().NoError(suite.mempool.Insert(suite.ctx.WithPriority(priority), refTx)) + } + } + + // decrement the bid amount for the next auction tx + suite.auctionBidAmount = suite.auctionBidAmount.Sub(suite.minBidIncrement...) + } + + numSeenGlobalTxs := 0 + for iterator := suite.mempool.Select(suite.ctx, nil); iterator != nil; iterator = iterator.Next() { + numSeenGlobalTxs++ + } + + numSeenAuctionTxs := 0 + for iterator := suite.mempool.AuctionBidSelect(suite.ctx); iterator != nil; iterator = iterator.Next() { + numSeenAuctionTxs++ + } + + var totalNumTxs int + suite.Require().Equal(numAuctionTxs, suite.mempool.CountAuctionTx()) + if insertRefTxs { + totalNumTxs = numNormalTxs + numAuctionTxs*(numBundledTxs+1) + suite.Require().Equal(totalNumTxs, suite.mempool.CountTx()) + suite.Require().Equal(totalNumTxs, numSeenGlobalTxs) + } else { + totalNumTxs = numNormalTxs + numAuctionTxs + suite.Require().Equal(totalNumTxs, suite.mempool.CountTx()) + suite.Require().Equal(totalNumTxs, numSeenGlobalTxs) + } + + suite.Require().Equal(numAuctionTxs, numSeenAuctionTxs) + + return totalNumTxs +} + +func (suite *ABCITestSuite) TestPrepareProposal() { + var ( + // the modified transactions cannot exceed this size + maxTxBytes int64 = 1000000000000000000 + + // mempool configuration + numNormalTxs = 100 + numAuctionTxs = 100 + numBundledTxs = 3 + insertRefTxs = false + + // auction configuration + maxBundleSize uint32 = 10 + reserveFee = sdk.NewCoins(sdk.NewCoin("foo", sdk.NewInt(1000))) + minBuyInFee = sdk.NewCoins(sdk.NewCoin("foo", sdk.NewInt(1000))) + frontRunningProtection = true + ) + + cases := []struct { + name string + malleate func() + expectedNumberProposalTxs int + expectedNumberTxsInMempool int + isTopBidValid bool + }{ + { + "single bundle in the mempool", + func() { + numNormalTxs = 0 + numAuctionTxs = 1 + numBundledTxs = 3 + insertRefTxs = true + }, + 4, + 4, + true, + }, + { + "single bundle in the mempool, no ref txs in mempool", + func() { + numNormalTxs = 0 + numAuctionTxs = 1 + numBundledTxs = 3 + insertRefTxs = false + }, + 4, + 1, + true, + }, + { + "single bundle in the mempool, not valid", + func() { + reserveFee = sdk.NewCoins(sdk.NewCoin("foo", sdk.NewInt(100000))) + suite.auctionBidAmount = sdk.Coins{sdk.NewCoin("foo", sdk.NewInt(10000))} // this will fail the ante handler + numNormalTxs = 0 + numAuctionTxs = 1 + numBundledTxs = 3 + }, + 0, + 0, + false, + }, + { + "single bundle in the mempool, not valid with ref txs in mempool", + func() { + reserveFee = sdk.NewCoins(sdk.NewCoin("foo", sdk.NewInt(100000))) + suite.auctionBidAmount = sdk.Coins{sdk.NewCoin("foo", sdk.NewInt(10000))} // this will fail the ante handler + numNormalTxs = 0 + numAuctionTxs = 1 + numBundledTxs = 3 + insertRefTxs = true + }, + 3, + 3, + false, + }, + { + "multiple bundles in the mempool, no normal txs + no ref txs in mempool", + func() { + reserveFee = sdk.NewCoins(sdk.NewCoin("foo", sdk.NewInt(1000))) + suite.auctionBidAmount = sdk.Coins{sdk.NewCoin("foo", sdk.NewInt(10000000))} + numNormalTxs = 0 + numAuctionTxs = 10 + numBundledTxs = 3 + insertRefTxs = false + }, + 4, + 1, + true, + }, + { + "multiple bundles in the mempool, no normal txs + ref txs in mempool", + func() { + numNormalTxs = 0 + numAuctionTxs = 10 + numBundledTxs = 3 + insertRefTxs = true + }, + 31, + 31, + true, + }, + { + "normal txs only", + func() { + numNormalTxs = 1 + numAuctionTxs = 0 + numBundledTxs = 0 + }, + 1, + 1, + false, + }, + { + "many normal txs only", + func() { + numNormalTxs = 100 + numAuctionTxs = 0 + numBundledTxs = 0 + }, + 100, + 100, + false, + }, + { + "single normal tx, single auction tx", + func() { + numNormalTxs = 1 + numAuctionTxs = 1 + numBundledTxs = 0 + }, + 2, + 2, + true, + }, + { + "single normal tx, single auction tx with ref txs", + func() { + numNormalTxs = 1 + numAuctionTxs = 1 + numBundledTxs = 3 + insertRefTxs = false + }, + 5, + 2, + true, + }, + { + "single normal tx, single failing auction tx with ref txs", + func() { + numNormalTxs = 1 + numAuctionTxs = 1 + numBundledTxs = 3 + insertRefTxs = true + suite.auctionBidAmount = sdk.Coins{sdk.NewCoin("foo", sdk.NewInt(2000))} // this will fail the ante handler + reserveFee = sdk.NewCoins(sdk.NewCoin("foo", sdk.NewInt(1000000000))) + }, + 4, + 4, + false, + }, + { + "many normal tx, single auction tx with no ref txs", + func() { + reserveFee = sdk.NewCoins(sdk.NewCoin("foo", sdk.NewInt(1000))) + suite.auctionBidAmount = sdk.Coins{sdk.NewCoin("foo", sdk.NewInt(2000000))} + numNormalTxs = 100 + numAuctionTxs = 1 + numBundledTxs = 0 + }, + 101, + 101, + true, + }, + { + "many normal tx, single auction tx with ref txs", + func() { + numNormalTxs = 100 + numAuctionTxs = 1 + numBundledTxs = 3 + insertRefTxs = true + }, + 104, + 104, + true, + }, + { + "many normal tx, single auction tx with ref txs", + func() { + numNormalTxs = 100 + numAuctionTxs = 1 + numBundledTxs = 3 + insertRefTxs = false + }, + 104, + 101, + true, + }, + { + "many normal tx, many auction tx with ref txs", + func() { + numNormalTxs = 100 + numAuctionTxs = 100 + numBundledTxs = 1 + insertRefTxs = true + }, + 201, + 201, + true, + }, + } + + for _, tc := range cases { + suite.Run(tc.name, func() { + suite.SetupTest() // reset + tc.malleate() + + suite.createFilledMempool(numNormalTxs, numAuctionTxs, numBundledTxs, insertRefTxs) + + // create a new auction + params := types.Params{ + MaxBundleSize: maxBundleSize, + ReserveFee: reserveFee, + MinBuyInFee: minBuyInFee, + FrontRunningProtection: frontRunningProtection, + MinBidIncrement: suite.minBidIncrement, + } + suite.auctionKeeper.SetParams(suite.ctx, params) + suite.auctionDecorator = ante.NewAuctionDecorator(suite.auctionKeeper, suite.encodingConfig.TxConfig.TxDecoder(), suite.mempool) + + handler := suite.proposalHandler.PrepareProposalHandler() + res := handler(suite.ctx, abcitypes.RequestPrepareProposal{ + MaxTxBytes: maxTxBytes, + }) + + // -------------------- Check Invariants -------------------- // + // 1. The auction tx must fail if we know it is invalid + suite.Require().Equal(tc.isTopBidValid, suite.isTopBidValid()) + + // 2. total bytes must be less than or equal to maxTxBytes + totalBytes := int64(0) + if suite.isTopBidValid() { + totalBytes += int64(len(res.Txs[0])) + + for _, tx := range res.Txs[1+numBundledTxs:] { + totalBytes += int64(len(tx)) + } + } else { + for _, tx := range res.Txs { + totalBytes += int64(len(tx)) + } + } + suite.Require().LessOrEqual(totalBytes, maxTxBytes) + + // 3. the number of transactions in the response must be equal to the number of expected transactions + suite.Require().Equal(tc.expectedNumberProposalTxs, len(res.Txs)) + + // 4. if there are auction transactions, the first transaction must be the top bid + // and the rest of the bundle must be in the response + if suite.isTopBidValid() { + auctionTx, err := suite.encodingConfig.TxConfig.TxDecoder()(res.Txs[0]) + suite.Require().NoError(err) + + msgAuctionBid, err := mempool.GetMsgAuctionBidFromTx(auctionTx) + suite.Require().NoError(err) + + for index, tx := range msgAuctionBid.GetTransactions() { + suite.Require().Equal(tx, res.Txs[index+1]) + } + } + + // 5. All of the transactions must be unique + uniqueTxs := make(map[string]bool) + for _, tx := range res.Txs { + suite.Require().False(uniqueTxs[string(tx)]) + uniqueTxs[string(tx)] = true + } + + // 6. The number of transactions in the mempool must be correct + suite.Require().Equal(tc.expectedNumberTxsInMempool, suite.mempool.CountTx()) + }) + } +} + +// isTopBidValid returns true if the top bid is valid. We purposefully insert invalid +// auction transactions into the mempool to test the handlers. +func (suite *ABCITestSuite) isTopBidValid() bool { + iterator := suite.mempool.AuctionBidSelect(suite.ctx) + if iterator == nil { + return false + } + + // check if the top bid is valid + _, err := suite.executeAnteHandler(iterator.Tx().(*mempool.WrappedBidTx).Tx) + return err == nil } diff --git a/abci/mocks_test.go b/abci/mocks_test.go new file mode 100644 index 0000000..2c18704 --- /dev/null +++ b/abci/mocks_test.go @@ -0,0 +1,144 @@ +package abci_test + +import ( + "reflect" + + sdk "github.com/cosmos/cosmos-sdk/types" + stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" + "github.com/golang/mock/gomock" +) + +type MockAccountKeeper struct { + ctrl *gomock.Controller + recorder *MockAccountKeeperMockRecorder +} + +type MockAccountKeeperMockRecorder struct { + mock *MockAccountKeeper +} + +func NewMockAccountKeeper(ctrl *gomock.Controller) *MockAccountKeeper { + mock := &MockAccountKeeper{ctrl: ctrl} + mock.recorder = &MockAccountKeeperMockRecorder{mock} + return mock +} + +func (m *MockAccountKeeper) EXPECT() *MockAccountKeeperMockRecorder { + return m.recorder +} + +func (m *MockAccountKeeper) GetModuleAddress(name string) sdk.AccAddress { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetModuleAddress", name) + ret0, _ := ret[0].(sdk.AccAddress) + return ret0 +} + +func (mr *MockAccountKeeperMockRecorder) GetModuleAddress(name interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetModuleAddress", reflect.TypeOf((*MockAccountKeeper)(nil).GetModuleAddress), name) +} + +type MockBankKeeper struct { + ctrl *gomock.Controller + recorder *MockBankKeeperMockRecorder +} + +type MockBankKeeperMockRecorder struct { + mock *MockBankKeeper +} + +func NewMockBankKeeper(ctrl *gomock.Controller) *MockBankKeeper { + mock := &MockBankKeeper{ctrl: ctrl} + mock.recorder = &MockBankKeeperMockRecorder{mock} + return mock +} + +func (m *MockBankKeeper) EXPECT() *MockBankKeeperMockRecorder { + return m.recorder +} + +func (m *MockBankKeeper) GetAllBalances(ctx sdk.Context, addr sdk.AccAddress) sdk.Coins { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllBalances", ctx, addr) + ret0 := ret[0].(sdk.Coins) + return ret0 +} + +func (mr *MockBankKeeperMockRecorder) GetAllBalances(ctx, addr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllBalances", reflect.TypeOf((*MockBankKeeper)(nil).GetAllBalances), ctx, addr) +} + +func (m *MockBankKeeper) SendCoins(ctx sdk.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SendCoins", ctx, fromAddr, toAddr, amt) + return nil +} + +func (mr *MockBankKeeperMockRecorder) SendCoins(ctx, fromAddr, toAddr, amt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendCoins", reflect.TypeOf((*MockBankKeeper)(nil).SendCoins), ctx, fromAddr, toAddr, amt) +} + +type MockDistributionKeeperRecorder struct { + mock *MockDistributionKeeper +} + +type MockDistributionKeeper struct { + ctrl *gomock.Controller + recorder *MockDistributionKeeperRecorder +} + +func NewMockDistributionKeeper(ctrl *gomock.Controller) *MockDistributionKeeper { + mock := &MockDistributionKeeper{ctrl: ctrl} + mock.recorder = &MockDistributionKeeperRecorder{mock} + return mock +} + +func (m *MockDistributionKeeper) EXPECT() *MockDistributionKeeperRecorder { + return m.recorder +} + +func (m *MockDistributionKeeper) GetPreviousProposerConsAddr(ctx sdk.Context) sdk.ConsAddress { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPreviousProposerConsAddr", ctx) + ret0 := ret[0].(sdk.ConsAddress) + return ret0 +} + +func (mr *MockDistributionKeeperRecorder) GetPreviousProposerConsAddr(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPreviousProposerConsAddr", reflect.TypeOf((*MockDistributionKeeper)(nil).GetPreviousProposerConsAddr), ctx) +} + +type MockStakingKeeperRecorder struct { + mock *MockStakingKeeper +} + +type MockStakingKeeper struct { + ctrl *gomock.Controller + recorder *MockStakingKeeperRecorder +} + +func NewMockStakingKeeper(ctrl *gomock.Controller) *MockStakingKeeper { + mock := &MockStakingKeeper{ctrl: ctrl} + mock.recorder = &MockStakingKeeperRecorder{mock} + return mock +} + +func (m *MockStakingKeeper) EXPECT() *MockStakingKeeperRecorder { + return m.recorder +} + +func (m *MockStakingKeeper) ValidatorByConsAddr(ctx sdk.Context, consAddr sdk.ConsAddress) stakingtypes.ValidatorI { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidatorByConsAddr", ctx, consAddr) + ret0 := ret[0].(stakingtypes.ValidatorI) + return ret0 +} + +func (mr *MockStakingKeeperRecorder) ValidatorByConsAddr(ctx, consAddr any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidatorByConsAddr", reflect.TypeOf((*MockStakingKeeper)(nil).ValidatorByConsAddr), ctx, consAddr) +} diff --git a/abci/utils_test.go b/abci/utils_test.go new file mode 100644 index 0000000..d1b2a53 --- /dev/null +++ b/abci/utils_test.go @@ -0,0 +1,148 @@ +package abci_test + +import ( + "math/rand" + + "github.com/cosmos/cosmos-sdk/client" + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/codec/types" + cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec" + "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" + "github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1" + cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/tx/signing" + authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" + "github.com/cosmos/cosmos-sdk/x/auth/tx" + banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" + auctiontypes "github.com/skip-mev/pob/x/auction/types" +) + +type encodingConfig struct { + InterfaceRegistry types.InterfaceRegistry + Codec codec.Codec + TxConfig client.TxConfig + Amino *codec.LegacyAmino +} + +func createTestEncodingConfig() encodingConfig { + cdc := codec.NewLegacyAmino() + interfaceRegistry := types.NewInterfaceRegistry() + + banktypes.RegisterInterfaces(interfaceRegistry) + cryptocodec.RegisterInterfaces(interfaceRegistry) + auctiontypes.RegisterInterfaces(interfaceRegistry) + + codec := codec.NewProtoCodec(interfaceRegistry) + + return encodingConfig{ + InterfaceRegistry: interfaceRegistry, + Codec: codec, + TxConfig: tx.NewTxConfig(codec, tx.DefaultSignModes), + Amino: cdc, + } +} + +type Account struct { + PrivKey cryptotypes.PrivKey + PubKey cryptotypes.PubKey + Address sdk.AccAddress + ConsKey cryptotypes.PrivKey +} + +func (acc Account) Equals(acc2 Account) bool { + return acc.Address.Equals(acc2.Address) +} + +func RandomAccounts(r *rand.Rand, n int) []Account { + accs := make([]Account, n) + + for i := 0; i < n; i++ { + pkSeed := make([]byte, 15) + r.Read(pkSeed) + + accs[i].PrivKey = secp256k1.GenPrivKeyFromSecret(pkSeed) + accs[i].PubKey = accs[i].PrivKey.PubKey() + accs[i].Address = sdk.AccAddress(accs[i].PubKey.Address()) + + accs[i].ConsKey = ed25519.GenPrivKeyFromSecret(pkSeed) + } + + return accs +} + +func createTx(txCfg client.TxConfig, account Account, nonce uint64, msgs []sdk.Msg) (authsigning.Tx, error) { + txBuilder := txCfg.NewTxBuilder() + if err := txBuilder.SetMsgs(msgs...); err != nil { + return nil, err + } + + sigV2 := signing.SignatureV2{ + PubKey: account.PrivKey.PubKey(), + Data: &signing.SingleSignatureData{ + SignMode: txCfg.SignModeHandler().DefaultMode(), + Signature: nil, + }, + Sequence: nonce, + } + if err := txBuilder.SetSignatures(sigV2); err != nil { + return nil, err + } + + return txBuilder.GetTx(), nil +} + +func createRandomMsgs(acc sdk.AccAddress, numberMsgs int) []sdk.Msg { + msgs := make([]sdk.Msg, numberMsgs) + for i := 0; i < numberMsgs; i++ { + msgs[i] = &banktypes.MsgSend{ + FromAddress: acc.String(), + ToAddress: acc.String(), + } + } + + return msgs +} + +func createMsgAuctionBid(txCfg client.TxConfig, bidder Account, bid sdk.Coins, nonce uint64, numberMsgs int) (*auctiontypes.MsgAuctionBid, error) { + bidMsg := &auctiontypes.MsgAuctionBid{ + Bidder: bidder.Address.String(), + Bid: bid, + Transactions: make([][]byte, numberMsgs), + } + + for i := 0; i < numberMsgs; i++ { + txBuilder := txCfg.NewTxBuilder() + + msgs := []sdk.Msg{ + &banktypes.MsgSend{ + FromAddress: bidder.Address.String(), + ToAddress: bidder.Address.String(), + }, + } + if err := txBuilder.SetMsgs(msgs...); err != nil { + return nil, err + } + + sigV2 := signing.SignatureV2{ + PubKey: bidder.PrivKey.PubKey(), + Data: &signing.SingleSignatureData{ + SignMode: txCfg.SignModeHandler().DefaultMode(), + Signature: nil, + }, + Sequence: nonce + uint64(i), + } + if err := txBuilder.SetSignatures(sigV2); err != nil { + return nil, err + } + + bz, err := txCfg.TxEncoder()(txBuilder.GetTx()) + if err != nil { + return nil, err + } + + bidMsg.Transactions[i] = bz + } + + return bidMsg, nil +} diff --git a/mempool/tx.go b/mempool/tx.go index 24d1048..6bda137 100644 --- a/mempool/tx.go +++ b/mempool/tx.go @@ -51,3 +51,13 @@ func GetMsgAuctionBidFromTx(tx sdk.Tx) (*auctiontypes.MsgAuctionBid, error) { return nil, errors.New("invalid MsgAuctionBid transaction") } } + +// UnwrapBidTx attempts to unwrap a WrappedBidTx from an sdk.Tx if one exists. +func UnwrapBidTx(tx sdk.Tx) sdk.Tx { + wTx, ok := tx.(*WrappedBidTx) + if ok { + return wTx.Tx + } + + return tx +} diff --git a/x/auction/ante/ante.go b/x/auction/ante/ante.go index 7aaa958..aae4a60 100644 --- a/x/auction/ante/ante.go +++ b/x/auction/ante/ante.go @@ -48,7 +48,7 @@ func (ad AuctionDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, transactions[i] = decodedTx } - highestBid, err := ad.GetHighestAuctionBid(ctx) + highestBid, err := ad.GetTopAuctionBid(ctx, tx) if err != nil { return ctx, errors.Wrap(err, "failed to get highest auction bid") } @@ -61,12 +61,18 @@ func (ad AuctionDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, return next(ctx, tx, simulate) } -// GetHighestAuctionBid returns the highest auction bid if one exists. -func (ad AuctionDecorator) GetHighestAuctionBid(ctx sdk.Context) (sdk.Coins, error) { +// GetTopAuctionBid returns the highest auction bid if one exists. If the current transaction is the highest +// bidding transaction, then an empty coin set is returned. +func (ad AuctionDecorator) GetTopAuctionBid(ctx sdk.Context, currTx sdk.Tx) (sdk.Coins, error) { auctionTx := ad.mempool.GetTopAuctionTx(ctx) if auctionTx == nil { return sdk.NewCoins(), nil } - return auctionTx.(*mempool.WrappedBidTx).GetBid(), nil + wrappedTx := auctionTx.(*mempool.WrappedBidTx) + if wrappedTx.Tx == currTx { + return sdk.NewCoins(), nil + } + + return wrappedTx.GetBid(), nil } diff --git a/x/auction/keeper/keeper_test.go b/x/auction/keeper/keeper_test.go index 3508f4a..a124516 100644 --- a/x/auction/keeper/keeper_test.go +++ b/x/auction/keeper/keeper_test.go @@ -8,7 +8,6 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/golang/mock/gomock" "github.com/skip-mev/pob/mempool" - "github.com/skip-mev/pob/x/auction/ante" "github.com/skip-mev/pob/x/auction/keeper" "github.com/skip-mev/pob/x/auction/types" @@ -24,7 +23,6 @@ type KeeperTestSuite struct { distrKeeper *MockDistributionKeeper stakingKeeper *MockStakingKeeper encCfg encodingConfig - AuctionDecorator sdk.AnteDecorator ctx sdk.Context msgServer types.MsgServer key *storetypes.KVStoreKey @@ -66,6 +64,5 @@ func (suite *KeeperTestSuite) SetupTest() { suite.Require().NoError(err) suite.mempool = mempool.NewAuctionMempool(suite.encCfg.TxConfig.TxDecoder(), 0) - suite.AuctionDecorator = ante.NewAuctionDecorator(suite.auctionKeeper, suite.encCfg.TxConfig.TxDecoder(), suite.mempool) suite.msgServer = keeper.NewMsgServerImpl(suite.auctionKeeper) }