ethclient: fix tx sender cache miss detection (#23877)

This fixes a bug in TransactionSender where it would return the
zero address for transactions where the sender address wasn't
cached already.

Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
Lee Bousfield 2021-11-17 07:44:41 -06:00 committed by GitHub
parent fa96718512
commit 16341e0563
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 124 additions and 41 deletions

View File

@ -233,6 +233,8 @@ func (ec *Client) TransactionSender(ctx context.Context, tx *types.Transaction,
if err == nil { if err == nil {
return sender, nil return sender, nil
} }
// It was not found in cache, ask the server.
var meta struct { var meta struct {
Hash common.Hash Hash common.Hash
From common.Address From common.Address

View File

@ -187,9 +187,34 @@ var (
testBalance = big.NewInt(2e15) testBalance = big.NewInt(2e15)
) )
var genesis = &core.Genesis{
Config: params.AllEthashProtocolChanges,
Alloc: core.GenesisAlloc{testAddr: {Balance: testBalance}},
ExtraData: []byte("test genesis"),
Timestamp: 9000,
BaseFee: big.NewInt(params.InitialBaseFee),
}
var testTx1 = types.MustSignNewTx(testKey, types.LatestSigner(genesis.Config), &types.LegacyTx{
Nonce: 0,
Value: big.NewInt(12),
GasPrice: big.NewInt(params.InitialBaseFee),
Gas: params.TxGas,
To: &common.Address{2},
})
var testTx2 = types.MustSignNewTx(testKey, types.LatestSigner(genesis.Config), &types.LegacyTx{
Nonce: 1,
Value: big.NewInt(8),
GasPrice: big.NewInt(params.InitialBaseFee),
Gas: params.TxGas,
To: &common.Address{2},
})
func newTestBackend(t *testing.T) (*node.Node, []*types.Block) { func newTestBackend(t *testing.T) (*node.Node, []*types.Block) {
// Generate test chain. // Generate test chain.
genesis, blocks := generateTestChain() blocks := generateTestChain()
// Create node // Create node
n, err := node.New(&node.Config{}) n, err := node.New(&node.Config{})
if err != nil { if err != nil {
@ -212,25 +237,22 @@ func newTestBackend(t *testing.T) (*node.Node, []*types.Block) {
return n, blocks return n, blocks
} }
func generateTestChain() (*core.Genesis, []*types.Block) { func generateTestChain() []*types.Block {
db := rawdb.NewMemoryDatabase() db := rawdb.NewMemoryDatabase()
config := params.AllEthashProtocolChanges
genesis := &core.Genesis{
Config: config,
Alloc: core.GenesisAlloc{testAddr: {Balance: testBalance}},
ExtraData: []byte("test genesis"),
Timestamp: 9000,
BaseFee: big.NewInt(params.InitialBaseFee),
}
generate := func(i int, g *core.BlockGen) { generate := func(i int, g *core.BlockGen) {
g.OffsetTime(5) g.OffsetTime(5)
g.SetExtra([]byte("test")) g.SetExtra([]byte("test"))
if i == 1 {
// Test transactions are included in block #2.
g.AddTx(testTx1)
g.AddTx(testTx2)
}
} }
gblock := genesis.ToBlock(db) gblock := genesis.ToBlock(db)
engine := ethash.NewFaker() engine := ethash.NewFaker()
blocks, _ := core.GenerateChain(config, gblock, engine, db, 1, generate) blocks, _ := core.GenerateChain(genesis.Config, gblock, engine, db, 2, generate)
blocks = append([]*types.Block{gblock}, blocks...) blocks = append([]*types.Block{gblock}, blocks...)
return genesis, blocks return blocks
} }
func TestEthClient(t *testing.T) { func TestEthClient(t *testing.T) {
@ -242,30 +264,33 @@ func TestEthClient(t *testing.T) {
tests := map[string]struct { tests := map[string]struct {
test func(t *testing.T) test func(t *testing.T)
}{ }{
"TestHeader": { "Header": {
func(t *testing.T) { testHeader(t, chain, client) }, func(t *testing.T) { testHeader(t, chain, client) },
}, },
"TestBalanceAt": { "BalanceAt": {
func(t *testing.T) { testBalanceAt(t, client) }, func(t *testing.T) { testBalanceAt(t, client) },
}, },
"TestTxInBlockInterrupted": { "TxInBlockInterrupted": {
func(t *testing.T) { testTransactionInBlockInterrupted(t, client) }, func(t *testing.T) { testTransactionInBlockInterrupted(t, client) },
}, },
"TestChainID": { "ChainID": {
func(t *testing.T) { testChainID(t, client) }, func(t *testing.T) { testChainID(t, client) },
}, },
"TestGetBlock": { "GetBlock": {
func(t *testing.T) { testGetBlock(t, client) }, func(t *testing.T) { testGetBlock(t, client) },
}, },
"TestStatusFunctions": { "StatusFunctions": {
func(t *testing.T) { testStatusFunctions(t, client) }, func(t *testing.T) { testStatusFunctions(t, client) },
}, },
"TestCallContract": { "CallContract": {
func(t *testing.T) { testCallContract(t, client) }, func(t *testing.T) { testCallContract(t, client) },
}, },
"TestAtFunctions": { "AtFunctions": {
func(t *testing.T) { testAtFunctions(t, client) }, func(t *testing.T) { testAtFunctions(t, client) },
}, },
"TransactionSender": {
func(t *testing.T) { testTransactionSender(t, client) },
},
} }
t.Parallel() t.Parallel()
@ -321,6 +346,11 @@ func testBalanceAt(t *testing.T, client *rpc.Client) {
want *big.Int want *big.Int
wantErr error wantErr error
}{ }{
"valid_account_genesis": {
account: testAddr,
block: big.NewInt(0),
want: testBalance,
},
"valid_account": { "valid_account": {
account: testAddr, account: testAddr,
block: big.NewInt(1), block: big.NewInt(1),
@ -358,23 +388,25 @@ func testBalanceAt(t *testing.T, client *rpc.Client) {
func testTransactionInBlockInterrupted(t *testing.T, client *rpc.Client) { func testTransactionInBlockInterrupted(t *testing.T, client *rpc.Client) {
ec := NewClient(client) ec := NewClient(client)
// Get current block by number // Get current block by number.
block, err := ec.BlockByNumber(context.Background(), nil) block, err := ec.BlockByNumber(context.Background(), nil)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
// Test tx in block interupted
// Test tx in block interupted.
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
tx, err := ec.TransactionInBlock(ctx, block.Hash(), 1) tx, err := ec.TransactionInBlock(ctx, block.Hash(), 0)
if tx != nil { if tx != nil {
t.Fatal("transaction should be nil") t.Fatal("transaction should be nil")
} }
if err == nil || err == ethereum.NotFound { if err == nil || err == ethereum.NotFound {
t.Fatal("error should not be nil/notfound") t.Fatal("error should not be nil/notfound")
} }
// Test tx in block not found
if _, err := ec.TransactionInBlock(context.Background(), block.Hash(), 1); err != ethereum.NotFound { // Test tx in block not found.
if _, err := ec.TransactionInBlock(context.Background(), block.Hash(), 20); err != ethereum.NotFound {
t.Fatal("error should be ethereum.NotFound") t.Fatal("error should be ethereum.NotFound")
} }
} }
@ -392,12 +424,13 @@ func testChainID(t *testing.T, client *rpc.Client) {
func testGetBlock(t *testing.T, client *rpc.Client) { func testGetBlock(t *testing.T, client *rpc.Client) {
ec := NewClient(client) ec := NewClient(client)
// Get current block number // Get current block number
blockNumber, err := ec.BlockNumber(context.Background()) blockNumber, err := ec.BlockNumber(context.Background())
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
if blockNumber != 1 { if blockNumber != 2 {
t.Fatalf("BlockNumber returned wrong number: %d", blockNumber) t.Fatalf("BlockNumber returned wrong number: %d", blockNumber)
} }
// Get current block by number // Get current block by number
@ -445,6 +478,7 @@ func testStatusFunctions(t *testing.T, client *rpc.Client) {
if progress != nil { if progress != nil {
t.Fatalf("unexpected progress: %v", progress) t.Fatalf("unexpected progress: %v", progress)
} }
// NetworkID // NetworkID
networkID, err := ec.NetworkID(context.Background()) networkID, err := ec.NetworkID(context.Background())
if err != nil { if err != nil {
@ -453,20 +487,22 @@ func testStatusFunctions(t *testing.T, client *rpc.Client) {
if networkID.Cmp(big.NewInt(0)) != 0 { if networkID.Cmp(big.NewInt(0)) != 0 {
t.Fatalf("unexpected networkID: %v", networkID) t.Fatalf("unexpected networkID: %v", networkID)
} }
// SuggestGasPrice (should suggest 1 Gwei)
// SuggestGasPrice
gasPrice, err := ec.SuggestGasPrice(context.Background()) gasPrice, err := ec.SuggestGasPrice(context.Background())
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
if gasPrice.Cmp(big.NewInt(1875000000)) != 0 { // 1 gwei tip + 0.875 basefee after a 1 gwei fee empty block if gasPrice.Cmp(big.NewInt(1000000000)) != 0 {
t.Fatalf("unexpected gas price: %v", gasPrice) t.Fatalf("unexpected gas price: %v", gasPrice)
} }
// SuggestGasTipCap (should suggest 1 Gwei)
// SuggestGasTipCap
gasTipCap, err := ec.SuggestGasTipCap(context.Background()) gasTipCap, err := ec.SuggestGasTipCap(context.Background())
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
if gasTipCap.Cmp(big.NewInt(1000000000)) != 0 { if gasTipCap.Cmp(big.NewInt(234375000)) != 0 {
t.Fatalf("unexpected gas tip cap: %v", gasTipCap) t.Fatalf("unexpected gas tip cap: %v", gasTipCap)
} }
} }
@ -500,9 +536,11 @@ func testCallContract(t *testing.T, client *rpc.Client) {
func testAtFunctions(t *testing.T, client *rpc.Client) { func testAtFunctions(t *testing.T, client *rpc.Client) {
ec := NewClient(client) ec := NewClient(client)
// send a transaction for some interesting pending status // send a transaction for some interesting pending status
sendTransaction(ec) sendTransaction(ec)
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
// Check pending transaction count // Check pending transaction count
pending, err := ec.PendingTransactionCount(context.Background()) pending, err := ec.PendingTransactionCount(context.Background())
if err != nil { if err != nil {
@ -561,23 +599,66 @@ func testAtFunctions(t *testing.T, client *rpc.Client) {
} }
} }
func testTransactionSender(t *testing.T, client *rpc.Client) {
ec := NewClient(client)
ctx := context.Background()
// Retrieve testTx1 via RPC.
block2, err := ec.HeaderByNumber(ctx, big.NewInt(2))
if err != nil {
t.Fatal("can't get block 1:", err)
}
tx1, err := ec.TransactionInBlock(ctx, block2.Hash(), 0)
if err != nil {
t.Fatal("can't get tx:", err)
}
if tx1.Hash() != testTx1.Hash() {
t.Fatalf("wrong tx hash %v, want %v", tx1.Hash(), testTx1.Hash())
}
// The sender address is cached in tx1, so no additional RPC should be required in
// TransactionSender. Ensure the server is not asked by canceling the context here.
canceledCtx, cancel := context.WithCancel(context.Background())
cancel()
sender1, err := ec.TransactionSender(canceledCtx, tx1, block2.Hash(), 0)
if err != nil {
t.Fatal(err)
}
if sender1 != testAddr {
t.Fatal("wrong sender:", sender1)
}
// Now try to get the sender of testTx2, which was not fetched through RPC.
// TransactionSender should query the server here.
sender2, err := ec.TransactionSender(ctx, testTx2, block2.Hash(), 1)
if err != nil {
t.Fatal(err)
}
if sender2 != testAddr {
t.Fatal("wrong sender:", sender2)
}
}
func sendTransaction(ec *Client) error { func sendTransaction(ec *Client) error {
// Retrieve chainID
chainID, err := ec.ChainID(context.Background()) chainID, err := ec.ChainID(context.Background())
if err != nil { if err != nil {
return err return err
} }
// Create transaction nonce, err := ec.PendingNonceAt(context.Background(), testAddr)
tx := types.NewTransaction(0, common.Address{1}, big.NewInt(1), 22000, big.NewInt(params.InitialBaseFee), nil) if err != nil {
return err
}
signer := types.LatestSignerForChainID(chainID) signer := types.LatestSignerForChainID(chainID)
signature, err := crypto.Sign(signer.Hash(tx).Bytes(), testKey) tx, err := types.SignNewTx(testKey, signer, &types.LegacyTx{
Nonce: nonce,
To: &common.Address{2},
Value: big.NewInt(1),
Gas: 22000,
GasPrice: big.NewInt(params.InitialBaseFee),
})
if err != nil { if err != nil {
return err return err
} }
signedTx, err := tx.WithSignature(signer, signature) return ec.SendTransaction(context.Background(), tx)
if err != nil {
return err
}
// Send transaction
return ec.SendTransaction(context.Background(), signedTx)
} }

View File

@ -45,7 +45,7 @@ func (s *senderFromServer) Equal(other types.Signer) bool {
} }
func (s *senderFromServer) Sender(tx *types.Transaction) (common.Address, error) { func (s *senderFromServer) Sender(tx *types.Transaction) (common.Address, error) {
if s.blockhash == (common.Hash{}) { if s.addr == (common.Address{}) {
return common.Address{}, errNotCached return common.Address{}, errNotCached
} }
return s.addr, nil return s.addr, nil