diff --git a/block/base/mempool.go b/block/base/mempool.go index 17cf56d..331fda5 100644 --- a/block/base/mempool.go +++ b/block/base/mempool.go @@ -42,49 +42,6 @@ type ( } ) -// DefaultTxPriority returns a default implementation of the TxPriority. It prioritizes -// transactions by their fee. -func DefaultTxPriority() TxPriority[string] { - return TxPriority[string]{ - GetTxPriority: func(goCtx context.Context, tx sdk.Tx) string { - feeTx, ok := tx.(sdk.FeeTx) - if !ok { - return "" - } - - return feeTx.GetFee().String() - }, - Compare: func(a, b string) int { - aCoins, _ := sdk.ParseCoinsNormalized(a) - bCoins, _ := sdk.ParseCoinsNormalized(b) - - switch { - case aCoins == nil && bCoins == nil: - return 0 - - case aCoins == nil: - return -1 - - case bCoins == nil: - return 1 - - default: - switch { - case aCoins.IsAllGT(bCoins): - return 1 - - case aCoins.IsAllLT(bCoins): - return -1 - - default: - return 0 - } - } - }, - MinValue: "", - } -} - // NewMempool returns a new Mempool. func NewMempool[C comparable](txPriority TxPriority[C], txEncoder sdk.TxEncoder, extractor signer_extraction.Adapter, maxTx int) *Mempool[C] { return &Mempool[C]{ diff --git a/block/base/tx_priority.go b/block/base/tx_priority.go new file mode 100644 index 0000000..570dcf3 --- /dev/null +++ b/block/base/tx_priority.go @@ -0,0 +1,175 @@ +package base + +import ( + "context" + "fmt" + "strconv" + "strings" + + "cosmossdk.io/math" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +type Coins map[string]math.Int + +// DefaultTxPriority returns a default implementation of the TxPriority. It prioritizes +// transactions by their fee. +func DefaultTxPriority() TxPriority[string] { + return TxPriority[string]{ + GetTxPriority: func(goCtx context.Context, tx sdk.Tx) string { + feeTx, ok := tx.(sdk.FeeTx) + if !ok { + return "" + } + + return coinsToString(feeTx.GetFee()) + }, + Compare: func(a, b string) int { + aCoins, _ := coinsFromString(a) + bCoins, _ := coinsFromString(b) + + switch { + case aCoins.Greater(bCoins): + return 1 + + case bCoins.Greater(aCoins): + return -1 + + default: + return 0 + } + }, + MinValue: "", + } +} + +func coinsToString(coins sdk.Coins) string { + // sort the coins by denomination + coins.Sort() + + // to avoid dealing with regex, etc. we use a , to separate denominations from amounts + // e.g. 10000,stake,10000,atom + coinString := "" + for i, coin := range coins { + coinString += coin.Amount.String() + "," + coin.Denom + if i != len(coins)-1 { + coinString += "," + } + } + + return coinString +} + +// coinsFromString converts a string of coins to a sdk.Coins object. +func coinsFromString(coinsString string) (Coins, error) { + // if its empty string (zero value), we return nil + if coinsString == "" { + return nil, nil + } + + // split the string by commas + coinStrings := strings.Split(coinsString, ",") + + // if the length is odd, then the given string is invalid + if len(coinStrings)%2 != 0 { + return nil, fmt.Errorf("invalid coins string: %s", coinsString) + } + + coins := make(Coins, len(coinsString)/2) + for i := 0; i < len(coinStrings); i += 2 { + // split the string by pipe + amount, ok := intFromString(coinStrings[i]) + if !ok { + return nil, fmt.Errorf("invalid amount: %s, denom: %s", coinStrings[i], coinStrings[i+1]) + } + + coins[coinStrings[i+1]] = amount + } + + return coins, nil +} + +func intFromString(str string) (math.Int, bool) { + // first attempt to get int64 from the string + int64Val, err := strconv.ParseInt(str, 10, 64) + if err == nil { + return math.NewInt(int64Val), true + } + + // if we can't get an int64, then get raw math.Int + return math.NewIntFromString(str) +} + +// Greater returns true if lhs is strictly greater than rhs, and false otherwise. Notice, lhs / rhs must be comparable, +// specifically, they must have the exact same denoms, otherwise, they aren't comparable. +func (lhs Coins) Greater(rhs Coins) bool { + // if a or b is nil, then return whether a is non-nil + if lhs == nil || rhs == nil { + return lhs != nil + } + + // for each of a's denoms, check if b has the same denom + if len(lhs) != len(rhs) { + return false + } + + // for each of a's denoms, check if a is greater + for denom, aAmount := range lhs { + // b does not have the corresponding denom, a is not greater + bAmount, ok := rhs[denom] + if !ok { + return false + } + + // a is not greater than b + if !aAmount.GT(bAmount) { + return false + } + } + + return true +} + +// DeprecatedTxPriority serves the same purpose as DefaultTxPriority, however, it is significantly slower- on the order of +// 6-10x slower. +func DeprecatedTxPriority() TxPriority[string] { + return TxPriority[string]{ + GetTxPriority: func(goCtx context.Context, tx sdk.Tx) string { + feeTx, ok := tx.(sdk.FeeTx) + if !ok { + return "" + } + + return feeTx.GetFee().String() + }, + Compare: func(a, b string) int { + aCoins, _ := sdk.ParseCoinsNormalized(a) + bCoins, _ := sdk.ParseCoinsNormalized(b) + + switch { + case aCoins == nil && bCoins == nil: + return 0 + + case aCoins == nil: + return -1 + + case bCoins == nil: + return 1 + + default: + switch { + case aCoins.IsAllGT(bCoins): + return 1 + + case aCoins.IsAllLT(bCoins): + return -1 + + default: + return 0 + } + } + }, + MinValue: "", + } +} diff --git a/block/base/tx_priority_test.go b/block/base/tx_priority_test.go new file mode 100644 index 0000000..6506828 --- /dev/null +++ b/block/base/tx_priority_test.go @@ -0,0 +1,223 @@ +package base_test + +import ( + "fmt" + "math/rand" + "testing" + + "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + + "github.com/skip-mev/block-sdk/block/base" + "github.com/skip-mev/block-sdk/testutils" +) + +const maxUint64 = "18446744073709551616" // value is 2^64 + +func TestTxPriority(t *testing.T) { + acct := testutils.RandomAccounts(rand.New(rand.NewSource(1)), 1) + txc := testutils.CreateTestEncodingConfig().TxConfig + + type testCase struct { + name string + priority base.TxPriority[string] + } + + testCases := []testCase{ + { + "DeprecatedTxPriority", + base.DeprecatedTxPriority(), + }, + { + "DefaultTxPriority", + base.DefaultTxPriority(), + }, + } + t.Run("test getting a tx priority: DefaultTxPriority", func(t *testing.T) { + priority := base.DefaultTxPriority() + + tx, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(1))) + require.NoError(t, err) + + require.Equal(t, "1,stake", priority.GetTxPriority(nil, tx)) + }) + + t.Run("test with amt that is not uint64", func(t *testing.T) { + priority := base.DefaultTxPriority() + + amt, ok := math.NewIntFromString(maxUint64) + require.True(t, ok) + tx, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", amt)) + require.NoError(t, err) + + require.Equal(t, maxUint64+",stake", priority.GetTxPriority(nil, tx)) + + tx2, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(1))) + require.NoError(t, err) + + require.Equal(t, 1, priority.Compare(priority.GetTxPriority(nil, tx), priority.GetTxPriority(nil, tx2))) + }) + + t.Run("test invalid priorities", func(t *testing.T) { + priority := base.DefaultTxPriority() + + invalidAmount := "a,b" + invalidCoins := "1,stake,2" + + require.Equal(t, 0, priority.Compare(invalidAmount, "")) + require.Equal(t, 0, priority.Compare("", invalidCoins)) + }) + + for _, tc := range testCases { + t.Run(fmt.Sprintf("test with non-tx: %s", tc.name), func(t *testing.T) { + require.Equal(t, "", tc.priority.GetTxPriority(nil, nil)) + }) + + t.Run(fmt.Sprintf("test tx with no fee: %s", tc.name), func(t *testing.T) { + tx, err := testutils.CreateTx(txc, acct[0], 0, 0, nil) + require.NoError(t, err) + + require.Equal(t, "", tc.priority.GetTxPriority(nil, tx)) + }) + + t.Run(fmt.Sprintf("test comparing two tx priorities: %s", tc.name), func(t *testing.T) { + tx1, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(1))) + require.NoError(t, err) + + tx2, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(2))) + require.NoError(t, err) + + priority1 := tc.priority.GetTxPriority(nil, tx1) + priority2 := tc.priority.GetTxPriority(nil, tx2) + + require.Equal(t, -1, tc.priority.Compare(priority1, priority2)) + require.Equal(t, 1, tc.priority.Compare(priority2, priority1)) + require.Equal(t, 0, tc.priority.Compare(priority2, priority2)) + }) + + t.Run(fmt.Sprintf("test comparing two tx priorities with nil: %s", tc.name), func(t *testing.T) { + tx1, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(1))) + require.NoError(t, err) + + priority1 := tc.priority.GetTxPriority(nil, tx1) + + require.Equal(t, 1, tc.priority.Compare(priority1, "")) + require.Equal(t, -1, tc.priority.Compare("", priority1)) + require.Equal(t, 0, tc.priority.Compare("", "")) + }) + + t.Run(fmt.Sprintf("test with multiple fee coins: %s", tc.name), func(t *testing.T) { + tx, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(1)), sdk.NewCoin("atom", math.NewInt(2))) + require.NoError(t, err) + + tx2, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(2)), sdk.NewCoin("atom", math.NewInt(3))) + require.NoError(t, err) + + priority1 := tc.priority.GetTxPriority(nil, tx) + priority2 := tc.priority.GetTxPriority(nil, tx2) + + require.Equal(t, -1, tc.priority.Compare(priority1, priority2)) + require.Equal(t, 1, tc.priority.Compare(priority2, priority1)) + require.Equal(t, 0, tc.priority.Compare(priority2, priority2)) + }) + + t.Run(fmt.Sprintf("test with multiple different fee coins: %s", tc.name), func(t *testing.T) { + tx1, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(1)), sdk.NewCoin("atom", math.NewInt(2)), sdk.NewCoin("btc", math.NewInt(3))) + require.NoError(t, err) + + tx2, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(2)), sdk.NewCoin("eth", math.NewInt(3))) + require.NoError(t, err) + + priority1 := tc.priority.GetTxPriority(nil, tx1) + priority2 := tc.priority.GetTxPriority(nil, tx2) + + require.Equal(t, 0, tc.priority.Compare(priority1, priority2)) + + tx3, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(2)), sdk.NewCoin("osmo", math.NewInt(3)), sdk.NewCoin("btc", math.NewInt(3))) + require.NoError(t, err) + + priority3 := tc.priority.GetTxPriority(nil, tx3) + + require.Equal(t, 0, tc.priority.Compare(priority3, priority1)) + }) + + t.Run(fmt.Sprintf("one is nil, and the other isn't: %s", tc.name), func(t *testing.T) { + tx1, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(1)), sdk.NewCoin("atom", math.NewInt(2)), sdk.NewCoin("btc", math.NewInt(3))) + require.NoError(t, err) + + tx2, err := testutils.CreateTx(txc, acct[0], 0, 0, nil) + require.NoError(t, err) + + priority1 := tc.priority.GetTxPriority(nil, tx1) + priority2 := tc.priority.GetTxPriority(nil, tx2) + + require.Equal(t, 1, tc.priority.Compare(priority1, priority2)) + require.Equal(t, -1, tc.priority.Compare(priority2, priority1)) + require.Equal(t, 0, tc.priority.Compare(priority2, priority2)) + }) + + t.Run(fmt.Sprintf("incorrectly ordered fee tokens: %s", tc.name), func(t *testing.T) { + tx1, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(1)), sdk.NewCoin("atom", math.NewInt(2)), sdk.NewCoin("btc", math.NewInt(3))) + require.NoError(t, err) + + tx2, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("atom", math.NewInt(2)), sdk.NewCoin("stake", math.NewInt(1)), sdk.NewCoin("btc", math.NewInt(3))) + require.NoError(t, err) + + priority1 := tc.priority.GetTxPriority(nil, tx1) + priority2 := tc.priority.GetTxPriority(nil, tx2) + + require.Equal(t, tc.priority.Compare(priority1, priority2), 0) + }) + + t.Run(fmt.Sprintf("IBC tokens: %s", tc.name), func(t *testing.T) { + tx1, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("ibc/7F1D3FCF4AE79E1554D670D1AD949A9BA4E4A3C76C63093E17E446A46061A7A2", math.NewInt(1))) + require.NoError(t, err) + + tx2, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("ibc/7F1D3FCF4AE79E1554D670D1AD949A9BA4E4A3C76C63093E17E446A46061A7A2", math.NewInt(2))) + require.NoError(t, err) + + priority1 := tc.priority.GetTxPriority(nil, tx1) + priority2 := tc.priority.GetTxPriority(nil, tx2) + + require.Equal(t, -1, tc.priority.Compare(priority1, priority2)) + require.Equal(t, 1, tc.priority.Compare(priority2, priority1)) + require.Equal(t, 0, tc.priority.Compare(priority2, priority2)) + }) + } +} + +func BenchmarkDefaultTxPriority(b *testing.B) { + acct := testutils.RandomAccounts(rand.New(rand.NewSource(1)), 1) + txc := testutils.CreateTestEncodingConfig().TxConfig + + priority := base.DefaultTxPriority() + + tx1, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(1)), sdk.NewCoin("atom", math.NewInt(2)), sdk.NewCoin("btc", math.NewInt(3))) + require.NoError(b, err) + + tx2, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(2)), sdk.NewCoin("eth", math.NewInt(3)), sdk.NewCoin("btc", math.NewInt(4))) + require.NoError(b, err) + + for i := 0; i < b.N; i++ { + priority.Compare(priority.GetTxPriority(nil, tx1), priority.GetTxPriority(nil, tx2)) + } +} + +func BenchmarkDeprecatedTxPriority(b *testing.B) { + // ignore setup + acct := testutils.RandomAccounts(rand.New(rand.NewSource(1)), 1) + txc := testutils.CreateTestEncodingConfig().TxConfig + + priority := base.DeprecatedTxPriority() + + tx1, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(1)), sdk.NewCoin("atom", math.NewInt(2)), sdk.NewCoin("btc", math.NewInt(3))) + require.NoError(b, err) + + tx2, err := testutils.CreateTx(txc, acct[0], 0, 0, nil, sdk.NewCoin("stake", math.NewInt(2)), sdk.NewCoin("eth", math.NewInt(3)), sdk.NewCoin("btc", math.NewInt(4))) + require.NoError(b, err) + // start timer + for i := 0; i < b.N; i++ { + priority.Compare(priority.GetTxPriority(nil, tx1), priority.GetTxPriority(nil, tx2)) + } +} diff --git a/lanes/base/mempool_test.go b/lanes/base/mempool_test.go index 3667a1c..6913c97 100644 --- a/lanes/base/mempool_test.go +++ b/lanes/base/mempool_test.go @@ -9,84 +9,7 @@ import ( testutils "github.com/skip-mev/block-sdk/testutils" ) -func (s *BaseTestSuite) TestGetTxPriority() { - txPriority := base.DefaultTxPriority() - - s.Run("should be able to get the priority off a normal transaction with fees", func() { - tx, err := testutils.CreateRandomTx( - s.encodingConfig.TxConfig, - s.accounts[0], - 0, - 0, - 0, - 0, - sdk.NewCoin(s.gasTokenDenom, math.NewInt(100)), - ) - s.Require().NoError(err) - - priority := txPriority.GetTxPriority(sdk.Context{}, tx) - s.Require().Equal(sdk.NewCoin(s.gasTokenDenom, math.NewInt(100)).String(), priority) - }) - - s.Run("should not get a priority when the transaction does not have a fee", func() { - tx, err := testutils.CreateRandomTx( - s.encodingConfig.TxConfig, - s.accounts[0], - 0, - 0, - 0, - 0, - ) - s.Require().NoError(err) - - priority := txPriority.GetTxPriority(sdk.Context{}, tx) - s.Require().Equal("", priority) - }) - - s.Run("should get a priority when the gas token is different", func() { - tx, err := testutils.CreateRandomTx( - s.encodingConfig.TxConfig, - s.accounts[0], - 0, - 0, - 0, - 0, - sdk.NewCoin("random", math.NewInt(100)), - ) - s.Require().NoError(err) - - priority := txPriority.GetTxPriority(sdk.Context{}, tx) - s.Require().Equal(sdk.NewCoin("random", math.NewInt(100)).String(), priority) - }) -} - func (s *BaseTestSuite) TestCompareTxPriority() { - txPriority := base.DefaultTxPriority() - - s.Run("should return 0 when both priorities are nil", func() { - a := sdk.NewCoin(s.gasTokenDenom, math.NewInt(0)).String() - b := sdk.NewCoin(s.gasTokenDenom, math.NewInt(0)).String() - s.Require().Equal(0, txPriority.Compare(a, b)) - }) - - s.Run("should return 1 when the first priority is greater", func() { - a := sdk.NewCoin(s.gasTokenDenom, math.NewInt(100)).String() - b := sdk.NewCoin(s.gasTokenDenom, math.NewInt(1)).String() - s.Require().Equal(1, txPriority.Compare(a, b)) - }) - - s.Run("should return -1 when the second priority is greater", func() { - a := sdk.NewCoin(s.gasTokenDenom, math.NewInt(1)).String() - b := sdk.NewCoin(s.gasTokenDenom, math.NewInt(100)).String() - s.Require().Equal(-1, txPriority.Compare(a, b)) - }) - - s.Run("should return 0 when both priorities are equal", func() { - a := sdk.NewCoin(s.gasTokenDenom, math.NewInt(100)).String() - b := sdk.NewCoin(s.gasTokenDenom, math.NewInt(100)).String() - s.Require().Equal(0, txPriority.Compare(a, b)) - }) - lane := s.initLane(math.LegacyOneDec(), nil) s.Run("should return -1 when signers are the same but the first tx has a higher sequence", func() { diff --git a/lanes/base/tx_info_test.go b/lanes/base/tx_info_test.go index f661e93..77d86a0 100644 --- a/lanes/base/tx_info_test.go +++ b/lanes/base/tx_info_test.go @@ -39,11 +39,6 @@ func (s *BaseTestSuite) TestGetTxInfo() { s.Require().Equal(signer.Address.String(), txInfo.Signers[0].Signer.String()) s.Require().Equal(nonce, txInfo.Signers[0].Sequence) - // Verify the priority - actualfee, err := sdk.ParseCoinsNormalized(txInfo.Priority.(string)) - s.Require().NoError(err) - s.Require().Equal(fee, actualfee) - // Verify the gas limit s.Require().Equal(gasLimit, txInfo.GasLimit) @@ -82,11 +77,6 @@ func (s *BaseTestSuite) TestGetTxInfo() { s.Require().Equal(signer.Address.String(), txInfo.Signers[0].Signer.String()) s.Require().Equal(nonce, txInfo.Signers[0].Sequence) - // Verify the priority - actualfee, err := sdk.ParseCoinsNormalized(txInfo.Priority.(string)) - s.Require().NoError(err) - s.Require().Equal(fee, actualfee) - // Verify the bytes txBz, err := s.encodingConfig.TxConfig.TxEncoder()(tx) s.Require().NoError(err)