diff --git a/abci/abci.go b/abci/abci.go index a5531cc..09ec02a 100644 --- a/abci/abci.go +++ b/abci/abci.go @@ -48,7 +48,7 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { totalTxBytes int64 ) - bidTxMap := make(map[string]struct{}) //nolint + bidTxMap := make(map[string]struct{}) bidTxIterator := h.mempool.AuctionBidSelect(ctx) // Attempt to select the highest bid transaction that is valid and whose @@ -59,7 +59,7 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { bidTxBz, err := h.txVerifier.PrepareProposalVerifyTx(tmpBidTx) if err != nil { - h.RemoveTx(tmpBidTx) + h.RemoveTx(tmpBidTx, true) continue selectBidTxLoop } @@ -70,7 +70,7 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { // This should never happen, as CheckTx will ensure only valid bids // enter the mempool, but in case it does, we need to remove the // transaction from the mempool. - h.RemoveTx(tmpBidTx) + h.RemoveTx(tmpBidTx, true) continue selectBidTxLoop } @@ -80,14 +80,14 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { if err != nil { // Malformed bundled transaction, so we remove the bid transaction // and try the next top bid. - h.RemoveTx(tmpBidTx) + h.RemoveTx(tmpBidTx, true) continue selectBidTxLoop } if _, err := h.txVerifier.PrepareProposalVerifyTx(refTx); err != nil { // Invalid bundled transaction, so we remove the bid transaction // and try the next top bid. - h.RemoveTx(tmpBidTx) + h.RemoveTx(tmpBidTx, true) continue selectBidTxLoop } @@ -135,7 +135,7 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { txBz, err := h.txVerifier.PrepareProposalVerifyTx(memTx) if err != nil { - h.RemoveTx(memTx) + h.RemoveTx(memTx, false) continue selectTxLoop } @@ -168,8 +168,15 @@ func (h *ProposalHandler) ProcessProposalHandler() sdk.ProcessProposalHandler { } } -func (h *ProposalHandler) RemoveTx(tx sdk.Tx) { - err := h.mempool.Remove(tx) +func (h *ProposalHandler) RemoveTx(tx sdk.Tx, isAuctionTx bool) { + var err error + + if isAuctionTx { + err = h.mempool.RemoveWithoutRefTx(tx) + } else { + err = h.mempool.Remove(tx) + } + if err != nil && !errors.Is(err, sdkmempool.ErrTxNotFound) { panic(fmt.Errorf("failed to remove invalid transaction from the mempool: %w", err)) } diff --git a/mempool/bid_list.go b/mempool/bid_list.go deleted file mode 100644 index 9d66f93..0000000 --- a/mempool/bid_list.go +++ /dev/null @@ -1,58 +0,0 @@ -package mempool - -import ( - sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/huandu/skiplist" -) - -type ( - // AuctionBidList defines a list of WrappedBidTx objects, sorted by their bids. - - AuctionBidList struct { - list *skiplist.SkipList - } - - auctionBidListKey struct { - bid sdk.Coins - hash []byte - } -) - -func NewAuctionBidList() *AuctionBidList { - return &AuctionBidList{ - list: skiplist.New(skiplist.GreaterThanFunc(func(lhs, rhs any) int { - bidA := lhs.(auctionBidListKey) - bidB := rhs.(auctionBidListKey) - - switch { - case bidA.bid.IsAllGT(bidB.bid): - return 1 - - case bidA.bid.IsAllLT(bidB.bid): - return -1 - - default: - // in case of a tie in bid, sort by hash - return skiplist.ByteAsc.Compare(bidA.hash, bidB.hash) - } - })), - } -} - -// TopBid returns the WrappedBidTx with the highest bid. -func (abl *AuctionBidList) TopBid() *WrappedBidTx { - n := abl.list.Back() - if n == nil { - return nil - } - - return n.Value.(*WrappedBidTx) -} - -func (abl *AuctionBidList) Insert(wBidTx *WrappedBidTx) { - abl.list.Set(auctionBidListKey{bid: wBidTx.bid, hash: wBidTx.hash[:]}, wBidTx) -} - -func (abl *AuctionBidList) Remove(wBidTx *WrappedBidTx) { - abl.list.Remove(auctionBidListKey{bid: wBidTx.bid, hash: wBidTx.hash[:]}) -} diff --git a/mempool/bid_list_test.go b/mempool/bid_list_test.go deleted file mode 100644 index 2951275..0000000 --- a/mempool/bid_list_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package mempool_test - -import ( - "math/rand" - "testing" - "time" - - sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/skip-mev/pob/mempool" - "github.com/stretchr/testify/require" -) - -var emptyHash = [32]byte{} - -func TestAuctionBidList(t *testing.T) { - abl := mempool.NewAuctionBidList() - - require.Nil(t, abl.TopBid()) - - // insert a bid which should be the head and tail - bid1 := sdk.NewCoins(sdk.NewInt64Coin("foo", 100)) - abl.Insert(mempool.NewWrappedBidTx(nil, emptyHash, bid1)) - require.Equal(t, bid1, abl.TopBid().GetBid()) - - // insert 500 random bids between [100, 1000) - var currTopBid sdk.Coins - rng := rand.New(rand.NewSource(time.Now().UnixNano())) - for i := 0; i < 500; i++ { - randomBid := rng.Int63n(1000-100) + 100 - - bid := sdk.NewCoins(sdk.NewInt64Coin("foo", randomBid)) - abl.Insert(mempool.NewWrappedBidTx(nil, emptyHash, bid)) - - currTopBid = abl.TopBid().GetBid() - } - - // insert a bid which should be the new tail, thus the highest bid - bid2 := sdk.NewCoins(sdk.NewInt64Coin("foo", 1000)) - abl.Insert(mempool.NewWrappedBidTx(nil, emptyHash, bid2)) - require.Equal(t, bid2, abl.TopBid().GetBid()) - - // remove the top bid and ensure the new top bid is the previous top bid - abl.Remove(mempool.NewWrappedBidTx(nil, emptyHash, bid2)) - require.Equal(t, currTopBid, abl.TopBid().GetBid()) -} diff --git a/mempool/mempool.go b/mempool/mempool.go index 24d1254..5a41a9b 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -2,8 +2,7 @@ package mempool import ( "context" - "crypto/sha256" - "encoding/base64" + "errors" "fmt" sdk "github.com/cosmos/cosmos-sdk/types" @@ -20,38 +19,76 @@ type AuctionMempool struct { // globalIndex defines the index of all transactions in the mempool. It uses // the SDK's builtin PriorityNonceMempool. Once a bid is selected for top-of-block, // all subsequent transactions in the mempool will be selected from this index. - globalIndex *sdkmempool.PriorityNonceMempool + globalIndex *PriorityNonceMempool[int64] // auctionIndex defines an index of auction bids. - auctionIndex *AuctionBidList + auctionIndex *PriorityNonceMempool[string] - // txIndex defines an index of all transactions in the mempool by hash. - txIndex map[string]sdk.Tx - - // txEncoder defines the sdk.Tx encoder that allows us to encode transactions - // and construct their hashes. - txEncoder sdk.TxEncoder + // txDecoder defines the sdk.Tx decoder that allows us to decode transactions + // and construct sdk.Txs from the bundled transactions. + txDecoder sdk.TxDecoder } -func NewAuctionMempool(txEncoder sdk.TxEncoder, opts ...sdkmempool.PriorityNonceMempoolOption) *AuctionMempool { +// AuctionTxPriority returns a TxPriority over auction bid transactions only. It +// is to be used in the auction index only. +func AuctionTxPriority() TxPriority[string] { + return TxPriority[string]{ + GetTxPriority: func(goCtx context.Context, tx sdk.Tx) string { + return tx.(*WrappedBidTx).GetBid().String() + }, + Compare: func(a, b string) int { + aCoins, _ := sdk.ParseCoinsNormalized(a) + bCoins, _ := sdk.ParseCoinsNormalized(b) + + switch { + case aCoins == nil && bCoins == nil: + return 0 + + case aCoins == nil: + return -1 + + case bCoins == nil: + return 1 + + default: + switch { + case aCoins.IsAllGT(bCoins): + return 1 + + case aCoins.IsAllLT(bCoins): + return -1 + + default: + return 0 + } + } + }, + MinValue: "", + } +} + +func NewAuctionMempool(txDecoder sdk.TxDecoder, maxTx int) *AuctionMempool { return &AuctionMempool{ - globalIndex: sdkmempool.NewPriorityMempool(opts...), - auctionIndex: NewAuctionBidList(), - txIndex: make(map[string]sdk.Tx), - txEncoder: txEncoder, + globalIndex: NewPriorityMempool( + PriorityNonceMempoolConfig[int64]{ + TxPriority: NewDefaultTxPriority(), + MaxTx: maxTx, + }, + ), + auctionIndex: NewPriorityMempool( + PriorityNonceMempoolConfig[string]{ + TxPriority: AuctionTxPriority(), + MaxTx: maxTx, + }, + ), + txDecoder: txDecoder, } } +// Insert inserts a transaction into the mempool. If the transaction is a special +// auction tx (tx that contains a single MsgAuctionBid), it will also insert the +// transaction into the auction index. func (am *AuctionMempool) Insert(ctx context.Context, tx sdk.Tx) error { - hash, hashStr, err := am.getTxHash(tx) - if err != nil { - return err - } - - if _, ok := am.txIndex[hashStr]; ok { - return fmt.Errorf("tx already exists: %s", hashStr) - } - if err := am.globalIndex.Insert(ctx, tx); err != nil { return fmt.Errorf("failed to insert tx into global index: %w", err) } @@ -62,90 +99,85 @@ func (am *AuctionMempool) Insert(ctx context.Context, tx sdk.Tx) error { } if msg != nil { - am.auctionIndex.Insert(NewWrappedBidTx(tx, hash, msg.GetBid())) - } - - am.txIndex[hashStr] = tx - - return nil -} - -func (am *AuctionMempool) Remove(tx sdk.Tx) error { - hash, hashStr, err := am.getTxHash(tx) - if err != nil { - return err - } - - // 1. Remove the tx from the global index - if err := am.globalIndex.Remove(tx); err != nil { - return fmt.Errorf("failed to remove tx from global index: %w", err) - } - - // 2. Remove from the transaction index - delete(am.txIndex, hashStr) - - msg, err := GetMsgAuctionBidFromTx(tx) - if err != nil { - return err - } - - // 3. Remove the bid from the auction index (if applicable). In addition, we - // remove all referenced transactions from the global and transaction indices. - if msg != nil { - am.auctionIndex.Remove(NewWrappedBidTx(tx, hash, msg.GetBid())) - - for _, refTxRaw := range msg.GetTransactions() { - refHash := sha256.Sum256(refTxRaw) - refHashStr := base64.StdEncoding.EncodeToString(refHash[:]) - - // check if we have the referenced transaction first - if refTx, ok := am.txIndex[refHashStr]; ok { - if err := am.globalIndex.Remove(refTx); err != nil { - return fmt.Errorf("failed to remove bid referenced tx from global index: %w", err) - } - } - - delete(am.txIndex, refHashStr) + if err := am.auctionIndex.Insert(ctx, NewWrappedBidTx(tx, msg.GetBid())); err != nil { + removeTx(am.globalIndex, tx) + return fmt.Errorf("failed to insert tx into auction index: %w", err) } } return nil } -// SelectTopAuctionBidTx returns the top auction bid tx in the mempool if one -// exists. -func (am *AuctionMempool) SelectTopAuctionBidTx() sdk.Tx { - wBidTx := am.auctionIndex.TopBid() - if wBidTx == nil { - return nil +// Remove removes a transaction from the mempool. If the transaction is a special +// auction tx (tx that contains a single MsgAuctionBid), it will also remove all +// referenced transactions from the global mempool. +func (am *AuctionMempool) Remove(tx sdk.Tx) error { + // 1. Remove the tx from the global index + removeTx(am.globalIndex, tx) + + msg, err := GetMsgAuctionBidFromTx(tx) + if err != nil { + return err } - return wBidTx.Tx + // 2. Remove the bid from the auction index (if applicable). In addition, we + // remove all referenced transactions from the global mempool. + if msg != nil { + removeTx(am.auctionIndex, NewWrappedBidTx(tx, msg.GetBid())) + + for _, refRawTx := range msg.GetTransactions() { + refTx, err := am.txDecoder(refRawTx) + if err != nil { + return fmt.Errorf("failed to decode referenced tx: %w", err) + } + + removeTx(am.globalIndex, refTx) + } + } + + return nil +} + +// RemoveWithoutRefTx removes a transaction from the mempool without removing +// any referenced transactions. Referenced transactions only exist in special +// auction transactions (txs that only include a single MsgAuctionBid). This +// API is used to ensure that searchers are unable to remove valid transactions +// from the global mempool. +func (am *AuctionMempool) RemoveWithoutRefTx(tx sdk.Tx) error { + removeTx(am.globalIndex, tx) + + msg, err := GetMsgAuctionBidFromTx(tx) + if err != nil { + return err + } + + if msg != nil { + removeTx(am.auctionIndex, NewWrappedBidTx(tx, msg.GetBid())) + } + + return nil +} + +// AuctionBidSelect returns an iterator over auction bids transactions only. +func (am *AuctionMempool) AuctionBidSelect(ctx context.Context) sdkmempool.Iterator { + return am.auctionIndex.Select(ctx, nil) } func (am *AuctionMempool) Select(ctx context.Context, txs [][]byte) sdkmempool.Iterator { return am.globalIndex.Select(ctx, txs) } -func (am *AuctionMempool) AuctionBidSelect(ctx context.Context) sdkmempool.Iterator { - // TODO: return am.auctionIndex.Select(ctx, nil) - // - // Ref: ENG-547 - panic("not implemented") +func (am *AuctionMempool) CountAuctionTx() int { + return am.auctionIndex.CountTx() } func (am *AuctionMempool) CountTx() int { return am.globalIndex.CountTx() } -func (am *AuctionMempool) getTxHash(tx sdk.Tx) ([32]byte, string, error) { - bz, err := am.txEncoder(tx) - if err != nil { - return [32]byte{}, "", fmt.Errorf("failed to encode tx: %w", err) +func removeTx(mp sdkmempool.Mempool, tx sdk.Tx) { + err := mp.Remove(tx) + if err != nil && !errors.Is(err, sdkmempool.ErrTxNotFound) { + panic(fmt.Errorf("failed to remove invalid transaction from the mempool: %w", err)) } - - hash := sha256.Sum256(bz) - hashStr := base64.StdEncoding.EncodeToString(hash[:]) - - return hash, hashStr, nil } diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 0842b0e..4241288 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -18,7 +18,7 @@ import ( func TestAuctionMempool(t *testing.T) { encCfg := createTestEncodingConfig() - amp := mempool.NewAuctionMempool(encCfg.TxConfig.TxEncoder()) + amp := mempool.NewAuctionMempool(encCfg.TxConfig.TxDecoder(), 0) ctx := sdk.NewContext(nil, cmtproto.Header{}, false, log.NewNopLogger()) rng := rand.New(rand.NewSource(time.Now().Unix())) accounts := RandomAccounts(rng, 5) @@ -60,7 +60,7 @@ func TestAuctionMempool(t *testing.T) { require.NoError(t, amp.Insert(ctx.WithPriority(p), txBuilder.GetTx())) } - require.Nil(t, amp.SelectTopAuctionBidTx()) + require.Nil(t, amp.AuctionBidSelect(ctx)) // insert bid transactions var highestBid sdk.Coins @@ -109,13 +109,26 @@ func TestAuctionMempool(t *testing.T) { require.Equal(t, expectedCount, amp.CountTx()) // select the top bid and misc txs - bidTx := amp.SelectTopAuctionBidTx() + bidTx := amp.AuctionBidSelect(ctx).Tx() require.Len(t, bidTx.GetMsgs(), 1) require.Equal(t, highestBid, bidTx.GetMsgs()[0].(*auctiontypes.MsgAuctionBid).Bid) - // remove bid tx, which should also removed the referenced txs - require.NoError(t, amp.Remove(bidTx)) - require.Equal(t, expectedCount-3, amp.CountTx()) + // remove the top bid tx (without removing the referenced txs) + prevAuctionCount := amp.CountAuctionTx() + require.NoError(t, amp.RemoveWithoutRefTx(bidTx)) + require.Equal(t, expectedCount-1, amp.CountTx()) + require.Equal(t, prevAuctionCount-1, amp.CountAuctionTx()) + + // the next bid tx should be less than or equal to the previous highest bid + nextBidTx := amp.AuctionBidSelect(ctx).Tx() + require.Len(t, nextBidTx.GetMsgs(), 1) + msgAuctionBid := nextBidTx.GetMsgs()[0].(*auctiontypes.MsgAuctionBid) + require.True(t, highestBid.IsAllGTE(msgAuctionBid.Bid)) + + // remove the top bid tx (including the ref txs) + prevGlobalCount := amp.CountTx() + require.NoError(t, amp.Remove(nextBidTx)) + require.Equal(t, prevGlobalCount-1-2, amp.CountTx()) } func createMsgAuctionBid(txCfg client.TxConfig, bidder Account, bid sdk.Coins) (*auctiontypes.MsgAuctionBid, error) { diff --git a/mempool/priority_nonce.go b/mempool/priority_nonce.go new file mode 100644 index 0000000..ccbf3cd --- /dev/null +++ b/mempool/priority_nonce.go @@ -0,0 +1,477 @@ +package mempool + +import ( + "context" + "fmt" + "math" + + "github.com/huandu/skiplist" + + sdk "github.com/cosmos/cosmos-sdk/types" + sdkmempool "github.com/cosmos/cosmos-sdk/types/mempool" + "github.com/cosmos/cosmos-sdk/x/auth/signing" +) + +var ( + _ sdkmempool.Mempool = (*PriorityNonceMempool[int64])(nil) + _ sdkmempool.Iterator = (*PriorityNonceIterator[int64])(nil) +) + +type ( + // PriorityNonceMempoolConfig defines the configuration used to configure the + // PriorityNonceMempool. + PriorityNonceMempoolConfig[C comparable] struct { + // TxPriority defines the transaction priority and comparator. + TxPriority TxPriority[C] + + // OnRead is a callback to be called when a tx is read from the mempool. + OnRead func(tx sdk.Tx) + + // TxReplacement is a callback to be called when duplicated transaction nonce + // detected during mempool insert. An application can define a transaction + // replacement rule based on tx priority or certain transaction fields. + TxReplacement func(op, np C, oTx, nTx sdk.Tx) bool + + // MaxTx sets the maximum number of transactions allowed in the mempool with + // the semantics: + // - if MaxTx == 0, there is no cap on the number of transactions in the mempool + // - if MaxTx > 0, the mempool will cap the number of transactions it stores, + // and will prioritize transactions by their priority and sender-nonce + // (sequence number) when evicting transactions. + // - if MaxTx < 0, `Insert` is a no-op. + MaxTx int + } + + // PriorityNonceMempool is a mempool implementation that stores txs + // in a partially ordered set by 2 dimensions: priority, and sender-nonce + // (sequence number). Internally it uses one priority ordered skip list and one + // skip list per sender ordered by sender-nonce (sequence number). When there + // are multiple txs from the same sender, they are not always comparable by + // priority to other sender txs and must be partially ordered by both sender-nonce + // and priority. + PriorityNonceMempool[C comparable] struct { + priorityIndex *skiplist.SkipList + priorityCounts map[C]int + senderIndices map[string]*skiplist.SkipList + scores map[txMeta[C]]txMeta[C] + cfg PriorityNonceMempoolConfig[C] + } + + // PriorityNonceIterator defines an iterator that is used for mempool iteration + // on Select(). + PriorityNonceIterator[C comparable] struct { + mempool *PriorityNonceMempool[C] + priorityNode *skiplist.Element + senderCursors map[string]*skiplist.Element + sender string + nextPriority C + } + + // TxPriority defines a type that is used to retrieve and compare transaction + // priorities. Priorities must be comparable. + TxPriority[C comparable] struct { + // GetTxPriority returns the priority of the transaction. A priority must be + // comparable via Compare. + GetTxPriority func(ctx context.Context, tx sdk.Tx) C + + // CompareTxPriority compares two transaction priorities. The result should be + // 0 if a == b, -1 if a < b, and +1 if a > b. + Compare func(a, b C) int + + // MinValue defines the minimum priority value, e.g. MinInt64. This value is + // used when instantiating a new iterator and comparing weights. + MinValue C + } + + // txMeta stores transaction metadata used in indices + txMeta[C comparable] struct { + // nonce is the sender's sequence number + nonce uint64 + // priority is the transaction's priority + priority C + // sender is the transaction's sender + sender string + // weight is the transaction's weight, used as a tiebreaker for transactions + // with the same priority + weight C + // senderElement is a pointer to the transaction's element in the sender index + senderElement *skiplist.Element + } +) + +// NewDefaultTxPriority returns a TxPriority comparator using ctx.Priority as +// the defining transaction priority. +func NewDefaultTxPriority() TxPriority[int64] { + return TxPriority[int64]{ + GetTxPriority: func(goCtx context.Context, _ sdk.Tx) int64 { + return sdk.UnwrapSDKContext(goCtx).Priority() + }, + Compare: func(a, b int64) int { + return skiplist.Int64.Compare(a, b) + }, + MinValue: math.MinInt64, + } +} + +func DefaultPriorityNonceMempoolConfig() PriorityNonceMempoolConfig[int64] { + return PriorityNonceMempoolConfig[int64]{ + TxPriority: NewDefaultTxPriority(), + } +} + +// skiplistComparable is a comparator for txKeys that first compares priority, +// then weight, then sender, then nonce, uniquely identifying a transaction. +// +// Note, skiplistComparable is used as the comparator in the priority index. +func skiplistComparable[C comparable](txPriority TxPriority[C]) skiplist.Comparable { + return skiplist.LessThanFunc(func(a, b any) int { + keyA := a.(txMeta[C]) + keyB := b.(txMeta[C]) + + res := txPriority.Compare(keyA.priority, keyB.priority) + if res != 0 { + return res + } + + // Weight is used as a tiebreaker for transactions with the same priority. + // Weight is calculated in a single pass in .Select(...) and so will be 0 + // on .Insert(...). + res = txPriority.Compare(keyA.weight, keyB.weight) + if res != 0 { + return res + } + + // Because weight will be 0 on .Insert(...), we must also compare sender and + // nonce to resolve priority collisions. If we didn't then transactions with + // the same priority would overwrite each other in the priority index. + res = skiplist.String.Compare(keyA.sender, keyB.sender) + if res != 0 { + return res + } + + return skiplist.Uint64.Compare(keyA.nonce, keyB.nonce) + }) +} + +// NewPriorityMempool returns the SDK's default mempool implementation which +// returns txs in a partial order by 2 dimensions; priority, and sender-nonce. +func NewPriorityMempool[C comparable](cfg PriorityNonceMempoolConfig[C]) *PriorityNonceMempool[C] { + mp := &PriorityNonceMempool[C]{ + priorityIndex: skiplist.New(skiplistComparable(cfg.TxPriority)), + priorityCounts: make(map[C]int), + senderIndices: make(map[string]*skiplist.SkipList), + scores: make(map[txMeta[C]]txMeta[C]), + cfg: cfg, + } + + return mp +} + +// DefaultPriorityMempool returns a priorityNonceMempool with no options. +func DefaultPriorityMempool() *PriorityNonceMempool[int64] { + return NewPriorityMempool(DefaultPriorityNonceMempoolConfig()) +} + +// NextSenderTx returns the next transaction for a given sender by nonce order, +// i.e. the next valid transaction for the sender. If no such transaction exists, +// nil will be returned. +func (mp *PriorityNonceMempool[C]) NextSenderTx(sender string) sdk.Tx { + senderIndex, ok := mp.senderIndices[sender] + if !ok { + return nil + } + + cursor := senderIndex.Front() + return cursor.Value.(sdk.Tx) +} + +// Insert attempts to insert a Tx into the app-side mempool in O(log n) time, +// returning an error if unsuccessful. Sender and nonce are derived from the +// transaction's first signature. +// +// Transactions are unique by sender and nonce. Inserting a duplicate tx is an +// O(log n) no-op. +// +// Inserting a duplicate tx with a different priority overwrites the existing tx, +// changing the total order of the mempool. +func (mp *PriorityNonceMempool[C]) Insert(ctx context.Context, tx sdk.Tx) error { + if mp.cfg.MaxTx > 0 && mp.CountTx() >= mp.cfg.MaxTx { + return sdkmempool.ErrMempoolTxMaxCapacity + } else if mp.cfg.MaxTx < 0 { + return nil + } + + sigs, err := tx.(signing.SigVerifiableTx).GetSignaturesV2() + if err != nil { + return err + } + if len(sigs) == 0 { + return fmt.Errorf("tx must have at least one signer") + } + + sig := sigs[0] + sender := sdk.AccAddress(sig.PubKey.Address()).String() + priority := mp.cfg.TxPriority.GetTxPriority(ctx, tx) + nonce := sig.Sequence + key := txMeta[C]{nonce: nonce, priority: priority, sender: sender} + + senderIndex, ok := mp.senderIndices[sender] + if !ok { + senderIndex = skiplist.New(skiplist.LessThanFunc(func(a, b any) int { + return skiplist.Uint64.Compare(b.(txMeta[C]).nonce, a.(txMeta[C]).nonce) + })) + + // initialize sender index if not found + mp.senderIndices[sender] = senderIndex + } + + // Since mp.priorityIndex is scored by priority, then sender, then nonce, a + // changed priority will create a new key, so we must remove the old key and + // re-insert it to avoid having the same tx with different priorityIndex indexed + // twice in the mempool. + // + // This O(log n) remove operation is rare and only happens when a tx's priority + // changes. + sk := txMeta[C]{nonce: nonce, sender: sender} + if oldScore, txExists := mp.scores[sk]; txExists { + if mp.cfg.TxReplacement != nil && !mp.cfg.TxReplacement(oldScore.priority, priority, senderIndex.Get(key).Value.(sdk.Tx), tx) { + return fmt.Errorf( + "tx doesn't fit the replacement rule, oldPriority: %v, newPriority: %v, oldTx: %v, newTx: %v", + oldScore.priority, + priority, + senderIndex.Get(key).Value.(sdk.Tx), + tx, + ) + } + + mp.priorityIndex.Remove(txMeta[C]{ + nonce: nonce, + sender: sender, + priority: oldScore.priority, + weight: oldScore.weight, + }) + mp.priorityCounts[oldScore.priority]-- + } + + mp.priorityCounts[priority]++ + + // Since senderIndex is scored by nonce, a changed priority will overwrite the + // existing key. + key.senderElement = senderIndex.Set(key, tx) + + mp.scores[sk] = txMeta[C]{priority: priority} + mp.priorityIndex.Set(key, tx) + + return nil +} + +func (i *PriorityNonceIterator[C]) iteratePriority() sdkmempool.Iterator { + // beginning of priority iteration + if i.priorityNode == nil { + i.priorityNode = i.mempool.priorityIndex.Front() + } else { + i.priorityNode = i.priorityNode.Next() + } + + // end of priority iteration + if i.priorityNode == nil { + return nil + } + + i.sender = i.priorityNode.Key().(txMeta[C]).sender + + nextPriorityNode := i.priorityNode.Next() + if nextPriorityNode != nil { + i.nextPriority = nextPriorityNode.Key().(txMeta[C]).priority + } else { + i.nextPriority = i.mempool.cfg.TxPriority.MinValue + } + + return i.Next() +} + +func (i *PriorityNonceIterator[C]) Next() sdkmempool.Iterator { + if i.priorityNode == nil { + return nil + } + + cursor, ok := i.senderCursors[i.sender] + if !ok { + // beginning of sender iteration + cursor = i.mempool.senderIndices[i.sender].Front() + } else { + // middle of sender iteration + cursor = cursor.Next() + } + + // end of sender iteration + if cursor == nil { + return i.iteratePriority() + } + + key := cursor.Key().(txMeta[C]) + + // We've reached a transaction with a priority lower than the next highest + // priority in the pool. + if i.mempool.cfg.TxPriority.Compare(key.priority, i.nextPriority) < 0 { + return i.iteratePriority() + } else if i.mempool.cfg.TxPriority.Compare(key.priority, i.nextPriority) == 0 { + // Weight is incorporated into the priority index key only (not sender index) + // so we must fetch it here from the scores map. + weight := i.mempool.scores[txMeta[C]{nonce: key.nonce, sender: key.sender}].weight + if i.mempool.cfg.TxPriority.Compare(weight, i.priorityNode.Next().Key().(txMeta[C]).weight) < 0 { + return i.iteratePriority() + } + } + + i.senderCursors[i.sender] = cursor + return i +} + +func (i *PriorityNonceIterator[C]) Tx() sdk.Tx { + return i.senderCursors[i.sender].Value.(sdk.Tx) +} + +// Select returns a set of transactions from the mempool, ordered by priority +// and sender-nonce in O(n) time. The passed in list of transactions are ignored. +// This is a readonly operation, the mempool is not modified. +// +// The maxBytes parameter defines the maximum number of bytes of transactions to +// return. +func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) sdkmempool.Iterator { + if mp.priorityIndex.Len() == 0 { + return nil + } + + mp.reorderPriorityTies() + + iterator := &PriorityNonceIterator[C]{ + mempool: mp, + senderCursors: make(map[string]*skiplist.Element), + } + + return iterator.iteratePriority() +} + +type reorderKey[C comparable] struct { + deleteKey txMeta[C] + insertKey txMeta[C] + tx sdk.Tx +} + +func (mp *PriorityNonceMempool[C]) reorderPriorityTies() { + node := mp.priorityIndex.Front() + + var reordering []reorderKey[C] + for node != nil { + key := node.Key().(txMeta[C]) + if mp.priorityCounts[key.priority] > 1 { + newKey := key + newKey.weight = senderWeight(mp.cfg.TxPriority, key.senderElement) + reordering = append(reordering, reorderKey[C]{deleteKey: key, insertKey: newKey, tx: node.Value.(sdk.Tx)}) + } + + node = node.Next() + } + + for _, k := range reordering { + mp.priorityIndex.Remove(k.deleteKey) + delete(mp.scores, txMeta[C]{nonce: k.deleteKey.nonce, sender: k.deleteKey.sender}) + mp.priorityIndex.Set(k.insertKey, k.tx) + mp.scores[txMeta[C]{nonce: k.insertKey.nonce, sender: k.insertKey.sender}] = k.insertKey + } +} + +// senderWeight returns the weight of a given tx (t) at senderCursor. Weight is +// defined as the first (nonce-wise) same sender tx with a priority not equal to +// t. It is used to resolve priority collisions, that is when 2 or more txs from +// different senders have the same priority. +func senderWeight[C comparable](txPriority TxPriority[C], senderCursor *skiplist.Element) C { + if senderCursor == nil { + return txPriority.MinValue + } + + weight := senderCursor.Key().(txMeta[C]).priority + senderCursor = senderCursor.Next() + for senderCursor != nil { + p := senderCursor.Key().(txMeta[C]).priority + if txPriority.Compare(p, weight) != 0 { + weight = p + } + + senderCursor = senderCursor.Next() + } + + return weight +} + +// CountTx returns the number of transactions in the mempool. +func (mp *PriorityNonceMempool[C]) CountTx() int { + return mp.priorityIndex.Len() +} + +// Remove removes a transaction from the mempool in O(log n) time, returning an +// error if unsuccessful. +func (mp *PriorityNonceMempool[C]) Remove(tx sdk.Tx) error { + sigs, err := tx.(signing.SigVerifiableTx).GetSignaturesV2() + if err != nil { + return err + } + if len(sigs) == 0 { + return fmt.Errorf("attempted to remove a tx with no signatures") + } + + sig := sigs[0] + sender := sdk.AccAddress(sig.PubKey.Address()).String() + nonce := sig.Sequence + + scoreKey := txMeta[C]{nonce: nonce, sender: sender} + score, ok := mp.scores[scoreKey] + if !ok { + return sdkmempool.ErrTxNotFound + } + tk := txMeta[C]{nonce: nonce, priority: score.priority, sender: sender, weight: score.weight} + + senderTxs, ok := mp.senderIndices[sender] + if !ok { + return fmt.Errorf("sender %s not found", sender) + } + + mp.priorityIndex.Remove(tk) + senderTxs.Remove(tk) + delete(mp.scores, scoreKey) + mp.priorityCounts[score.priority]-- + + return nil +} + +func IsEmpty[C comparable](mempool sdkmempool.Mempool) error { + mp := mempool.(*PriorityNonceMempool[C]) + if mp.priorityIndex.Len() != 0 { + return fmt.Errorf("priorityIndex not empty") + } + + var countKeys []C + for k := range mp.priorityCounts { + countKeys = append(countKeys, k) + } + + for _, k := range countKeys { + if mp.priorityCounts[k] != 0 { + return fmt.Errorf("priorityCounts not zero at %v, got %v", k, mp.priorityCounts[k]) + } + } + + var senderKeys []string + for k := range mp.senderIndices { + senderKeys = append(senderKeys, k) + } + + for _, k := range senderKeys { + if mp.senderIndices[k].Len() != 0 { + return fmt.Errorf("senderIndex not empty for sender %v", k) + } + } + + return nil +} diff --git a/mempool/tx.go b/mempool/tx.go index 12341dc..24d1048 100644 --- a/mempool/tx.go +++ b/mempool/tx.go @@ -4,27 +4,25 @@ import ( "errors" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/auth/signing" auctiontypes "github.com/skip-mev/pob/x/auction/types" ) // WrappedBidTx defines a wrapper around an sdk.Tx that contains a single // MsgAuctionBid message with additional metadata. type WrappedBidTx struct { - sdk.Tx + signing.Tx - hash [32]byte - bid sdk.Coins + bid sdk.Coins } -func NewWrappedBidTx(tx sdk.Tx, hash [32]byte, bid sdk.Coins) *WrappedBidTx { +func NewWrappedBidTx(tx sdk.Tx, bid sdk.Coins) *WrappedBidTx { return &WrappedBidTx{ - Tx: tx, - hash: hash, - bid: bid, + Tx: tx.(signing.Tx), + bid: bid, } } -func (wbtx *WrappedBidTx) GetHash() [32]byte { return wbtx.hash } func (wbtx *WrappedBidTx) GetBid() sdk.Coins { return wbtx.bid } // GetMsgAuctionBidFromTx attempts to retrieve a MsgAuctionBid from an sdk.Tx if