diff --git a/cmd/devp2p/internal/ethtest/helpers.go b/cmd/devp2p/internal/ethtest/helpers.go index b57649ade..70ed2d210 100644 --- a/cmd/devp2p/internal/ethtest/helpers.go +++ b/cmd/devp2p/internal/ethtest/helpers.go @@ -63,8 +63,9 @@ func (s *Suite) dial() (*Conn, error) { conn.caps = []p2p.Cap{ {Name: "eth", Version: 66}, {Name: "eth", Version: 67}, + {Name: "eth", Version: 68}, } - conn.ourHighestProtoVersion = 67 + conn.ourHighestProtoVersion = 68 return &conn, nil } @@ -359,6 +360,8 @@ func (s *Suite) waitAnnounce(conn *Conn, blockAnnouncement *NewBlock) error { return nil // ignore tx announcements from previous tests + case *NewPooledTransactionHashes66: + continue case *NewPooledTransactionHashes: continue case *Transactions: diff --git a/cmd/devp2p/internal/ethtest/suite.go b/cmd/devp2p/internal/ethtest/suite.go index 4497478d7..815353be7 100644 --- a/cmd/devp2p/internal/ethtest/suite.go +++ b/cmd/devp2p/internal/ethtest/suite.go @@ -510,17 +510,18 @@ func (s *Suite) TestNewPooledTxs(t *utesting.T) { } // generate 50 txs - hashMap, _, err := generateTxs(s, 50) + _, txs, err := generateTxs(s, 50) if err != nil { t.Fatalf("failed to generate transactions: %v", err) } - - // create new pooled tx hashes announcement - hashes := make([]common.Hash, 0) - for _, hash := range hashMap { - hashes = append(hashes, hash) + hashes := make([]common.Hash, len(txs)) + types := make([]byte, len(txs)) + sizes := make([]uint32, len(txs)) + for i, tx := range txs { + hashes[i] = tx.Hash() + types[i] = tx.Type() + sizes[i] = uint32(tx.Size()) } - announce := NewPooledTransactionHashes(hashes) // send announcement conn, err := s.dial() @@ -531,7 +532,13 @@ func (s *Suite) TestNewPooledTxs(t *utesting.T) { if err = conn.peer(s.chain, nil); err != nil { t.Fatalf("peering failed: %v", err) } - if err = conn.Write(announce); err != nil { + + var ann Message = NewPooledTransactionHashes{Types: types, Sizes: sizes, Hashes: hashes} + if conn.negotiatedProtoVersion < eth.ETH68 { + ann = NewPooledTransactionHashes66(hashes) + } + err = conn.Write(ann) + if err != nil { t.Fatalf("failed to write to connection: %v", err) } @@ -546,6 +553,8 @@ func (s *Suite) TestNewPooledTxs(t *utesting.T) { return // ignore propagated txs from previous tests + case *NewPooledTransactionHashes66: + continue case *NewPooledTransactionHashes: continue case *Transactions: diff --git a/cmd/devp2p/internal/ethtest/transaction.go b/cmd/devp2p/internal/ethtest/transaction.go index baa55bd49..bf3a4b7f0 100644 --- a/cmd/devp2p/internal/ethtest/transaction.go +++ b/cmd/devp2p/internal/ethtest/transaction.go @@ -95,7 +95,7 @@ func sendSuccessfulTx(s *Suite, tx *types.Transaction, prevTx *types.Transaction } } return fmt.Errorf("missing transaction: got %v missing %v", recTxs, tx.Hash()) - case *NewPooledTransactionHashes: + case *NewPooledTransactionHashes66: txHashes := *msg // if you receive an old tx propagation, read from connection again if len(txHashes) == 1 && prevTx != nil { @@ -110,6 +110,34 @@ func sendSuccessfulTx(s *Suite, tx *types.Transaction, prevTx *types.Transaction } } return fmt.Errorf("missing transaction announcement: got %v missing %v", txHashes, tx.Hash()) + case *NewPooledTransactionHashes: + txHashes := msg.Hashes + if len(txHashes) != len(msg.Sizes) { + return fmt.Errorf("invalid msg size lengths: hashes: %v sizes: %v", len(txHashes), len(msg.Sizes)) + } + if len(txHashes) != len(msg.Types) { + return fmt.Errorf("invalid msg type lengths: hashes: %v types: %v", len(txHashes), len(msg.Types)) + } + // if you receive an old tx propagation, read from connection again + if len(txHashes) == 1 && prevTx != nil { + if txHashes[0] == prevTx.Hash() { + continue + } + } + for index, gotHash := range txHashes { + if gotHash == tx.Hash() { + if msg.Sizes[index] != uint32(tx.Size()) { + return fmt.Errorf("invalid tx size: got %v want %v", msg.Sizes[index], tx.Size()) + } + if msg.Types[index] != tx.Type() { + return fmt.Errorf("invalid tx type: got %v want %v", msg.Types[index], tx.Type()) + } + // Ok + return nil + } + } + return fmt.Errorf("missing transaction announcement: got %v missing %v", txHashes, tx.Hash()) + default: return fmt.Errorf("unexpected message in sendSuccessfulTx: %s", pretty.Sdump(msg)) } @@ -201,8 +229,10 @@ func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, txs []*types.Transaction for _, tx := range *msg { recvHashes = append(recvHashes, tx.Hash()) } - case *NewPooledTransactionHashes: + case *NewPooledTransactionHashes66: recvHashes = append(recvHashes, *msg...) + case *NewPooledTransactionHashes: + recvHashes = append(recvHashes, msg.Hashes...) default: if !strings.Contains(pretty.Sdump(msg), "i/o timeout") { return fmt.Errorf("unexpected message while waiting to receive txs: %s", pretty.Sdump(msg)) @@ -246,11 +276,16 @@ func checkMaliciousTxPropagation(s *Suite, txs []*types.Transaction, conn *Conn) if len(badTxs) > 0 { return fmt.Errorf("received %d bad txs: \n%v", len(badTxs), badTxs) } - case *NewPooledTransactionHashes: + case *NewPooledTransactionHashes66: badTxs, _ := compareReceivedTxs(*msg, txs) if len(badTxs) > 0 { return fmt.Errorf("received %d bad txs: \n%v", len(badTxs), badTxs) } + case *NewPooledTransactionHashes: + badTxs, _ := compareReceivedTxs(msg.Hashes, txs) + if len(badTxs) > 0 { + return fmt.Errorf("received %d bad txs: \n%v", len(badTxs), badTxs) + } case *Error: // Transaction should not be announced -> wait for timeout return nil diff --git a/cmd/devp2p/internal/ethtest/types.go b/cmd/devp2p/internal/ethtest/types.go index fd5251d16..3c7b6dbcf 100644 --- a/cmd/devp2p/internal/ethtest/types.go +++ b/cmd/devp2p/internal/ethtest/types.go @@ -126,8 +126,14 @@ type NewBlock eth.NewBlockPacket func (msg NewBlock) Code() int { return 23 } func (msg NewBlock) ReqID() uint64 { return 0 } +// NewPooledTransactionHashes66 is the network packet for the tx hash propagation message. +type NewPooledTransactionHashes66 eth.NewPooledTransactionHashesPacket66 + +func (msg NewPooledTransactionHashes66) Code() int { return 24 } +func (msg NewPooledTransactionHashes66) ReqID() uint64 { return 0 } + // NewPooledTransactionHashes is the network packet for the tx hash propagation message. -type NewPooledTransactionHashes eth.NewPooledTransactionHashesPacket66 +type NewPooledTransactionHashes eth.NewPooledTransactionHashesPacket68 func (msg NewPooledTransactionHashes) Code() int { return 24 } func (msg NewPooledTransactionHashes) ReqID() uint64 { return 0 } @@ -202,8 +208,13 @@ func (c *Conn) Read() Message { msg = new(NewBlockHashes) case (Transactions{}).Code(): msg = new(Transactions) - case (NewPooledTransactionHashes{}).Code(): - msg = new(NewPooledTransactionHashes) + case (NewPooledTransactionHashes66{}).Code(): + // Try decoding to eth68 + ethMsg := new(NewPooledTransactionHashes) + if err := rlp.DecodeBytes(rawData, ethMsg); err == nil { + return ethMsg + } + msg = new(NewPooledTransactionHashes66) case (GetPooledTransactions{}.Code()): ethMsg := new(eth.GetPooledTransactionsPacket66) if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {