diff --git a/abci/abci.go b/abci/abci.go index 0364236..366f577 100644 --- a/abci/abci.go +++ b/abci/abci.go @@ -66,17 +66,20 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { bidTxSize := int64(len(bidTxBz)) if bidTxSize <= req.MaxTxBytes { - bidMsg, err := mempool.GetMsgAuctionBidFromTx(tmpBidTx) + bundledTransactions, err := h.mempool.GetBundledTransactions(tmpBidTx) if err != nil { - // This should never happen, as CheckTx will ensure only valid bids - // enter the mempool, but in case it does, we need to remove the - // transaction from the mempool. + // Some transactions in the bundle may be malformatted or invalid, so + // we remove the bid transaction and try the next top bid. txsToRemove[tmpBidTx] = struct{}{} continue selectBidTxLoop } - for _, refTxRaw := range bidMsg.Transactions { - refTx, err := h.txDecoder(refTxRaw) + // store the bytes of each ref tx as sdk.Tx bytes in order to build a valid proposal + sdkTxBytes := make([][]byte, len(bundledTransactions)) + + // Ensure that the bundled transactions are valid + for index, rawRefTx := range bundledTransactions { + refTx, err := h.mempool.WrapBundleTransaction(rawRefTx) if err != nil { // Malformed bundled transaction, so we remove the bid transaction // and try the next top bid. @@ -84,12 +87,15 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { continue selectBidTxLoop } - if _, err := h.PrepareProposalVerifyTx(cacheCtx, refTx); err != nil { + txBz, err := h.PrepareProposalVerifyTx(cacheCtx, refTx) + if err != nil { // Invalid bundled transaction, so we remove the bid transaction // and try the next top bid. txsToRemove[tmpBidTx] = struct{}{} continue selectBidTxLoop } + + sdkTxBytes[index] = txBz } // At this point, both the bid transaction itself and all the bundled @@ -98,9 +104,9 @@ func (h *ProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHandler { // update the total size selected thus far. totalTxBytes += bidTxSize selectedTxs = append(selectedTxs, bidTxBz) - selectedTxs = append(selectedTxs, bidMsg.Transactions...) + selectedTxs = append(selectedTxs, sdkTxBytes...) - for _, refTxRaw := range bidMsg.Transactions { + for _, refTxRaw := range sdkTxBytes { hash := sha256.Sum256(refTxRaw) txHash := hex.EncodeToString(hash[:]) seenTxs[txHash] = struct{}{} @@ -183,24 +189,41 @@ func (h *ProposalHandler) ProcessProposalHandler() sdk.ProcessProposalHandler { return abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_REJECT} } - msgAuctionBid, err := mempool.GetMsgAuctionBidFromTx(tx) + isAuctionTx, err := h.mempool.IsAuctionTx(tx) if err != nil { return abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_REJECT} } - if msgAuctionBid != nil { + if isAuctionTx { // Only the first transaction can be an auction bid tx if index != 0 { return abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_REJECT} } - // The order of transactions in the block proposal must follow the order of transactions in the bid. - if len(req.Txs) < len(msgAuctionBid.Transactions)+1 { + bundledTransactions, err := h.mempool.GetBundledTransactions(tx) + if err != nil { return abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_REJECT} } - for i, refTxRaw := range msgAuctionBid.Transactions { - if !bytes.Equal(refTxRaw, req.Txs[i+1]) { + // The order of transactions in the block proposal must follow the order of transactions in the bid. + if len(req.Txs) < len(bundledTransactions)+1 { + return abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_REJECT} + } + + for i, refTxRaw := range bundledTransactions { + // Wrap and then encode the bundled transaction to ensure that the underlying + // reference transaction can be processed as an sdk.Tx. + wrappedTx, err := h.mempool.WrapBundleTransaction(refTxRaw) + if err != nil { + return abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_REJECT} + } + + refTxBz, err := h.txEncoder(wrappedTx) + if err != nil { + return abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_REJECT} + } + + if !bytes.Equal(refTxBz, req.Txs[i+1]) { return abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_REJECT} } } diff --git a/abci/abci_test.go b/abci/abci_test.go index 67d630e..0a7cca1 100644 --- a/abci/abci_test.go +++ b/abci/abci_test.go @@ -69,7 +69,8 @@ func (suite *ABCITestSuite) SetupTest() { suite.ctx = testCtx.Ctx // Mempool set up - suite.mempool = mempool.NewAuctionMempool(suite.encodingConfig.TxConfig.TxDecoder(), suite.encodingConfig.TxConfig.TxEncoder(), 0) + config := mempool.NewDefaultConfig(suite.encodingConfig.TxConfig.TxDecoder()) + suite.mempool = mempool.NewAuctionMempool(suite.encodingConfig.TxConfig.TxDecoder(), suite.encodingConfig.TxConfig.TxEncoder(), 0, config) suite.txs = make(map[string]struct{}) suite.auctionBidAmount = sdk.NewCoin("foo", sdk.NewInt(1000000000)) suite.minBidIncrement = sdk.NewCoin("foo", sdk.NewInt(1000)) @@ -522,10 +523,10 @@ func (suite *ABCITestSuite) TestPrepareProposal() { auctionTx, err := suite.encodingConfig.TxConfig.TxDecoder()(res.Txs[0]) suite.Require().NoError(err) - msgAuctionBid, err := mempool.GetMsgAuctionBidFromTx(auctionTx) + bidInfo, err := suite.mempool.GetAuctionBidInfo(auctionTx) suite.Require().NoError(err) - for index, tx := range msgAuctionBid.GetTransactions() { + for index, tx := range bidInfo.Transactions { suite.Require().Equal(tx, res.Txs[index+1]) } } diff --git a/mempool/config.go b/mempool/config.go new file mode 100644 index 0000000..1d918b9 --- /dev/null +++ b/mempool/config.go @@ -0,0 +1,135 @@ +package mempool + +import ( + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +type ( + // AuctionBidInfo defines the information about a bid to the auction house. + AuctionBidInfo struct { + Bidder sdk.AccAddress + Bid sdk.Coin + Transactions [][]byte + } + + // Config defines the configuration for processing auction transactions. It is + // a wrapper around all of the functionality that each application chain must implement + // in order for auction processing to work. + Config interface { + // IsAuctionTx defines a function that returns true iff a transaction is an + // auction bid transaction. + IsAuctionTx(tx sdk.Tx) (bool, error) + + // GetTransactionSigners defines a function that returns the signers of a + // bundle transaction i.e. transaction that was included in the auction transaction's bundle. + GetTransactionSigners(tx []byte) (map[string]struct{}, error) + + // WrapBundleTransaction defines a function that wraps a bundle transaction into a sdk.Tx. + WrapBundleTransaction(tx []byte) (sdk.Tx, error) + + // GetBidder defines a function that returns the bidder of an auction transaction transaction. + GetBidder(tx sdk.Tx) (sdk.AccAddress, error) + + // GetBid defines a function that returns the bid of an auction transaction. + GetBid(tx sdk.Tx) (sdk.Coin, error) + + // GetBundledTransactions defines a function that returns the bundled transactions + // that the user wants to execute at the top of the block given an auction transaction. + GetBundledTransactions(tx sdk.Tx) ([][]byte, error) + } + + // DefaultConfig defines a default configuration for processing auction transactions. + DefaultConfig struct { + txDecoder sdk.TxDecoder + } +) + +var _ Config = (*DefaultConfig)(nil) + +// NewDefaultConfig returns a default transaction configuration. +func NewDefaultConfig(txDecoder sdk.TxDecoder) Config { + return &DefaultConfig{ + txDecoder: txDecoder, + } +} + +// NewDefaultIsAuctionTx defines a default function that returns true iff a transaction +// is an auction bid transaction. In the default case, the transaction must contain a single +// MsgAuctionBid message. +func (config *DefaultConfig) IsAuctionTx(tx sdk.Tx) (bool, error) { + msg, err := GetMsgAuctionBidFromTx(tx) + if err != nil { + return false, err + } + + return msg != nil, nil +} + +// GetTransactionSigners defines a default function that returns the signers +// of a transaction. In the default case, each bundle transaction will be an sdk.Tx and the +// signers are the signers of each sdk.Msg in the transaction. +func (config *DefaultConfig) GetTransactionSigners(tx []byte) (map[string]struct{}, error) { + sdkTx, err := config.txDecoder(tx) + if err != nil { + return nil, err + } + + signers := make(map[string]struct{}) + for _, msg := range sdkTx.GetMsgs() { + for _, signer := range msg.GetSigners() { + signers[signer.String()] = struct{}{} + } + } + + return signers, nil +} + +// WrapBundleTransaction defines a default function that wraps a transaction +// that is included in the bundle into a sdk.Tx. In the default case, the transaction +// that is included in the bundle will be the raw bytes of an sdk.Tx so we can just +// decode it. +func (config *DefaultConfig) WrapBundleTransaction(tx []byte) (sdk.Tx, error) { + return config.txDecoder(tx) +} + +// GetBidder defines a default function that returns the bidder of an auction transaction. +// In the default case, the bidder is the address defined in MsgAuctionBid. +func (config *DefaultConfig) GetBidder(tx sdk.Tx) (sdk.AccAddress, error) { + msg, err := GetMsgAuctionBidFromTx(tx) + if err != nil { + return nil, err + } + + bidder, err := sdk.AccAddressFromBech32(msg.Bidder) + if err != nil { + return nil, fmt.Errorf("invalid bidder address (%s): %w", msg.Bidder, err) + } + + return bidder, nil +} + +// GetBid defines a default function that returns the bid of an auction transaction. +// In the default case, the bid is the amount defined in MsgAuctionBid. +func (config *DefaultConfig) GetBid(tx sdk.Tx) (sdk.Coin, error) { + msg, err := GetMsgAuctionBidFromTx(tx) + if err != nil { + return sdk.Coin{}, err + } + + return msg.Bid, nil +} + +// GetBundledTransactions defines a default function that returns the bundled +// transactions that the user wants to execute at the top of the block. In the default case, +// the bundled transactions will be the raw bytes of sdk.Tx's that are included in the +// MsgAuctionBid. +func (config *DefaultConfig) GetBundledTransactions(tx sdk.Tx) ([][]byte, error) { + msg, err := GetMsgAuctionBidFromTx(tx) + if err != nil { + return nil, err + } + + return msg.Transactions, nil +} diff --git a/mempool/mempool.go b/mempool/mempool.go index 9250eb3..0246cc0 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -21,10 +21,10 @@ type AuctionMempool struct { // globalIndex defines the index of all transactions in the mempool. It uses // the SDK's builtin PriorityNonceMempool. Once a bid is selected for top-of-block, // all subsequent transactions in the mempool will be selected from this index. - globalIndex *PriorityNonceMempool[int64] + globalIndex sdkmempool.Mempool // auctionIndex defines an index of auction bids. - auctionIndex *PriorityNonceMempool[string] + auctionIndex sdkmempool.Mempool // txDecoder defines the sdk.Tx decoder that allows us to decode transactions // and construct sdk.Txs from the bundled transactions. @@ -37,6 +37,9 @@ type AuctionMempool struct { // txIndex is a map of all transactions in the mempool. It is used // to quickly check if a transaction is already in the mempool. txIndex map[string]struct{} + + // config defines the transaction configuration for processing auction transactions. + config Config } // AuctionTxPriority returns a TxPriority over auction bid transactions only. It @@ -82,7 +85,7 @@ func AuctionTxPriority() TxPriority[string] { } } -func NewAuctionMempool(txDecoder sdk.TxDecoder, txEncoder sdk.TxEncoder, maxTx int) *AuctionMempool { +func NewAuctionMempool(txDecoder sdk.TxDecoder, txEncoder sdk.TxEncoder, maxTx int, config Config) *AuctionMempool { return &AuctionMempool{ globalIndex: NewPriorityMempool( PriorityNonceMempoolConfig[int64]{ @@ -99,6 +102,7 @@ func NewAuctionMempool(txDecoder sdk.TxDecoder, txEncoder sdk.TxEncoder, maxTx i txDecoder: txDecoder, txEncoder: txEncoder, txIndex: make(map[string]struct{}), + config: config, } } @@ -106,18 +110,18 @@ func NewAuctionMempool(txDecoder sdk.TxDecoder, txEncoder sdk.TxEncoder, maxTx i // auction tx (tx that contains a single MsgAuctionBid), it will also insert the // transaction into the auction index. func (am *AuctionMempool) Insert(ctx context.Context, tx sdk.Tx) error { - msg, err := GetMsgAuctionBidFromTx(tx) + isAuctionTx, err := am.IsAuctionTx(tx) if err != nil { return err } // Insert the transactions into the appropriate index. switch { - case msg == nil: + case !isAuctionTx: if err := am.globalIndex.Insert(ctx, tx); err != nil { return fmt.Errorf("failed to insert tx into global index: %w", err) } - case msg != nil: + case isAuctionTx: if err := am.auctionIndex.Insert(ctx, tx); err != nil { return fmt.Errorf("failed to insert tx into auction index: %w", err) } @@ -137,26 +141,31 @@ func (am *AuctionMempool) Insert(ctx context.Context, tx sdk.Tx) error { // auction tx (tx that contains a single MsgAuctionBid), it will also remove all // referenced transactions from the global mempool. func (am *AuctionMempool) Remove(tx sdk.Tx) error { - msg, err := GetMsgAuctionBidFromTx(tx) + isAuctionTx, err := am.IsAuctionTx(tx) if err != nil { return err } // Remove the transactions from the appropriate index. switch { - case msg == nil: + case !isAuctionTx: am.removeTx(am.globalIndex, tx) - case msg != nil: + case isAuctionTx: am.removeTx(am.auctionIndex, tx) // Remove all referenced transactions from the global mempool. - for _, refRawTx := range msg.GetTransactions() { - refTx, err := am.txDecoder(refRawTx) + bundleTxs, err := am.GetBundledTransactions(tx) + if err != nil { + return err + } + + for _, refTx := range bundleTxs { + wrappedRefTx, err := am.WrapBundleTransaction(refTx) if err != nil { - return fmt.Errorf("failed to decode referenced tx: %w", err) + return err } - am.removeTx(am.globalIndex, refTx) + am.removeTx(am.globalIndex, wrappedRefTx) } } @@ -169,12 +178,12 @@ func (am *AuctionMempool) Remove(tx sdk.Tx) error { // API is used to ensure that searchers are unable to remove valid transactions // from the global mempool. func (am *AuctionMempool) RemoveWithoutRefTx(tx sdk.Tx) error { - msg, err := GetMsgAuctionBidFromTx(tx) + isAuctionTx, err := am.IsAuctionTx(tx) if err != nil { return err } - if msg != nil { + if isAuctionTx { am.removeTx(am.auctionIndex, tx) } @@ -219,19 +228,6 @@ func (am *AuctionMempool) Contains(tx sdk.Tx) (bool, error) { return ok, nil } -// getTxHashStr returns the transaction hash string for a given transaction. -func (am *AuctionMempool) getTxHashStr(tx sdk.Tx) (string, error) { - txBz, err := am.txEncoder(tx) - if err != nil { - return "", fmt.Errorf("failed to encode transaction: %w", err) - } - - txHash := sha256.Sum256(txBz) - txHashStr := hex.EncodeToString(txHash[:]) - - return txHashStr, nil -} - func (am *AuctionMempool) removeTx(mp sdkmempool.Mempool, tx sdk.Tx) { err := mp.Remove(tx) if err != nil && !errors.Is(err, sdkmempool.ErrTxNotFound) { @@ -245,3 +241,16 @@ func (am *AuctionMempool) removeTx(mp sdkmempool.Mempool, tx sdk.Tx) { delete(am.txIndex, txHashStr) } + +// getTxHashStr returns the transaction hash string for a given transaction. +func (am *AuctionMempool) getTxHashStr(tx sdk.Tx) (string, error) { + txBz, err := am.txEncoder(tx) + if err != nil { + return "", fmt.Errorf("failed to encode transaction: %w", err) + } + + txHash := sha256.Sum256(txBz) + txHashStr := hex.EncodeToString(txHash[:]) + + return txHashStr, nil +} diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index a2de753..cd9c10c 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -32,7 +32,8 @@ func TestMempoolTestSuite(t *testing.T) { func (suite *IntegrationTestSuite) SetupTest() { // Mempool setup suite.encCfg = testutils.CreateTestEncodingConfig() - suite.mempool = mempool.NewAuctionMempool(suite.encCfg.TxConfig.TxDecoder(), suite.encCfg.TxConfig.TxEncoder(), 0) + config := mempool.NewDefaultConfig(suite.encCfg.TxConfig.TxDecoder()) + suite.mempool = mempool.NewAuctionMempool(suite.encCfg.TxConfig.TxDecoder(), suite.encCfg.TxConfig.TxEncoder(), 0, config) suite.ctx = sdk.NewContext(nil, cmtproto.Header{}, false, log.NewNopLogger()) // Init accounts diff --git a/mempool/tx.go b/mempool/tx.go new file mode 100644 index 0000000..cae6d26 --- /dev/null +++ b/mempool/tx.go @@ -0,0 +1,90 @@ +package mempool + +import ( + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// IsAuctionTx returns true if the transaction is a transaction that is attempting to +// bid to the auction. +func (am *AuctionMempool) IsAuctionTx(tx sdk.Tx) (bool, error) { + return am.config.IsAuctionTx(tx) +} + +// GetTransactionSigners returns the signers of the bundle transaction. +func (am *AuctionMempool) GetTransactionSigners(tx []byte) (map[string]struct{}, error) { + return am.config.GetTransactionSigners(tx) +} + +// WrapBundleTransaction wraps a bundle transaction into sdk.Tx transaction. +func (am *AuctionMempool) WrapBundleTransaction(tx []byte) (sdk.Tx, error) { + return am.config.WrapBundleTransaction(tx) +} + +// GetAuctionBidInfo returns the bid info from an auction transaction. +func (am *AuctionMempool) GetAuctionBidInfo(tx sdk.Tx) (AuctionBidInfo, error) { + bidder, err := am.GetBidder(tx) + if err != nil { + return AuctionBidInfo{}, err + } + + bid, err := am.GetBid(tx) + if err != nil { + return AuctionBidInfo{}, err + } + + transactions, err := am.GetBundledTransactions(tx) + if err != nil { + return AuctionBidInfo{}, err + } + + return AuctionBidInfo{ + Bidder: bidder, + Bid: bid, + Transactions: transactions, + }, nil +} + +// GetBidder returns the bidder from an auction transaction. +func (am *AuctionMempool) GetBidder(tx sdk.Tx) (sdk.AccAddress, error) { + if isAuctionTx, err := am.IsAuctionTx(tx); err != nil || !isAuctionTx { + return nil, fmt.Errorf("transaction is not an auction transaction") + } + + return am.config.GetBidder(tx) +} + +// GetBid returns the bid from an auction transaction. +func (am *AuctionMempool) GetBid(tx sdk.Tx) (sdk.Coin, error) { + if isAuctionTx, err := am.IsAuctionTx(tx); err != nil || !isAuctionTx { + return sdk.Coin{}, fmt.Errorf("transaction is not an auction transaction") + } + + return am.config.GetBid(tx) +} + +// GetBundledTransactions returns the transactions that are bundled in an auction transaction. +func (am *AuctionMempool) GetBundledTransactions(tx sdk.Tx) ([][]byte, error) { + if isAuctionTx, err := am.IsAuctionTx(tx); err != nil || !isAuctionTx { + return nil, fmt.Errorf("transaction is not an auction transaction") + } + + return am.config.GetBundledTransactions(tx) +} + +// GetBundleSigners returns all of the signers for each transaction in the bundle. +func (am *AuctionMempool) GetBundleSigners(txs [][]byte) ([]map[string]struct{}, error) { + signers := make([]map[string]struct{}, len(txs)) + + for index, tx := range txs { + txSigners, err := am.GetTransactionSigners(tx) + if err != nil { + return nil, err + } + + signers[index] = txSigners + } + + return signers, nil +} diff --git a/x/builder/ante/ante.go b/x/builder/ante/ante.go index 9058f01..5f9aabc 100644 --- a/x/builder/ante/ante.go +++ b/x/builder/ante/ante.go @@ -51,41 +51,25 @@ func (ad BuilderDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, } } - auctionMsg, err := mempool.GetMsgAuctionBidFromTx(tx) + isAuctionTx, err := ad.mempool.IsAuctionTx(tx) if err != nil { return ctx, err } // Validate the auction bid if one exists. - if auctionMsg != nil { - auctionTx, ok := tx.(TxWithTimeoutHeight) - if !ok { - return ctx, fmt.Errorf("transaction does not implement TxWithTimeoutHeight") + if isAuctionTx { + // Auction transactions must have a timeout set to a valid block height. + if err := ad.HasValidTimeout(ctx, tx); err != nil { + return ctx, err } - timeout := auctionTx.GetTimeoutHeight() - if timeout == 0 { - return ctx, fmt.Errorf("timeout height cannot be zero") - } - - bidder, err := sdk.AccAddressFromBech32(auctionMsg.Bidder) + bidInfo, err := ad.mempool.GetAuctionBidInfo(tx) if err != nil { - return ctx, errors.Wrapf(err, "invalid bidder address (%s)", auctionMsg.Bidder) + return ctx, err } - transactions := make([]sdk.Tx, len(auctionMsg.Transactions)) - for i, tx := range auctionMsg.Transactions { - decodedTx, err := ad.txDecoder(tx) - if err != nil { - return ctx, errors.Wrapf(err, "failed to decode transaction (%s)", tx) - } - - transactions[i] = decodedTx - } - - topBid := sdk.Coin{} - // If the current transaction is the highest bidding transaction, then the highest bid is empty. + topBid := sdk.Coin{} isTopBidTx, err := ad.IsTopBidTx(ctx, tx) if err != nil { return ctx, errors.Wrap(err, "failed to check if current transaction is highest bidding transaction") @@ -99,7 +83,13 @@ func (ad BuilderDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, } } - if err := ad.builderKeeper.ValidateAuctionMsg(ctx, bidder, auctionMsg.Bid, topBid, transactions); err != nil { + // Extract signers from bundle for verification. + signers, err := ad.mempool.GetBundleSigners(bidInfo.Transactions) + if err != nil { + return ctx, errors.Wrap(err, "failed to get bundle signers") + } + + if err := ad.builderKeeper.ValidateBidInfo(ctx, topBid, bidInfo, signers); err != nil { return ctx, errors.Wrap(err, "failed to validate auction bid") } } @@ -114,12 +104,12 @@ func (ad BuilderDecorator) GetTopAuctionBid(ctx sdk.Context) (sdk.Coin, error) { return sdk.Coin{}, nil } - msgAuctionBid, err := mempool.GetMsgAuctionBidFromTx(auctionTx) + bid, err := ad.mempool.GetBid(auctionTx) if err != nil { return sdk.Coin{}, err } - return msgAuctionBid.Bid, nil + return bid, nil } // IsTopBidTx returns true if the transaction inputted is the highest bidding auction transaction in the mempool. @@ -141,3 +131,22 @@ func (ad BuilderDecorator) IsTopBidTx(ctx sdk.Context, tx sdk.Tx) (bool, error) return bytes.Equal(topBidBz, currentTxBz), nil } + +// HasValidTimeout returns true if the transaction has a valid timeout height. +func (ad BuilderDecorator) HasValidTimeout(ctx sdk.Context, tx sdk.Tx) error { + auctionTx, ok := tx.(TxWithTimeoutHeight) + if !ok { + return fmt.Errorf("transaction does not implement TxWithTimeoutHeight") + } + + timeout := auctionTx.GetTimeoutHeight() + if timeout == 0 { + return fmt.Errorf("timeout height cannot be zero") + } + + if timeout < uint64(ctx.BlockHeight()) { + return fmt.Errorf("timeout height cannot be less than the current block height") + } + + return nil +} diff --git a/x/builder/ante/ante_test.go b/x/builder/ante/ante_test.go index 82dc102..134984c 100644 --- a/x/builder/ante/ante_test.go +++ b/x/builder/ante/ante_test.go @@ -244,7 +244,8 @@ func (suite *AnteTestSuite) TestAnteHandler() { suite.Require().NoError(err) // Insert the top bid into the mempool - mempool := mempool.NewAuctionMempool(suite.encodingConfig.TxConfig.TxDecoder(), suite.encodingConfig.TxConfig.TxEncoder(), 0) + config := mempool.NewDefaultConfig(suite.encodingConfig.TxConfig.TxDecoder()) + mempool := mempool.NewAuctionMempool(suite.encodingConfig.TxConfig.TxDecoder(), suite.encodingConfig.TxConfig.TxEncoder(), 0, config) if insertTopBid { topAuctionTx, err := testutils.CreateAuctionTxWithSigners(suite.encodingConfig.TxConfig, topBidder, topBid, 0, timeout, []testutils.Account{}) suite.Require().NoError(err) diff --git a/x/builder/keeper/auction.go b/x/builder/keeper/auction.go index 0d9612d..ae8c4a9 100644 --- a/x/builder/keeper/auction.go +++ b/x/builder/keeper/auction.go @@ -4,22 +4,23 @@ import ( "fmt" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/skip-mev/pob/mempool" ) -// ValidateAuctionMsg validates that the MsgAuctionBid can be included in the auction. -func (k Keeper) ValidateAuctionMsg(ctx sdk.Context, bidder sdk.AccAddress, bid, highestBid sdk.Coin, transactions []sdk.Tx) error { +// ValidateBidInfo validates that the bid can be included in the auction. +func (k Keeper) ValidateBidInfo(ctx sdk.Context, highestBid sdk.Coin, bidInfo mempool.AuctionBidInfo, signers []map[string]struct{}) error { // Validate the bundle size. maxBundleSize, err := k.GetMaxBundleSize(ctx) if err != nil { return err } - if uint32(len(transactions)) > maxBundleSize { - return fmt.Errorf("bundle size (%d) exceeds max bundle size (%d)", len(transactions), maxBundleSize) + if uint32(len(bidInfo.Transactions)) > maxBundleSize { + return fmt.Errorf("bundle size (%d) exceeds max bundle size (%d)", len(bidInfo.Transactions), maxBundleSize) } // Validate the bid amount. - if err := k.ValidateAuctionBid(ctx, bidder, bid, highestBid); err != nil { + if err := k.ValidateAuctionBid(ctx, bidInfo.Bidder, bidInfo.Bid, highestBid); err != nil { return err } @@ -30,7 +31,7 @@ func (k Keeper) ValidateAuctionMsg(ctx sdk.Context, bidder sdk.AccAddress, bid, } if protectionEnabled { - if err := k.ValidateAuctionBundle(bidder, transactions); err != nil { + if err := k.ValidateAuctionBundle(bidInfo.Bidder, signers); err != nil { return err } } @@ -100,27 +101,19 @@ func (k Keeper) ValidateAuctionBid(ctx sdk.Context, bidder sdk.AccAddress, bid, // 2. valid: [tx1, tx2, tx3, tx4] where tx1 - tx4 are signed by the bidder. // 3. invalid: [tx1, tx2, tx3] where tx1 and tx3 are signed by the bidder and tx2 is signed by some other signer. (possible sandwich attack) // 4. invalid: [tx1, tx2, tx3] where tx1 is signed by the bidder, and tx2 - tx3 are signed by some other signer. (possible front-running attack) -func (k Keeper) ValidateAuctionBundle(bidder sdk.AccAddress, transactions []sdk.Tx) error { - if len(transactions) <= 1 { +func (k Keeper) ValidateAuctionBundle(bidder sdk.AccAddress, bundleSigners []map[string]struct{}) error { + if len(bundleSigners) <= 1 { return nil } // prevSigners is used to track whether the signers of the current transaction overlap. - prevSigners, err := k.getTxSigners(transactions[0]) - if err != nil { - return err - } - seenBidder := prevSigners[bidder.String()] + prevSigners := bundleSigners[0] + _, seenBidder := prevSigners[bidder.String()] // Check that all subsequent transactions are signed by either // 1. the same party as the first transaction // 2. the same party for some arbitrary number of txs and then are all remaining txs are signed by the bidder. - for _, refTx := range transactions[1:] { - txSigners, err := k.getTxSigners(refTx) - if err != nil { - return err - } - + for _, txSigners := range bundleSigners[1:] { // Filter the signers to only those that signed the current transaction. filterSigners(prevSigners, txSigners) @@ -132,7 +125,7 @@ func (k Keeper) ValidateAuctionBundle(bidder sdk.AccAddress, transactions []sdk. } seenBidder = true - prevSigners = map[string]bool{bidder.String(): true} + prevSigners = map[string]struct{}{bidder.String(): {}} filterSigners(prevSigners, txSigners) if len(prevSigners) == 0 { @@ -144,22 +137,8 @@ func (k Keeper) ValidateAuctionBundle(bidder sdk.AccAddress, transactions []sdk. return nil } -// getTxSigners returns the signers of a transaction. -func (k Keeper) getTxSigners(tx sdk.Tx) (map[string]bool, error) { - signers := make(map[string]bool, 0) - for _, msg := range tx.GetMsgs() { - for _, signer := range msg.GetSigners() { - // TODO: check for multi-sig accounts - // https://github.com/skip-mev/pob/issues/14 - signers[signer.String()] = true - } - } - - return signers, nil -} - // filterSigners removes any signers from the currentSigners map that are not in the txSigners map. -func filterSigners(currentSigners, txSigners map[string]bool) { +func filterSigners(currentSigners, txSigners map[string]struct{}) { for signer := range currentSigners { if _, ok := txSigners[signer]; !ok { delete(currentSigners, signer) diff --git a/x/builder/keeper/auction_test.go b/x/builder/keeper/auction_test.go index a14821f..ddea532 100644 --- a/x/builder/keeper/auction_test.go +++ b/x/builder/keeper/auction_test.go @@ -5,6 +5,7 @@ import ( "time" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/skip-mev/pob/mempool" testutils "github.com/skip-mev/pob/testutils" "github.com/skip-mev/pob/x/builder/keeper" buildertypes "github.com/skip-mev/pob/x/builder/types" @@ -181,14 +182,26 @@ func (suite *KeeperTestSuite) TestValidateAuctionMsg() { suite.builderKeeper.SetParams(suite.ctx, params) // Create the bundle of transactions ordered by accounts - bundle := make([]sdk.Tx, 0) + bundle := make([][]byte, 0) for _, acc := range accounts { tx, err := testutils.CreateRandomTx(suite.encCfg.TxConfig, acc, 0, 1, 100) suite.Require().NoError(err) - bundle = append(bundle, tx) + + txBz, err := suite.encCfg.TxConfig.TxEncoder()(tx) + suite.Require().NoError(err) + bundle = append(bundle, txBz) } - err := suite.builderKeeper.ValidateAuctionMsg(suite.ctx, bidder.Address, bid, highestBid, bundle) + bidInfo := mempool.AuctionBidInfo{ + Bidder: bidder.Address, + Bid: bid, + Transactions: bundle, + } + + signers, err := suite.mempool.GetBundleSigners(bundle) + suite.Require().NoError(err) + + err = suite.builderKeeper.ValidateBidInfo(suite.ctx, highestBid, bidInfo, signers) if tc.pass { suite.Require().NoError(err) } else { @@ -290,16 +303,22 @@ func (suite *KeeperTestSuite) TestValidateBundle() { tc.malleate() // Create the bundle of transactions ordered by accounts - bundle := make([]sdk.Tx, 0) + bundle := make([][]byte, 0) for _, acc := range accounts { // Create a random tx tx, err := testutils.CreateRandomTx(suite.encCfg.TxConfig, acc, 0, 1, 1000) suite.Require().NoError(err) - bundle = append(bundle, tx) + + txBz, err := suite.encCfg.TxConfig.TxEncoder()(tx) + suite.Require().NoError(err) + bundle = append(bundle, txBz) } + signers, err := suite.mempool.GetBundleSigners(bundle) + suite.Require().NoError(err) + // Validate the bundle - err := suite.builderKeeper.ValidateAuctionBundle(bidder.Address, bundle) + err = suite.builderKeeper.ValidateAuctionBundle(bidder.Address, signers) if tc.pass { suite.Require().NoError(err) } else { diff --git a/x/builder/keeper/keeper_test.go b/x/builder/keeper/keeper_test.go index 2dcb7d5..1a0cfbf 100644 --- a/x/builder/keeper/keeper_test.go +++ b/x/builder/keeper/keeper_test.go @@ -64,6 +64,7 @@ func (suite *KeeperTestSuite) SetupTest() { err := suite.builderKeeper.SetParams(suite.ctx, types.DefaultParams()) suite.Require().NoError(err) - suite.mempool = mempool.NewAuctionMempool(suite.encCfg.TxConfig.TxDecoder(), suite.encCfg.TxConfig.TxEncoder(), 0) + config := mempool.NewDefaultConfig(suite.encCfg.TxConfig.TxDecoder()) + suite.mempool = mempool.NewAuctionMempool(suite.encCfg.TxConfig.TxDecoder(), suite.encCfg.TxConfig.TxEncoder(), 0, config) suite.msgServer = keeper.NewMsgServerImpl(suite.builderKeeper) }