cmd/devp2p: refactor eth test suite (#22843)

This PR refactors the eth test suite to make it more readable and
easier to use. Some notable differences:

- A new file helpers.go stores all of the methods used between
  both eth66 and eth65 and below tests, as well as methods shared
  among many test functions.
- suite.go now contains all of the test functions for both eth65
  tests and eth66 tests.
- The utesting.T object doesn't get passed through to other helper methods,
  but is instead only used within the scope of the test function,
  whereas helper methods return errors, so only the test function
  itself can fatal out in the case of an error.
- The full test suite now only takes 13.5 seconds to run.
This commit is contained in:
rene 2021-05-25 23:09:11 +02:00 committed by GitHub
parent 6c7d6cf886
commit 49bde05a55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1506 additions and 1461 deletions

View File

@ -54,10 +54,24 @@ func (c *Chain) Len() int {
return len(c.blocks) return len(c.blocks)
} }
// TD calculates the total difficulty of the chain. // TD calculates the total difficulty of the chain at the
func (c *Chain) TD(height int) *big.Int { // TODO later on channge scheme so that the height is included in range // chain head.
func (c *Chain) TD() *big.Int {
sum := big.NewInt(0) sum := big.NewInt(0)
for _, block := range c.blocks[:height] { for _, block := range c.blocks[:c.Len()] {
sum.Add(sum, block.Difficulty())
}
return sum
}
// TotalDifficultyAt calculates the total difficulty of the chain
// at the given block height.
func (c *Chain) TotalDifficultyAt(height int) *big.Int {
sum := big.NewInt(0)
if height >= c.Len() {
return sum
}
for _, block := range c.blocks[:height+1] {
sum.Add(sum, block.Difficulty()) sum.Add(sum, block.Difficulty())
} }
return sum return sum

View File

@ -1,521 +0,0 @@
// Copyright 2021 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package ethtest
import (
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/p2p"
)
// Is_66 checks if the node supports the eth66 protocol version,
// and if not, exists the test suite
func (s *Suite) Is_66(t *utesting.T) {
conn := s.dial66(t)
conn.handshake(t)
if conn.negotiatedProtoVersion < 66 {
t.Fail()
}
}
// TestStatus_66 attempts to connect to the given node and exchange
// a status message with it on the eth66 protocol, and then check to
// make sure the chain head is correct.
func (s *Suite) TestStatus_66(t *utesting.T) {
conn := s.dial66(t)
defer conn.Close()
// get protoHandshake
conn.handshake(t)
// get status
switch msg := conn.statusExchange66(t, s.chain).(type) {
case *Status:
status := *msg
if status.ProtocolVersion != uint32(66) {
t.Fatalf("mismatch in version: wanted 66, got %d", status.ProtocolVersion)
}
t.Logf("got status message: %s", pretty.Sdump(msg))
default:
t.Fatalf("unexpected: %s", pretty.Sdump(msg))
}
}
// TestGetBlockHeaders_66 tests whether the given node can respond to
// an eth66 `GetBlockHeaders` request and that the response is accurate.
func (s *Suite) TestGetBlockHeaders_66(t *utesting.T) {
conn := s.setupConnection66(t)
defer conn.Close()
// get block headers
req := &eth.GetBlockHeadersPacket66{
RequestId: 3,
GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Origin: eth.HashOrNumber{
Hash: s.chain.blocks[1].Hash(),
},
Amount: 2,
Skip: 1,
Reverse: false,
},
}
// write message
headers, err := s.getBlockHeaders66(conn, req, req.RequestId)
if err != nil {
t.Fatalf("could not get block headers: %v", err)
}
// check for correct headers
if !headersMatch(t, s.chain, headers) {
t.Fatal("received wrong header(s)")
}
}
// TestSimultaneousRequests_66 sends two simultaneous `GetBlockHeader` requests
// with different request IDs and checks to make sure the node responds with the correct
// headers per request.
func (s *Suite) TestSimultaneousRequests_66(t *utesting.T) {
// create two connections
conn := s.setupConnection66(t)
defer conn.Close()
// create two requests
req1 := &eth.GetBlockHeadersPacket66{
RequestId: 111,
GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Origin: eth.HashOrNumber{
Hash: s.chain.blocks[1].Hash(),
},
Amount: 2,
Skip: 1,
Reverse: false,
},
}
req2 := &eth.GetBlockHeadersPacket66{
RequestId: 222,
GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Origin: eth.HashOrNumber{
Hash: s.chain.blocks[1].Hash(),
},
Amount: 4,
Skip: 1,
Reverse: false,
},
}
// write first request
if err := conn.write66(req1, GetBlockHeaders{}.Code()); err != nil {
t.Fatalf("failed to write to connection: %v", err)
}
// write second request
if err := conn.write66(req2, GetBlockHeaders{}.Code()); err != nil {
t.Fatalf("failed to write to connection: %v", err)
}
// wait for responses
headers1, err := s.waitForBlockHeadersResponse66(conn, req1.RequestId)
if err != nil {
t.Fatalf("error while waiting for block headers: %v", err)
}
headers2, err := s.waitForBlockHeadersResponse66(conn, req2.RequestId)
if err != nil {
t.Fatalf("error while waiting for block headers: %v", err)
}
// check headers of both responses
if !headersMatch(t, s.chain, headers1) {
t.Fatalf("wrong header(s) in response to req1: got %v", headers1)
}
if !headersMatch(t, s.chain, headers2) {
t.Fatalf("wrong header(s) in response to req2: got %v", headers2)
}
}
// TestBroadcast_66 tests whether a block announcement is correctly
// propagated to the given node's peer(s) on the eth66 protocol.
func (s *Suite) TestBroadcast_66(t *utesting.T) {
s.sendNextBlock66(t)
}
// TestGetBlockBodies_66 tests whether the given node can respond to
// a `GetBlockBodies` request and that the response is accurate over
// the eth66 protocol.
func (s *Suite) TestGetBlockBodies_66(t *utesting.T) {
conn := s.setupConnection66(t)
defer conn.Close()
// create block bodies request
id := uint64(55)
req := &eth.GetBlockBodiesPacket66{
RequestId: id,
GetBlockBodiesPacket: eth.GetBlockBodiesPacket{
s.chain.blocks[54].Hash(),
s.chain.blocks[75].Hash(),
},
}
if err := conn.write66(req, GetBlockBodies{}.Code()); err != nil {
t.Fatalf("could not write to connection: %v", err)
}
reqID, msg := conn.readAndServe66(s.chain, timeout)
switch msg := msg.(type) {
case BlockBodies:
if reqID != req.RequestId {
t.Fatalf("request ID mismatch: wanted %d, got %d", req.RequestId, reqID)
}
t.Logf("received %d block bodies", len(msg))
default:
t.Fatalf("unexpected: %s", pretty.Sdump(msg))
}
}
// TestLargeAnnounce_66 tests the announcement mechanism with a large block.
func (s *Suite) TestLargeAnnounce_66(t *utesting.T) {
nextBlock := len(s.chain.blocks)
blocks := []*NewBlock{
{
Block: largeBlock(),
TD: s.fullChain.TD(nextBlock + 1),
},
{
Block: s.fullChain.blocks[nextBlock],
TD: largeNumber(2),
},
{
Block: largeBlock(),
TD: largeNumber(2),
},
{
Block: s.fullChain.blocks[nextBlock],
TD: s.fullChain.TD(nextBlock + 1),
},
}
for i, blockAnnouncement := range blocks[0:3] {
t.Logf("Testing malicious announcement: %v\n", i)
sendConn := s.setupConnection66(t)
if err := sendConn.Write(blockAnnouncement); err != nil {
t.Fatalf("could not write to connection: %v", err)
}
// Invalid announcement, check that peer disconnected
switch msg := sendConn.ReadAndServe(s.chain, time.Second*8).(type) {
case *Disconnect:
case *Error:
break
default:
t.Fatalf("unexpected: %s wanted disconnect", pretty.Sdump(msg))
}
sendConn.Close()
}
// Test the last block as a valid block
s.sendNextBlock66(t)
}
func (s *Suite) TestOldAnnounce_66(t *utesting.T) {
sendConn, recvConn := s.setupConnection66(t), s.setupConnection66(t)
defer sendConn.Close()
defer recvConn.Close()
s.oldAnnounce(t, sendConn, recvConn)
}
// TestMaliciousHandshake_66 tries to send malicious data during the handshake.
func (s *Suite) TestMaliciousHandshake_66(t *utesting.T) {
conn := s.dial66(t)
defer conn.Close()
// write hello to client
pub0 := crypto.FromECDSAPub(&conn.ourKey.PublicKey)[1:]
handshakes := []*Hello{
{
Version: 5,
Caps: []p2p.Cap{
{Name: largeString(2), Version: 66},
},
ID: pub0,
},
{
Version: 5,
Caps: []p2p.Cap{
{Name: "eth", Version: 64},
{Name: "eth", Version: 65},
{Name: "eth", Version: 66},
},
ID: append(pub0, byte(0)),
},
{
Version: 5,
Caps: []p2p.Cap{
{Name: "eth", Version: 64},
{Name: "eth", Version: 65},
{Name: "eth", Version: 66},
},
ID: append(pub0, pub0...),
},
{
Version: 5,
Caps: []p2p.Cap{
{Name: "eth", Version: 64},
{Name: "eth", Version: 65},
{Name: "eth", Version: 66},
},
ID: largeBuffer(2),
},
{
Version: 5,
Caps: []p2p.Cap{
{Name: largeString(2), Version: 66},
},
ID: largeBuffer(2),
},
}
for i, handshake := range handshakes {
t.Logf("Testing malicious handshake %v\n", i)
// Init the handshake
if err := conn.Write(handshake); err != nil {
t.Fatalf("could not write to connection: %v", err)
}
// check that the peer disconnected
timeout := 20 * time.Second
// Discard one hello
for i := 0; i < 2; i++ {
switch msg := conn.ReadAndServe(s.chain, timeout).(type) {
case *Disconnect:
case *Error:
case *Hello:
// Hello's are sent concurrently, so ignore them
continue
default:
t.Fatalf("unexpected: %s", pretty.Sdump(msg))
}
}
// Dial for the next round
conn = s.dial66(t)
}
}
// TestMaliciousStatus_66 sends a status package with a large total difficulty.
func (s *Suite) TestMaliciousStatus_66(t *utesting.T) {
conn := s.dial66(t)
defer conn.Close()
// get protoHandshake
conn.handshake(t)
status := &Status{
ProtocolVersion: uint32(66),
NetworkID: s.chain.chainConfig.ChainID.Uint64(),
TD: largeNumber(2),
Head: s.chain.blocks[s.chain.Len()-1].Hash(),
Genesis: s.chain.blocks[0].Hash(),
ForkID: s.chain.ForkID(),
}
// get status
switch msg := conn.statusExchange(t, s.chain, status).(type) {
case *Status:
t.Logf("%+v\n", msg)
default:
t.Fatalf("expected status, got: %#v ", msg)
}
// wait for disconnect
switch msg := conn.ReadAndServe(s.chain, timeout).(type) {
case *Disconnect:
case *Error:
return
default:
t.Fatalf("expected disconnect, got: %s", pretty.Sdump(msg))
}
}
func (s *Suite) TestTransaction_66(t *utesting.T) {
tests := []*types.Transaction{
getNextTxFromChain(t, s),
unknownTx(t, s),
}
for i, tx := range tests {
t.Logf("Testing tx propagation: %v\n", i)
sendSuccessfulTx66(t, s, tx)
}
}
func (s *Suite) TestMaliciousTx_66(t *utesting.T) {
badTxs := []*types.Transaction{
getOldTxFromChain(t, s),
invalidNonceTx(t, s),
hugeAmount(t, s),
hugeGasPrice(t, s),
hugeData(t, s),
}
sendConn := s.setupConnection66(t)
defer sendConn.Close()
// set up receiving connection before sending txs to make sure
// no announcements are missed
recvConn := s.setupConnection66(t)
defer recvConn.Close()
for i, tx := range badTxs {
t.Logf("Testing malicious tx propagation: %v\n", i)
if err := sendConn.Write(&Transactions{tx}); err != nil {
t.Fatalf("could not write to connection: %v", err)
}
}
// check to make sure bad txs aren't propagated
waitForTxPropagation(t, s, badTxs, recvConn)
}
// TestZeroRequestID_66 checks that a request ID of zero is still handled
// by the node.
func (s *Suite) TestZeroRequestID_66(t *utesting.T) {
conn := s.setupConnection66(t)
defer conn.Close()
req := &eth.GetBlockHeadersPacket66{
RequestId: 0,
GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Origin: eth.HashOrNumber{
Number: 0,
},
Amount: 2,
},
}
headers, err := s.getBlockHeaders66(conn, req, req.RequestId)
if err != nil {
t.Fatalf("could not get block headers: %v", err)
}
if !headersMatch(t, s.chain, headers) {
t.Fatal("received wrong header(s)")
}
}
// TestSameRequestID_66 sends two requests with the same request ID
// concurrently to a single node.
func (s *Suite) TestSameRequestID_66(t *utesting.T) {
conn := s.setupConnection66(t)
// create two requests with the same request ID
reqID := uint64(1234)
request1 := &eth.GetBlockHeadersPacket66{
RequestId: reqID,
GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Origin: eth.HashOrNumber{
Number: 1,
},
Amount: 2,
},
}
request2 := &eth.GetBlockHeadersPacket66{
RequestId: reqID,
GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Origin: eth.HashOrNumber{
Number: 33,
},
Amount: 2,
},
}
// write the first request
err := conn.write66(request1, GetBlockHeaders{}.Code())
if err != nil {
t.Fatalf("could not write to connection: %v", err)
}
// perform second request
headers2, err := s.getBlockHeaders66(conn, request2, reqID)
if err != nil {
t.Fatalf("could not get block headers: %v", err)
return
}
// wait for response to first request
headers1, err := s.waitForBlockHeadersResponse66(conn, reqID)
if err != nil {
t.Fatalf("could not get BlockHeaders response: %v", err)
}
// check if headers match
if !headersMatch(t, s.chain, headers1) || !headersMatch(t, s.chain, headers2) {
t.Fatal("received wrong header(s)")
}
}
// TestLargeTxRequest_66 tests whether a node can fulfill a large GetPooledTransactions
// request.
func (s *Suite) TestLargeTxRequest_66(t *utesting.T) {
// send the next block to ensure the node is no longer syncing and is able to accept
// txs
s.sendNextBlock66(t)
// send 2000 transactions to the node
hashMap, txs := generateTxs(t, s, 2000)
sendConn := s.setupConnection66(t)
defer sendConn.Close()
sendMultipleSuccessfulTxs(t, s, sendConn, txs)
// set up connection to receive to ensure node is peered with the receiving connection
// before tx request is sent
recvConn := s.setupConnection66(t)
defer recvConn.Close()
// create and send pooled tx request
hashes := make([]common.Hash, 0)
for _, hash := range hashMap {
hashes = append(hashes, hash)
}
getTxReq := &eth.GetPooledTransactionsPacket66{
RequestId: 1234,
GetPooledTransactionsPacket: hashes,
}
if err := recvConn.write66(getTxReq, GetPooledTransactions{}.Code()); err != nil {
t.Fatalf("could not write to conn: %v", err)
}
// check that all received transactions match those that were sent to node
switch msg := recvConn.waitForResponse(s.chain, timeout, getTxReq.RequestId).(type) {
case PooledTransactions:
for _, gotTx := range msg {
if _, exists := hashMap[gotTx.Hash()]; !exists {
t.Fatalf("unexpected tx received: %v", gotTx.Hash())
}
}
default:
t.Fatalf("unexpected %s", pretty.Sdump(msg))
}
}
// TestNewPooledTxs_66 tests whether a node will do a GetPooledTransactions
// request upon receiving a NewPooledTransactionHashes announcement.
func (s *Suite) TestNewPooledTxs_66(t *utesting.T) {
// send the next block to ensure the node is no longer syncing and is able to accept
// txs
s.sendNextBlock66(t)
// generate 50 txs
hashMap, _ := generateTxs(t, s, 50)
// create new pooled tx hashes announcement
hashes := make([]common.Hash, 0)
for _, hash := range hashMap {
hashes = append(hashes, hash)
}
announce := NewPooledTransactionHashes(hashes)
// send announcement
conn := s.setupConnection66(t)
defer conn.Close()
if err := conn.Write(announce); err != nil {
t.Fatalf("could not write to connection: %v", err)
}
// wait for GetPooledTxs request
for {
_, msg := conn.readAndServe66(s.chain, timeout)
switch msg := msg.(type) {
case GetPooledTransactions:
if len(msg) != len(hashes) {
t.Fatalf("unexpected number of txs requested: wanted %d, got %d", len(hashes), len(msg))
}
return
case *NewPooledTransactionHashes, *NewBlock, *NewBlockHashes:
// ignore propagated txs and blocks from old tests
continue
default:
t.Fatalf("unexpected %s", pretty.Sdump(msg))
}
}
}

View File

@ -1,333 +0,0 @@
// Copyright 2021 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package ethtest
import (
"fmt"
"reflect"
"time"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rlp"
"github.com/stretchr/testify/assert"
)
func (c *Conn) statusExchange66(t *utesting.T, chain *Chain) Message {
status := &Status{
ProtocolVersion: uint32(66),
NetworkID: chain.chainConfig.ChainID.Uint64(),
TD: chain.TD(chain.Len()),
Head: chain.blocks[chain.Len()-1].Hash(),
Genesis: chain.blocks[0].Hash(),
ForkID: chain.ForkID(),
}
return c.statusExchange(t, chain, status)
}
func (s *Suite) dial66(t *utesting.T) *Conn {
conn, err := s.dial()
if err != nil {
t.Fatalf("could not dial: %v", err)
}
conn.caps = append(conn.caps, p2p.Cap{Name: "eth", Version: 66})
conn.ourHighestProtoVersion = 66
return conn
}
func (c *Conn) write66(req eth.Packet, code int) error {
payload, err := rlp.EncodeToBytes(req)
if err != nil {
return err
}
_, err = c.Conn.Write(uint64(code), payload)
return err
}
func (c *Conn) read66() (uint64, Message) {
code, rawData, _, err := c.Conn.Read()
if err != nil {
return 0, errorf("could not read from connection: %v", err)
}
var msg Message
switch int(code) {
case (Hello{}).Code():
msg = new(Hello)
case (Ping{}).Code():
msg = new(Ping)
case (Pong{}).Code():
msg = new(Pong)
case (Disconnect{}).Code():
msg = new(Disconnect)
case (Status{}).Code():
msg = new(Status)
case (GetBlockHeaders{}).Code():
ethMsg := new(eth.GetBlockHeadersPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
}
return ethMsg.RequestId, GetBlockHeaders(*ethMsg.GetBlockHeadersPacket)
case (BlockHeaders{}).Code():
ethMsg := new(eth.BlockHeadersPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
}
return ethMsg.RequestId, BlockHeaders(ethMsg.BlockHeadersPacket)
case (GetBlockBodies{}).Code():
ethMsg := new(eth.GetBlockBodiesPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
}
return ethMsg.RequestId, GetBlockBodies(ethMsg.GetBlockBodiesPacket)
case (BlockBodies{}).Code():
ethMsg := new(eth.BlockBodiesPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
}
return ethMsg.RequestId, BlockBodies(ethMsg.BlockBodiesPacket)
case (NewBlock{}).Code():
msg = new(NewBlock)
case (NewBlockHashes{}).Code():
msg = new(NewBlockHashes)
case (Transactions{}).Code():
msg = new(Transactions)
case (NewPooledTransactionHashes{}).Code():
msg = new(NewPooledTransactionHashes)
case (GetPooledTransactions{}.Code()):
ethMsg := new(eth.GetPooledTransactionsPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
}
return ethMsg.RequestId, GetPooledTransactions(ethMsg.GetPooledTransactionsPacket)
case (PooledTransactions{}.Code()):
ethMsg := new(eth.PooledTransactionsPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
}
return ethMsg.RequestId, PooledTransactions(ethMsg.PooledTransactionsPacket)
default:
msg = errorf("invalid message code: %d", code)
}
if msg != nil {
if err := rlp.DecodeBytes(rawData, msg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
}
return 0, msg
}
return 0, errorf("invalid message: %s", string(rawData))
}
func (c *Conn) waitForResponse(chain *Chain, timeout time.Duration, requestID uint64) Message {
for {
id, msg := c.readAndServe66(chain, timeout)
if id == requestID {
return msg
}
}
}
// ReadAndServe serves GetBlockHeaders requests while waiting
// on another message from the node.
func (c *Conn) readAndServe66(chain *Chain, timeout time.Duration) (uint64, Message) {
start := time.Now()
for time.Since(start) < timeout {
c.SetReadDeadline(time.Now().Add(10 * time.Second))
reqID, msg := c.read66()
switch msg := msg.(type) {
case *Ping:
c.Write(&Pong{})
case *GetBlockHeaders:
headers, err := chain.GetHeaders(*msg)
if err != nil {
return 0, errorf("could not get headers for inbound header request: %v", err)
}
resp := &eth.BlockHeadersPacket66{
RequestId: reqID,
BlockHeadersPacket: eth.BlockHeadersPacket(headers),
}
if err := c.write66(resp, BlockHeaders{}.Code()); err != nil {
return 0, errorf("could not write to connection: %v", err)
}
default:
return reqID, msg
}
}
return 0, errorf("no message received within %v", timeout)
}
func (s *Suite) setupConnection66(t *utesting.T) *Conn {
// create conn
sendConn := s.dial66(t)
sendConn.handshake(t)
sendConn.statusExchange66(t, s.chain)
return sendConn
}
func (s *Suite) testAnnounce66(t *utesting.T, sendConn, receiveConn *Conn, blockAnnouncement *NewBlock) {
// Announce the block.
if err := sendConn.Write(blockAnnouncement); err != nil {
t.Fatalf("could not write to connection: %v", err)
}
s.waitAnnounce66(t, receiveConn, blockAnnouncement)
}
func (s *Suite) waitAnnounce66(t *utesting.T, conn *Conn, blockAnnouncement *NewBlock) {
for {
_, msg := conn.readAndServe66(s.chain, timeout)
switch msg := msg.(type) {
case *NewBlock:
t.Logf("received NewBlock message: %s", pretty.Sdump(msg.Block))
assert.Equal(t,
blockAnnouncement.Block.Header(), msg.Block.Header(),
"wrong block header in announcement",
)
assert.Equal(t,
blockAnnouncement.TD, msg.TD,
"wrong TD in announcement",
)
return
case *NewBlockHashes:
blockHashes := *msg
t.Logf("received NewBlockHashes message: %s", pretty.Sdump(blockHashes))
assert.Equal(t, blockAnnouncement.Block.Hash(), blockHashes[0].Hash,
"wrong block hash in announcement",
)
return
case *NewPooledTransactionHashes:
// ignore old txs being propagated
continue
default:
t.Fatalf("unexpected: %s", pretty.Sdump(msg))
}
}
}
// waitForBlock66 waits for confirmation from the client that it has
// imported the given block.
func (c *Conn) waitForBlock66(block *types.Block) error {
defer c.SetReadDeadline(time.Time{})
c.SetReadDeadline(time.Now().Add(20 * time.Second))
// note: if the node has not yet imported the block, it will respond
// to the GetBlockHeaders request with an empty BlockHeaders response,
// so the GetBlockHeaders request must be sent again until the BlockHeaders
// response contains the desired header.
for {
req := eth.GetBlockHeadersPacket66{
RequestId: 54,
GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
Origin: eth.HashOrNumber{
Hash: block.Hash(),
},
Amount: 1,
},
}
if err := c.write66(req, GetBlockHeaders{}.Code()); err != nil {
return err
}
reqID, msg := c.read66()
// check message
switch msg := msg.(type) {
case BlockHeaders:
// check request ID
if reqID != req.RequestId {
return fmt.Errorf("request ID mismatch: wanted %d, got %d", req.RequestId, reqID)
}
for _, header := range msg {
if header.Number.Uint64() == block.NumberU64() {
return nil
}
}
time.Sleep(100 * time.Millisecond)
case *NewPooledTransactionHashes:
// ignore old announcements
continue
default:
return fmt.Errorf("invalid message: %s", pretty.Sdump(msg))
}
}
}
func sendSuccessfulTx66(t *utesting.T, s *Suite, tx *types.Transaction) {
sendConn := s.setupConnection66(t)
defer sendConn.Close()
sendSuccessfulTxWithConn(t, s, tx, sendConn)
}
// waitForBlockHeadersResponse66 waits for a BlockHeaders message with the given expected request ID
func (s *Suite) waitForBlockHeadersResponse66(conn *Conn, expectedID uint64) (BlockHeaders, error) {
reqID, msg := conn.readAndServe66(s.chain, timeout)
switch msg := msg.(type) {
case BlockHeaders:
if reqID != expectedID {
return nil, fmt.Errorf("request ID mismatch: wanted %d, got %d", expectedID, reqID)
}
return msg, nil
default:
return nil, fmt.Errorf("unexpected: %s", pretty.Sdump(msg))
}
}
func (s *Suite) getBlockHeaders66(conn *Conn, req eth.Packet, expectedID uint64) (BlockHeaders, error) {
if err := conn.write66(req, GetBlockHeaders{}.Code()); err != nil {
return nil, fmt.Errorf("could not write to connection: %v", err)
}
return s.waitForBlockHeadersResponse66(conn, expectedID)
}
func headersMatch(t *utesting.T, chain *Chain, headers BlockHeaders) bool {
mismatched := 0
for _, header := range headers {
num := header.Number.Uint64()
t.Logf("received header (%d): %s", num, pretty.Sdump(header.Hash()))
if !reflect.DeepEqual(chain.blocks[int(num)].Header(), header) {
mismatched += 1
t.Logf("received wrong header: %v", pretty.Sdump(header))
}
}
return mismatched == 0
}
func (s *Suite) sendNextBlock66(t *utesting.T) {
sendConn, receiveConn := s.setupConnection66(t), s.setupConnection66(t)
defer sendConn.Close()
defer receiveConn.Close()
// create new block announcement
nextBlock := len(s.chain.blocks)
blockAnnouncement := &NewBlock{
Block: s.fullChain.blocks[nextBlock],
TD: s.fullChain.TD(nextBlock + 1),
}
// send announcement and wait for node to request the header
s.testAnnounce66(t, sendConn, receiveConn, blockAnnouncement)
// wait for client to update its chain
if err := receiveConn.waitForBlock66(s.fullChain.blocks[nextBlock]); err != nil {
t.Fatal(err)
}
// update test suite chain
s.chain.blocks = append(s.chain.blocks, s.fullChain.blocks[nextBlock])
}

View File

@ -0,0 +1,635 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package ethtest
import (
"fmt"
"net"
"reflect"
"strings"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/rlpx"
)
var (
pretty = spew.ConfigState{
Indent: " ",
DisableCapacities: true,
DisablePointerAddresses: true,
SortKeys: true,
}
timeout = 20 * time.Second
)
// Is_66 checks if the node supports the eth66 protocol version,
// and if not, exists the test suite
func (s *Suite) Is_66(t *utesting.T) {
conn, err := s.dial66()
if err != nil {
t.Fatalf("dial failed: %v", err)
}
if err := conn.handshake(); err != nil {
t.Fatalf("handshake failed: %v", err)
}
if conn.negotiatedProtoVersion < 66 {
t.Fail()
}
}
// dial attempts to dial the given node and perform a handshake,
// returning the created Conn if successful.
func (s *Suite) dial() (*Conn, error) {
// dial
fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP()))
if err != nil {
return nil, err
}
conn := Conn{Conn: rlpx.NewConn(fd, s.Dest.Pubkey())}
// do encHandshake
conn.ourKey, _ = crypto.GenerateKey()
_, err = conn.Handshake(conn.ourKey)
if err != nil {
conn.Close()
return nil, err
}
// set default p2p capabilities
conn.caps = []p2p.Cap{
{Name: "eth", Version: 64},
{Name: "eth", Version: 65},
}
conn.ourHighestProtoVersion = 65
return &conn, nil
}
// dial66 attempts to dial the given node and perform a handshake,
// returning the created Conn with additional eth66 capabilities if
// successful
func (s *Suite) dial66() (*Conn, error) {
conn, err := s.dial()
if err != nil {
return nil, fmt.Errorf("dial failed: %v", err)
}
conn.caps = append(conn.caps, p2p.Cap{Name: "eth", Version: 66})
conn.ourHighestProtoVersion = 66
return conn, nil
}
// peer performs both the protocol handshake and the status message
// exchange with the node in order to peer with it.
func (c *Conn) peer(chain *Chain, status *Status) error {
if err := c.handshake(); err != nil {
return fmt.Errorf("handshake failed: %v", err)
}
if _, err := c.statusExchange(chain, status); err != nil {
return fmt.Errorf("status exchange failed: %v", err)
}
return nil
}
// handshake performs a protocol handshake with the node.
func (c *Conn) handshake() error {
defer c.SetDeadline(time.Time{})
c.SetDeadline(time.Now().Add(10 * time.Second))
// write hello to client
pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:]
ourHandshake := &Hello{
Version: 5,
Caps: c.caps,
ID: pub0,
}
if err := c.Write(ourHandshake); err != nil {
return fmt.Errorf("write to connection failed: %v", err)
}
// read hello from client
switch msg := c.Read().(type) {
case *Hello:
// set snappy if version is at least 5
if msg.Version >= 5 {
c.SetSnappy(true)
}
c.negotiateEthProtocol(msg.Caps)
if c.negotiatedProtoVersion == 0 {
return fmt.Errorf("unexpected eth protocol version")
}
return nil
default:
return fmt.Errorf("bad handshake: %#v", msg)
}
}
// negotiateEthProtocol sets the Conn's eth protocol version to highest
// advertised capability from peer.
func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) {
var highestEthVersion uint
for _, capability := range caps {
if capability.Name != "eth" {
continue
}
if capability.Version > highestEthVersion && capability.Version <= c.ourHighestProtoVersion {
highestEthVersion = capability.Version
}
}
c.negotiatedProtoVersion = highestEthVersion
}
// statusExchange performs a `Status` message exchange with the given node.
func (c *Conn) statusExchange(chain *Chain, status *Status) (Message, error) {
defer c.SetDeadline(time.Time{})
c.SetDeadline(time.Now().Add(20 * time.Second))
// read status message from client
var message Message
loop:
for {
switch msg := c.Read().(type) {
case *Status:
if have, want := msg.Head, chain.blocks[chain.Len()-1].Hash(); have != want {
return nil, fmt.Errorf("wrong head block in status, want: %#x (block %d) have %#x",
want, chain.blocks[chain.Len()-1].NumberU64(), have)
}
if have, want := msg.TD.Cmp(chain.TD()), 0; have != want {
return nil, fmt.Errorf("wrong TD in status: have %v want %v", have, want)
}
if have, want := msg.ForkID, chain.ForkID(); !reflect.DeepEqual(have, want) {
return nil, fmt.Errorf("wrong fork ID in status: have %v, want %v", have, want)
}
if have, want := msg.ProtocolVersion, c.ourHighestProtoVersion; have != uint32(want) {
return nil, fmt.Errorf("wrong protocol version: have %v, want %v", have, want)
}
message = msg
break loop
case *Disconnect:
return nil, fmt.Errorf("disconnect received: %v", msg.Reason)
case *Ping:
c.Write(&Pong{}) // TODO (renaynay): in the future, this should be an error
// (PINGs should not be a response upon fresh connection)
default:
return nil, fmt.Errorf("bad status message: %s", pretty.Sdump(msg))
}
}
// make sure eth protocol version is set for negotiation
if c.negotiatedProtoVersion == 0 {
return nil, fmt.Errorf("eth protocol version must be set in Conn")
}
if status == nil {
// default status message
status = &Status{
ProtocolVersion: uint32(c.negotiatedProtoVersion),
NetworkID: chain.chainConfig.ChainID.Uint64(),
TD: chain.TD(),
Head: chain.blocks[chain.Len()-1].Hash(),
Genesis: chain.blocks[0].Hash(),
ForkID: chain.ForkID(),
}
}
if err := c.Write(status); err != nil {
return nil, fmt.Errorf("write to connection failed: %v", err)
}
return message, nil
}
// createSendAndRecvConns creates two connections, one for sending messages to the
// node, and one for receiving messages from the node.
func (s *Suite) createSendAndRecvConns(isEth66 bool) (*Conn, *Conn, error) {
var (
sendConn *Conn
recvConn *Conn
err error
)
if isEth66 {
sendConn, err = s.dial66()
if err != nil {
return nil, nil, fmt.Errorf("dial failed: %v", err)
}
recvConn, err = s.dial66()
if err != nil {
sendConn.Close()
return nil, nil, fmt.Errorf("dial failed: %v", err)
}
} else {
sendConn, err = s.dial()
if err != nil {
return nil, nil, fmt.Errorf("dial failed: %v", err)
}
recvConn, err = s.dial()
if err != nil {
sendConn.Close()
return nil, nil, fmt.Errorf("dial failed: %v", err)
}
}
return sendConn, recvConn, nil
}
// readAndServe serves GetBlockHeaders requests while waiting
// on another message from the node.
func (c *Conn) readAndServe(chain *Chain, timeout time.Duration) Message {
start := time.Now()
for time.Since(start) < timeout {
c.SetReadDeadline(time.Now().Add(5 * time.Second))
switch msg := c.Read().(type) {
case *Ping:
c.Write(&Pong{})
case *GetBlockHeaders:
req := *msg
headers, err := chain.GetHeaders(req)
if err != nil {
return errorf("could not get headers for inbound header request: %v", err)
}
if err := c.Write(headers); err != nil {
return errorf("could not write to connection: %v", err)
}
default:
return msg
}
}
return errorf("no message received within %v", timeout)
}
// readAndServe66 serves eth66 GetBlockHeaders requests while waiting
// on another message from the node.
func (c *Conn) readAndServe66(chain *Chain, timeout time.Duration) (uint64, Message) {
start := time.Now()
for time.Since(start) < timeout {
c.SetReadDeadline(time.Now().Add(10 * time.Second))
reqID, msg := c.Read66()
switch msg := msg.(type) {
case *Ping:
c.Write(&Pong{})
case *GetBlockHeaders:
headers, err := chain.GetHeaders(*msg)
if err != nil {
return 0, errorf("could not get headers for inbound header request: %v", err)
}
resp := &eth.BlockHeadersPacket66{
RequestId: reqID,
BlockHeadersPacket: eth.BlockHeadersPacket(headers),
}
if err := c.Write66(resp, BlockHeaders{}.Code()); err != nil {
return 0, errorf("could not write to connection: %v", err)
}
default:
return reqID, msg
}
}
return 0, errorf("no message received within %v", timeout)
}
// headersRequest executes the given `GetBlockHeaders` request.
func (c *Conn) headersRequest(request *GetBlockHeaders, chain *Chain, isEth66 bool, reqID uint64) (BlockHeaders, error) {
defer c.SetReadDeadline(time.Time{})
c.SetReadDeadline(time.Now().Add(20 * time.Second))
// if on eth66 connection, perform eth66 GetBlockHeaders request
if isEth66 {
return getBlockHeaders66(chain, c, request, reqID)
}
if err := c.Write(request); err != nil {
return nil, err
}
switch msg := c.readAndServe(chain, timeout).(type) {
case *BlockHeaders:
return *msg, nil
default:
return nil, fmt.Errorf("invalid message: %s", pretty.Sdump(msg))
}
}
// getBlockHeaders66 executes the given `GetBlockHeaders` request over the eth66 protocol.
func getBlockHeaders66(chain *Chain, conn *Conn, request *GetBlockHeaders, id uint64) (BlockHeaders, error) {
// write request
packet := eth.GetBlockHeadersPacket(*request)
req := &eth.GetBlockHeadersPacket66{
RequestId: id,
GetBlockHeadersPacket: &packet,
}
if err := conn.Write66(req, GetBlockHeaders{}.Code()); err != nil {
return nil, fmt.Errorf("could not write to connection: %v", err)
}
// wait for response
msg := conn.waitForResponse(chain, timeout, req.RequestId)
headers, ok := msg.(BlockHeaders)
if !ok {
return nil, fmt.Errorf("unexpected message received: %s", pretty.Sdump(msg))
}
return headers, nil
}
// headersMatch returns whether the received headers match the given request
func headersMatch(expected BlockHeaders, headers BlockHeaders) bool {
return reflect.DeepEqual(expected, headers)
}
// waitForResponse reads from the connection until a response with the expected
// request ID is received.
func (c *Conn) waitForResponse(chain *Chain, timeout time.Duration, requestID uint64) Message {
for {
id, msg := c.readAndServe66(chain, timeout)
if id == requestID {
return msg
}
}
}
// sendNextBlock broadcasts the next block in the chain and waits
// for the node to propagate the block and import it into its chain.
func (s *Suite) sendNextBlock(isEth66 bool) error {
// set up sending and receiving connections
sendConn, recvConn, err := s.createSendAndRecvConns(isEth66)
if err != nil {
return err
}
defer sendConn.Close()
defer recvConn.Close()
if err = sendConn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err)
}
if err = recvConn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err)
}
// create new block announcement
nextBlock := s.fullChain.blocks[s.chain.Len()]
blockAnnouncement := &NewBlock{
Block: nextBlock,
TD: s.fullChain.TotalDifficultyAt(s.chain.Len()),
}
// send announcement and wait for node to request the header
if err = s.testAnnounce(sendConn, recvConn, blockAnnouncement); err != nil {
return fmt.Errorf("failed to announce block: %v", err)
}
// wait for client to update its chain
if err = s.waitForBlockImport(recvConn, nextBlock, isEth66); err != nil {
return fmt.Errorf("failed to receive confirmation of block import: %v", err)
}
// update test suite chain
s.chain.blocks = append(s.chain.blocks, nextBlock)
return nil
}
// testAnnounce writes a block announcement to the node and waits for the node
// to propagate it.
func (s *Suite) testAnnounce(sendConn, receiveConn *Conn, blockAnnouncement *NewBlock) error {
if err := sendConn.Write(blockAnnouncement); err != nil {
return fmt.Errorf("could not write to connection: %v", err)
}
return s.waitAnnounce(receiveConn, blockAnnouncement)
}
// waitAnnounce waits for a NewBlock or NewBlockHashes announcement from the node.
func (s *Suite) waitAnnounce(conn *Conn, blockAnnouncement *NewBlock) error {
for {
switch msg := conn.readAndServe(s.chain, timeout).(type) {
case *NewBlock:
if !reflect.DeepEqual(blockAnnouncement.Block.Header(), msg.Block.Header()) {
return fmt.Errorf("wrong header in block announcement: \nexpected %v "+
"\ngot %v", blockAnnouncement.Block.Header(), msg.Block.Header())
}
if !reflect.DeepEqual(blockAnnouncement.TD, msg.TD) {
return fmt.Errorf("wrong TD in announcement: expected %v, got %v", blockAnnouncement.TD, msg.TD)
}
return nil
case *NewBlockHashes:
hashes := *msg
if blockAnnouncement.Block.Hash() != hashes[0].Hash {
return fmt.Errorf("wrong block hash in announcement: expected %v, got %v", blockAnnouncement.Block.Hash(), hashes[0].Hash)
}
return nil
case *NewPooledTransactionHashes:
// ignore tx announcements from previous tests
continue
default:
return fmt.Errorf("unexpected: %s", pretty.Sdump(msg))
}
}
}
func (s *Suite) waitForBlockImport(conn *Conn, block *types.Block, isEth66 bool) error {
defer conn.SetReadDeadline(time.Time{})
conn.SetReadDeadline(time.Now().Add(20 * time.Second))
// create request
req := &GetBlockHeaders{
Origin: eth.HashOrNumber{
Hash: block.Hash(),
},
Amount: 1,
}
// loop until BlockHeaders response contains desired block, confirming the
// node imported the block
for {
var (
headers BlockHeaders
err error
)
if isEth66 {
requestID := uint64(54)
headers, err = conn.headersRequest(req, s.chain, eth66, requestID)
} else {
headers, err = conn.headersRequest(req, s.chain, eth65, 0)
}
if err != nil {
return fmt.Errorf("GetBlockHeader request failed: %v", err)
}
// if headers response is empty, node hasn't imported block yet, try again
if len(headers) == 0 {
time.Sleep(100 * time.Millisecond)
continue
}
if !reflect.DeepEqual(block.Header(), headers[0]) {
return fmt.Errorf("wrong header returned: wanted %v, got %v", block.Header(), headers[0])
}
return nil
}
}
func (s *Suite) oldAnnounce(isEth66 bool) error {
sendConn, receiveConn, err := s.createSendAndRecvConns(isEth66)
if err != nil {
return err
}
defer sendConn.Close()
defer receiveConn.Close()
if err := sendConn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err)
}
if err := receiveConn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err)
}
// create old block announcement
oldBlockAnnounce := &NewBlock{
Block: s.chain.blocks[len(s.chain.blocks)/2],
TD: s.chain.blocks[len(s.chain.blocks)/2].Difficulty(),
}
if err := sendConn.Write(oldBlockAnnounce); err != nil {
return fmt.Errorf("could not write to connection: %v", err)
}
// wait to see if the announcement is propagated
switch msg := receiveConn.readAndServe(s.chain, time.Second*8).(type) {
case *NewBlock:
block := *msg
if block.Block.Hash() == oldBlockAnnounce.Block.Hash() {
return fmt.Errorf("unexpected: block propagated: %s", pretty.Sdump(msg))
}
case *NewBlockHashes:
hashes := *msg
for _, hash := range hashes {
if hash.Hash == oldBlockAnnounce.Block.Hash() {
return fmt.Errorf("unexpected: block announced: %s", pretty.Sdump(msg))
}
}
case *Error:
errMsg := *msg
// check to make sure error is timeout (propagation didn't come through == test successful)
if !strings.Contains(errMsg.String(), "timeout") {
return fmt.Errorf("unexpected error: %v", pretty.Sdump(msg))
}
default:
return fmt.Errorf("unexpected: %s", pretty.Sdump(msg))
}
return nil
}
func (s *Suite) maliciousHandshakes(t *utesting.T, isEth66 bool) error {
var (
conn *Conn
err error
)
if isEth66 {
conn, err = s.dial66()
if err != nil {
return fmt.Errorf("dial failed: %v", err)
}
} else {
conn, err = s.dial()
if err != nil {
return fmt.Errorf("dial failed: %v", err)
}
}
defer conn.Close()
// write hello to client
pub0 := crypto.FromECDSAPub(&conn.ourKey.PublicKey)[1:]
handshakes := []*Hello{
{
Version: 5,
Caps: []p2p.Cap{
{Name: largeString(2), Version: 64},
},
ID: pub0,
},
{
Version: 5,
Caps: []p2p.Cap{
{Name: "eth", Version: 64},
{Name: "eth", Version: 65},
},
ID: append(pub0, byte(0)),
},
{
Version: 5,
Caps: []p2p.Cap{
{Name: "eth", Version: 64},
{Name: "eth", Version: 65},
},
ID: append(pub0, pub0...),
},
{
Version: 5,
Caps: []p2p.Cap{
{Name: "eth", Version: 64},
{Name: "eth", Version: 65},
},
ID: largeBuffer(2),
},
{
Version: 5,
Caps: []p2p.Cap{
{Name: largeString(2), Version: 64},
},
ID: largeBuffer(2),
},
}
for i, handshake := range handshakes {
t.Logf("Testing malicious handshake %v\n", i)
if err := conn.Write(handshake); err != nil {
return fmt.Errorf("could not write to connection: %v", err)
}
// check that the peer disconnected
for i := 0; i < 2; i++ {
switch msg := conn.readAndServe(s.chain, 20*time.Second).(type) {
case *Disconnect:
case *Error:
case *Hello:
// Discard one hello as Hello's are sent concurrently
continue
default:
return fmt.Errorf("unexpected: %s", pretty.Sdump(msg))
}
}
// dial for the next round
if isEth66 {
conn, err = s.dial66()
if err != nil {
return fmt.Errorf("dial failed: %v", err)
}
} else {
conn, err = s.dial()
if err != nil {
return fmt.Errorf("dial failed: %v", err)
}
}
}
return nil
}
func (s *Suite) maliciousStatus(conn *Conn) error {
if err := conn.handshake(); err != nil {
return fmt.Errorf("handshake failed: %v", err)
}
status := &Status{
ProtocolVersion: uint32(conn.negotiatedProtoVersion),
NetworkID: s.chain.chainConfig.ChainID.Uint64(),
TD: largeNumber(2),
Head: s.chain.blocks[s.chain.Len()-1].Hash(),
Genesis: s.chain.blocks[0].Hash(),
ForkID: s.chain.ForkID(),
}
// get status
msg, err := conn.statusExchange(s.chain, status)
if err != nil {
return fmt.Errorf("status exchange failed: %v", err)
}
switch msg := msg.(type) {
case *Status:
default:
return fmt.Errorf("expected status, got: %#v ", msg)
}
// wait for disconnect
switch msg := conn.readAndServe(s.chain, timeout).(type) {
case *Disconnect:
return nil
case *Error:
return nil
default:
return fmt.Errorf("expected disconnect, got: %s", pretty.Sdump(msg))
}
}

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@
package ethtest package ethtest
import ( import (
"fmt"
"math/big" "math/big"
"strings" "strings"
"time" "time"
@ -31,58 +32,171 @@ import (
//var faucetAddr = common.HexToAddress("0x71562b71999873DB5b286dF957af199Ec94617F7") //var faucetAddr = common.HexToAddress("0x71562b71999873DB5b286dF957af199Ec94617F7")
var faucetKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") var faucetKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
func sendSuccessfulTx(t *utesting.T, s *Suite, tx *types.Transaction) { func (s *Suite) sendSuccessfulTxs(t *utesting.T, isEth66 bool) error {
sendConn := s.setupConnection(t) tests := []*types.Transaction{
defer sendConn.Close() getNextTxFromChain(s),
sendSuccessfulTxWithConn(t, s, tx, sendConn) unknownTx(s),
}
for i, tx := range tests {
if tx == nil {
return fmt.Errorf("could not find tx to send")
}
t.Logf("Testing tx propagation %d: sending tx %v %v %v\n", i, tx.Hash().String(), tx.GasPrice(), tx.Gas())
// get previous tx if exists for reference in case of old tx propagation
var prevTx *types.Transaction
if i != 0 {
prevTx = tests[i-1]
}
// write tx to connection
if err := sendSuccessfulTx(s, tx, prevTx, isEth66); err != nil {
return fmt.Errorf("send successful tx test failed: %v", err)
}
}
return nil
} }
func sendSuccessfulTxWithConn(t *utesting.T, s *Suite, tx *types.Transaction, sendConn *Conn) { func sendSuccessfulTx(s *Suite, tx *types.Transaction, prevTx *types.Transaction, isEth66 bool) error {
t.Logf("sending tx: %v %v %v\n", tx.Hash().String(), tx.GasPrice(), tx.Gas()) sendConn, recvConn, err := s.createSendAndRecvConns(isEth66)
if err != nil {
return err
}
defer sendConn.Close()
defer recvConn.Close()
if err = sendConn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err)
}
// Send the transaction // Send the transaction
if err := sendConn.Write(&Transactions{tx}); err != nil { if err = sendConn.Write(&Transactions{tx}); err != nil {
t.Fatal(err) return fmt.Errorf("failed to write to connection: %v", err)
}
// peer receiving connection to node
if err = recvConn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err)
} }
// update last nonce seen // update last nonce seen
nonce = tx.Nonce() nonce = tx.Nonce()
recvConn := s.setupConnection(t)
// Wait for the transaction announcement // Wait for the transaction announcement
switch msg := recvConn.ReadAndServe(s.chain, timeout).(type) { for {
switch msg := recvConn.readAndServe(s.chain, timeout).(type) {
case *Transactions: case *Transactions:
recTxs := *msg recTxs := *msg
// if you receive an old tx propagation, read from connection again
if len(recTxs) == 1 && prevTx != nil {
if recTxs[0] == prevTx {
continue
}
}
for _, gotTx := range recTxs { for _, gotTx := range recTxs {
if gotTx.Hash() == tx.Hash() { if gotTx.Hash() == tx.Hash() {
// Ok // Ok
return return nil
} }
} }
t.Fatalf("missing transaction: got %v missing %v", recTxs, tx.Hash()) return fmt.Errorf("missing transaction: got %v missing %v", recTxs, tx.Hash())
case *NewPooledTransactionHashes: case *NewPooledTransactionHashes:
txHashes := *msg txHashes := *msg
// if you receive an old tx propagation, read from connection again
if len(txHashes) == 1 && prevTx != nil {
if txHashes[0] == prevTx.Hash() {
continue
}
}
for _, gotHash := range txHashes { for _, gotHash := range txHashes {
if gotHash == tx.Hash() { if gotHash == tx.Hash() {
return // Ok
return nil
} }
} }
t.Fatalf("missing transaction announcement: got %v missing %v", txHashes, tx.Hash()) return fmt.Errorf("missing transaction announcement: got %v missing %v", txHashes, tx.Hash())
default: default:
t.Fatalf("unexpected message in sendSuccessfulTx: %s", pretty.Sdump(msg)) return fmt.Errorf("unexpected message in sendSuccessfulTx: %s", pretty.Sdump(msg))
} }
}
}
func (s *Suite) sendMaliciousTxs(t *utesting.T, isEth66 bool) error {
badTxs := []*types.Transaction{
getOldTxFromChain(s),
invalidNonceTx(s),
hugeAmount(s),
hugeGasPrice(s),
hugeData(s),
}
// setup receiving connection before sending malicious txs
var (
recvConn *Conn
err error
)
if isEth66 {
recvConn, err = s.dial66()
} else {
recvConn, err = s.dial()
}
if err != nil {
return fmt.Errorf("dial failed: %v", err)
}
defer recvConn.Close()
if err = recvConn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err)
}
for i, tx := range badTxs {
t.Logf("Testing malicious tx propagation: %v\n", i)
if err = sendMaliciousTx(s, tx, isEth66); err != nil {
return fmt.Errorf("malicious tx test failed:\ntx: %v\nerror: %v", tx, err)
}
}
// check to make sure bad txs aren't propagated
return checkMaliciousTxPropagation(s, badTxs, recvConn)
}
func sendMaliciousTx(s *Suite, tx *types.Transaction, isEth66 bool) error {
// setup connection
var (
conn *Conn
err error
)
if isEth66 {
conn, err = s.dial66()
} else {
conn, err = s.dial()
}
if err != nil {
return fmt.Errorf("dial failed: %v", err)
}
defer conn.Close()
if err = conn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err)
}
// write malicious tx
if err = conn.Write(&Transactions{tx}); err != nil {
return fmt.Errorf("failed to write to connection: %v", err)
}
return nil
} }
var nonce = uint64(99) var nonce = uint64(99)
func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, sendConn *Conn, txs []*types.Transaction) { // sendMultipleSuccessfulTxs sends the given transactions to the node and
// expects the node to accept and propagate them.
func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, txs []*types.Transaction) error {
txMsg := Transactions(txs) txMsg := Transactions(txs)
t.Logf("sending %d txs\n", len(txs)) t.Logf("sending %d txs\n", len(txs))
recvConn := s.setupConnection(t) sendConn, recvConn, err := s.createSendAndRecvConns(true)
if err != nil {
return err
}
defer sendConn.Close()
defer recvConn.Close() defer recvConn.Close()
if err = sendConn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err)
}
if err = recvConn.peer(s.chain, nil); err != nil {
return fmt.Errorf("peering failed: %v", err)
}
// Send the transactions // Send the transactions
if err := sendConn.Write(&txMsg); err != nil { if err = sendConn.Write(&txMsg); err != nil {
t.Fatal(err) return fmt.Errorf("failed to write message to connection: %v", err)
} }
// update nonce // update nonce
nonce = txs[len(txs)-1].Nonce() nonce = txs[len(txs)-1].Nonce()
@ -90,7 +204,7 @@ func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, sendConn *Conn, txs []*t
recvHashes := make([]common.Hash, 0) recvHashes := make([]common.Hash, 0)
// all txs should be announced within 3 announcements // all txs should be announced within 3 announcements
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
switch msg := recvConn.ReadAndServe(s.chain, timeout).(type) { switch msg := recvConn.readAndServe(s.chain, timeout).(type) {
case *Transactions: case *Transactions:
for _, tx := range *msg { for _, tx := range *msg {
recvHashes = append(recvHashes, tx.Hash()) recvHashes = append(recvHashes, tx.Hash())
@ -99,7 +213,7 @@ func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, sendConn *Conn, txs []*t
recvHashes = append(recvHashes, *msg...) recvHashes = append(recvHashes, *msg...)
default: default:
if !strings.Contains(pretty.Sdump(msg), "i/o timeout") { if !strings.Contains(pretty.Sdump(msg), "i/o timeout") {
t.Fatalf("unexpected message while waiting to receive txs: %s", pretty.Sdump(msg)) return fmt.Errorf("unexpected message while waiting to receive txs: %s", pretty.Sdump(msg))
} }
} }
// break once all 2000 txs have been received // break once all 2000 txs have been received
@ -112,7 +226,7 @@ func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, sendConn *Conn, txs []*t
continue continue
} else { } else {
t.Logf("successfully received all %d txs", len(txs)) t.Logf("successfully received all %d txs", len(txs))
return return nil
} }
} }
} }
@ -121,13 +235,15 @@ func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, sendConn *Conn, txs []*t
for _, missing := range missingTxs { for _, missing := range missingTxs {
t.Logf("missing tx: %v", missing.Hash()) t.Logf("missing tx: %v", missing.Hash())
} }
t.Fatalf("missing %d txs", len(missingTxs)) return fmt.Errorf("missing %d txs", len(missingTxs))
} }
return nil
} }
func waitForTxPropagation(t *utesting.T, s *Suite, txs []*types.Transaction, recvConn *Conn) { // checkMaliciousTxPropagation checks whether the given malicious transactions were
// Wait for another transaction announcement // propagated by the node.
switch msg := recvConn.ReadAndServe(s.chain, time.Second*8).(type) { func checkMaliciousTxPropagation(s *Suite, txs []*types.Transaction, conn *Conn) error {
switch msg := conn.readAndServe(s.chain, time.Second*8).(type) {
case *Transactions: case *Transactions:
// check to see if any of the failing txs were in the announcement // check to see if any of the failing txs were in the announcement
recvTxs := make([]common.Hash, len(*msg)) recvTxs := make([]common.Hash, len(*msg))
@ -136,25 +252,20 @@ func waitForTxPropagation(t *utesting.T, s *Suite, txs []*types.Transaction, rec
} }
badTxs, _ := compareReceivedTxs(recvTxs, txs) badTxs, _ := compareReceivedTxs(recvTxs, txs)
if len(badTxs) > 0 { if len(badTxs) > 0 {
for _, tx := range badTxs { return fmt.Errorf("received %d bad txs: \n%v", len(badTxs), badTxs)
t.Logf("received bad tx: %v", tx)
}
t.Fatalf("received %d bad txs", len(badTxs))
} }
case *NewPooledTransactionHashes: case *NewPooledTransactionHashes:
badTxs, _ := compareReceivedTxs(*msg, txs) badTxs, _ := compareReceivedTxs(*msg, txs)
if len(badTxs) > 0 { if len(badTxs) > 0 {
for _, tx := range badTxs { return fmt.Errorf("received %d bad txs: \n%v", len(badTxs), badTxs)
t.Logf("received bad tx: %v", tx)
}
t.Fatalf("received %d bad txs", len(badTxs))
} }
case *Error: case *Error:
// Transaction should not be announced -> wait for timeout // Transaction should not be announced -> wait for timeout
return return nil
default: default:
t.Fatalf("unexpected message in sendFailingTx: %s", pretty.Sdump(msg)) return fmt.Errorf("unexpected message in sendFailingTx: %s", pretty.Sdump(msg))
} }
return nil
} }
// compareReceivedTxs compares the received set of txs against the given set of txs, // compareReceivedTxs compares the received set of txs against the given set of txs,
@ -180,118 +291,129 @@ func compareReceivedTxs(recvTxs []common.Hash, txs []*types.Transaction) (presen
return present, missing return present, missing
} }
func unknownTx(t *utesting.T, s *Suite) *types.Transaction { func unknownTx(s *Suite) *types.Transaction {
tx := getNextTxFromChain(t, s) tx := getNextTxFromChain(s)
if tx == nil {
return nil
}
var to common.Address var to common.Address
if tx.To() != nil { if tx.To() != nil {
to = *tx.To() to = *tx.To()
} }
txNew := types.NewTransaction(tx.Nonce()+1, to, tx.Value(), tx.Gas(), tx.GasPrice(), tx.Data()) txNew := types.NewTransaction(tx.Nonce()+1, to, tx.Value(), tx.Gas(), tx.GasPrice(), tx.Data())
return signWithFaucet(t, s.chain.chainConfig, txNew) return signWithFaucet(s.chain.chainConfig, txNew)
} }
func getNextTxFromChain(t *utesting.T, s *Suite) *types.Transaction { func getNextTxFromChain(s *Suite) *types.Transaction {
// Get a new transaction // Get a new transaction
var tx *types.Transaction
for _, blocks := range s.fullChain.blocks[s.chain.Len():] { for _, blocks := range s.fullChain.blocks[s.chain.Len():] {
txs := blocks.Transactions() txs := blocks.Transactions()
if txs.Len() != 0 { if txs.Len() != 0 {
tx = txs[0] return txs[0]
break
} }
} }
if tx == nil { return nil
t.Fatal("could not find transaction")
}
return tx
} }
func generateTxs(t *utesting.T, s *Suite, numTxs int) (map[common.Hash]common.Hash, []*types.Transaction) { func generateTxs(s *Suite, numTxs int) (map[common.Hash]common.Hash, []*types.Transaction, error) {
txHashMap := make(map[common.Hash]common.Hash, numTxs) txHashMap := make(map[common.Hash]common.Hash, numTxs)
txs := make([]*types.Transaction, numTxs) txs := make([]*types.Transaction, numTxs)
nextTx := getNextTxFromChain(t, s) nextTx := getNextTxFromChain(s)
if nextTx == nil {
return nil, nil, fmt.Errorf("failed to get the next transaction")
}
gas := nextTx.Gas() gas := nextTx.Gas()
nonce = nonce + 1 nonce = nonce + 1
// generate txs // generate txs
for i := 0; i < numTxs; i++ { for i := 0; i < numTxs; i++ {
tx := generateTx(t, s.chain.chainConfig, nonce, gas) tx := generateTx(s.chain.chainConfig, nonce, gas)
if tx == nil {
return nil, nil, fmt.Errorf("failed to get the next transaction")
}
txHashMap[tx.Hash()] = tx.Hash() txHashMap[tx.Hash()] = tx.Hash()
txs[i] = tx txs[i] = tx
nonce = nonce + 1 nonce = nonce + 1
} }
return txHashMap, txs return txHashMap, txs, nil
} }
func generateTx(t *utesting.T, chainConfig *params.ChainConfig, nonce uint64, gas uint64) *types.Transaction { func generateTx(chainConfig *params.ChainConfig, nonce uint64, gas uint64) *types.Transaction {
var to common.Address var to common.Address
tx := types.NewTransaction(nonce, to, big.NewInt(1), gas, big.NewInt(1), []byte{}) tx := types.NewTransaction(nonce, to, big.NewInt(1), gas, big.NewInt(1), []byte{})
return signWithFaucet(t, chainConfig, tx) return signWithFaucet(chainConfig, tx)
} }
func getOldTxFromChain(t *utesting.T, s *Suite) *types.Transaction { func getOldTxFromChain(s *Suite) *types.Transaction {
var tx *types.Transaction
for _, blocks := range s.fullChain.blocks[:s.chain.Len()-1] { for _, blocks := range s.fullChain.blocks[:s.chain.Len()-1] {
txs := blocks.Transactions() txs := blocks.Transactions()
if txs.Len() != 0 { if txs.Len() != 0 {
tx = txs[0] return txs[0]
break
} }
} }
if tx == nil { return nil
t.Fatal("could not find transaction")
}
return tx
} }
func invalidNonceTx(t *utesting.T, s *Suite) *types.Transaction { func invalidNonceTx(s *Suite) *types.Transaction {
tx := getNextTxFromChain(t, s) tx := getNextTxFromChain(s)
if tx == nil {
return nil
}
var to common.Address var to common.Address
if tx.To() != nil { if tx.To() != nil {
to = *tx.To() to = *tx.To()
} }
txNew := types.NewTransaction(tx.Nonce()-2, to, tx.Value(), tx.Gas(), tx.GasPrice(), tx.Data()) txNew := types.NewTransaction(tx.Nonce()-2, to, tx.Value(), tx.Gas(), tx.GasPrice(), tx.Data())
return signWithFaucet(t, s.chain.chainConfig, txNew) return signWithFaucet(s.chain.chainConfig, txNew)
} }
func hugeAmount(t *utesting.T, s *Suite) *types.Transaction { func hugeAmount(s *Suite) *types.Transaction {
tx := getNextTxFromChain(t, s) tx := getNextTxFromChain(s)
if tx == nil {
return nil
}
amount := largeNumber(2) amount := largeNumber(2)
var to common.Address var to common.Address
if tx.To() != nil { if tx.To() != nil {
to = *tx.To() to = *tx.To()
} }
txNew := types.NewTransaction(tx.Nonce(), to, amount, tx.Gas(), tx.GasPrice(), tx.Data()) txNew := types.NewTransaction(tx.Nonce(), to, amount, tx.Gas(), tx.GasPrice(), tx.Data())
return signWithFaucet(t, s.chain.chainConfig, txNew) return signWithFaucet(s.chain.chainConfig, txNew)
} }
func hugeGasPrice(t *utesting.T, s *Suite) *types.Transaction { func hugeGasPrice(s *Suite) *types.Transaction {
tx := getNextTxFromChain(t, s) tx := getNextTxFromChain(s)
if tx == nil {
return nil
}
gasPrice := largeNumber(2) gasPrice := largeNumber(2)
var to common.Address var to common.Address
if tx.To() != nil { if tx.To() != nil {
to = *tx.To() to = *tx.To()
} }
txNew := types.NewTransaction(tx.Nonce(), to, tx.Value(), tx.Gas(), gasPrice, tx.Data()) txNew := types.NewTransaction(tx.Nonce(), to, tx.Value(), tx.Gas(), gasPrice, tx.Data())
return signWithFaucet(t, s.chain.chainConfig, txNew) return signWithFaucet(s.chain.chainConfig, txNew)
} }
func hugeData(t *utesting.T, s *Suite) *types.Transaction { func hugeData(s *Suite) *types.Transaction {
tx := getNextTxFromChain(t, s) tx := getNextTxFromChain(s)
if tx == nil {
return nil
}
var to common.Address var to common.Address
if tx.To() != nil { if tx.To() != nil {
to = *tx.To() to = *tx.To()
} }
txNew := types.NewTransaction(tx.Nonce(), to, tx.Value(), tx.Gas(), tx.GasPrice(), largeBuffer(2)) txNew := types.NewTransaction(tx.Nonce(), to, tx.Value(), tx.Gas(), tx.GasPrice(), largeBuffer(2))
return signWithFaucet(t, s.chain.chainConfig, txNew) return signWithFaucet(s.chain.chainConfig, txNew)
} }
func signWithFaucet(t *utesting.T, chainConfig *params.ChainConfig, tx *types.Transaction) *types.Transaction { func signWithFaucet(chainConfig *params.ChainConfig, tx *types.Transaction) *types.Transaction {
signer := types.LatestSigner(chainConfig) signer := types.LatestSigner(chainConfig)
signedTx, err := types.SignTx(tx, signer, faucetKey) signedTx, err := types.SignTx(tx, signer, faucetKey)
if err != nil { if err != nil {
t.Fatalf("could not sign tx: %v\n", err) return nil
} }
return signedTx return signedTx
} }

View File

@ -19,13 +19,8 @@ package ethtest
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"fmt" "fmt"
"reflect"
"time"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/rlpx" "github.com/ethereum/go-ethereum/p2p/rlpx"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
@ -137,6 +132,7 @@ type Conn struct {
caps []p2p.Cap caps []p2p.Cap
} }
// Read reads an eth packet from the connection.
func (c *Conn) Read() Message { func (c *Conn) Read() Message {
code, rawData, _, err := c.Conn.Read() code, rawData, _, err := c.Conn.Read()
if err != nil { if err != nil {
@ -185,32 +181,83 @@ func (c *Conn) Read() Message {
return msg return msg
} }
// ReadAndServe serves GetBlockHeaders requests while waiting // Read66 reads an eth66 packet from the connection.
// on another message from the node. func (c *Conn) Read66() (uint64, Message) {
func (c *Conn) ReadAndServe(chain *Chain, timeout time.Duration) Message { code, rawData, _, err := c.Conn.Read()
start := time.Now()
for time.Since(start) < timeout {
c.SetReadDeadline(time.Now().Add(5 * time.Second))
switch msg := c.Read().(type) {
case *Ping:
c.Write(&Pong{})
case *GetBlockHeaders:
req := *msg
headers, err := chain.GetHeaders(req)
if err != nil { if err != nil {
return errorf("could not get headers for inbound header request: %v", err) return 0, errorf("could not read from connection: %v", err)
} }
if err := c.Write(headers); err != nil { var msg Message
return errorf("could not write to connection: %v", err) switch int(code) {
case (Hello{}).Code():
msg = new(Hello)
case (Ping{}).Code():
msg = new(Ping)
case (Pong{}).Code():
msg = new(Pong)
case (Disconnect{}).Code():
msg = new(Disconnect)
case (Status{}).Code():
msg = new(Status)
case (GetBlockHeaders{}).Code():
ethMsg := new(eth.GetBlockHeadersPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
} }
return ethMsg.RequestId, GetBlockHeaders(*ethMsg.GetBlockHeadersPacket)
case (BlockHeaders{}).Code():
ethMsg := new(eth.BlockHeadersPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
}
return ethMsg.RequestId, BlockHeaders(ethMsg.BlockHeadersPacket)
case (GetBlockBodies{}).Code():
ethMsg := new(eth.GetBlockBodiesPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
}
return ethMsg.RequestId, GetBlockBodies(ethMsg.GetBlockBodiesPacket)
case (BlockBodies{}).Code():
ethMsg := new(eth.BlockBodiesPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
}
return ethMsg.RequestId, BlockBodies(ethMsg.BlockBodiesPacket)
case (NewBlock{}).Code():
msg = new(NewBlock)
case (NewBlockHashes{}).Code():
msg = new(NewBlockHashes)
case (Transactions{}).Code():
msg = new(Transactions)
case (NewPooledTransactionHashes{}).Code():
msg = new(NewPooledTransactionHashes)
case (GetPooledTransactions{}.Code()):
ethMsg := new(eth.GetPooledTransactionsPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
}
return ethMsg.RequestId, GetPooledTransactions(ethMsg.GetPooledTransactionsPacket)
case (PooledTransactions{}.Code()):
ethMsg := new(eth.PooledTransactionsPacket66)
if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
}
return ethMsg.RequestId, PooledTransactions(ethMsg.PooledTransactionsPacket)
default: default:
return msg msg = errorf("invalid message code: %d", code)
} }
if msg != nil {
if err := rlp.DecodeBytes(rawData, msg); err != nil {
return 0, errorf("could not rlp decode message: %v", err)
} }
return errorf("no message received within %v", timeout) return 0, msg
}
return 0, errorf("invalid message: %s", string(rawData))
} }
// Write writes a eth packet to the connection.
func (c *Conn) Write(msg Message) error { func (c *Conn) Write(msg Message) error {
// check if message is eth protocol message // check if message is eth protocol message
var ( var (
@ -225,135 +272,12 @@ func (c *Conn) Write(msg Message) error {
return err return err
} }
// handshake checks to make sure a `HELLO` is received. // Write66 writes an eth66 packet to the connection.
func (c *Conn) handshake(t *utesting.T) Message { func (c *Conn) Write66(req eth.Packet, code int) error {
defer c.SetDeadline(time.Time{}) payload, err := rlp.EncodeToBytes(req)
c.SetDeadline(time.Now().Add(10 * time.Second)) if err != nil {
// write hello to client
pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:]
ourHandshake := &Hello{
Version: 5,
Caps: c.caps,
ID: pub0,
}
if err := c.Write(ourHandshake); err != nil {
t.Fatalf("could not write to connection: %v", err)
}
// read hello from client
switch msg := c.Read().(type) {
case *Hello:
// set snappy if version is at least 5
if msg.Version >= 5 {
c.SetSnappy(true)
}
c.negotiateEthProtocol(msg.Caps)
if c.negotiatedProtoVersion == 0 {
t.Fatalf("unexpected eth protocol version")
}
return msg
default:
t.Fatalf("bad handshake: %#v", msg)
return nil
}
}
// negotiateEthProtocol sets the Conn's eth protocol version
// to highest advertised capability from peer
func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) {
var highestEthVersion uint
for _, capability := range caps {
if capability.Name != "eth" {
continue
}
if capability.Version > highestEthVersion && capability.Version <= c.ourHighestProtoVersion {
highestEthVersion = capability.Version
}
}
c.negotiatedProtoVersion = highestEthVersion
}
// statusExchange performs a `Status` message exchange with the given
// node.
func (c *Conn) statusExchange(t *utesting.T, chain *Chain, status *Status) Message {
defer c.SetDeadline(time.Time{})
c.SetDeadline(time.Now().Add(20 * time.Second))
// read status message from client
var message Message
loop:
for {
switch msg := c.Read().(type) {
case *Status:
if have, want := msg.Head, chain.blocks[chain.Len()-1].Hash(); have != want {
t.Fatalf("wrong head block in status, want: %#x (block %d) have %#x",
want, chain.blocks[chain.Len()-1].NumberU64(), have)
}
if have, want := msg.TD.Cmp(chain.TD(chain.Len())), 0; have != want {
t.Fatalf("wrong TD in status: have %v want %v", have, want)
}
if have, want := msg.ForkID, chain.ForkID(); !reflect.DeepEqual(have, want) {
t.Fatalf("wrong fork ID in status: have %v, want %v", have, want)
}
message = msg
break loop
case *Disconnect:
t.Fatalf("disconnect received: %v", msg.Reason)
case *Ping:
c.Write(&Pong{}) // TODO (renaynay): in the future, this should be an error
// (PINGs should not be a response upon fresh connection)
default:
t.Fatalf("bad status message: %s", pretty.Sdump(msg))
}
}
// make sure eth protocol version is set for negotiation
if c.negotiatedProtoVersion == 0 {
t.Fatalf("eth protocol version must be set in Conn")
}
if status == nil {
// write status message to client
status = &Status{
ProtocolVersion: uint32(c.negotiatedProtoVersion),
NetworkID: chain.chainConfig.ChainID.Uint64(),
TD: chain.TD(chain.Len()),
Head: chain.blocks[chain.Len()-1].Hash(),
Genesis: chain.blocks[0].Hash(),
ForkID: chain.ForkID(),
}
}
if err := c.Write(status); err != nil {
t.Fatalf("could not write to connection: %v", err)
}
return message
}
// waitForBlock waits for confirmation from the client that it has
// imported the given block.
func (c *Conn) waitForBlock(block *types.Block) error {
defer c.SetReadDeadline(time.Time{})
c.SetReadDeadline(time.Now().Add(20 * time.Second))
// note: if the node has not yet imported the block, it will respond
// to the GetBlockHeaders request with an empty BlockHeaders response,
// so the GetBlockHeaders request must be sent again until the BlockHeaders
// response contains the desired header.
for {
req := &GetBlockHeaders{Origin: eth.HashOrNumber{Hash: block.Hash()}, Amount: 1}
if err := c.Write(req); err != nil {
return err return err
} }
switch msg := c.Read().(type) { _, err = c.Conn.Write(uint64(code), payload)
case *BlockHeaders: return err
for _, header := range *msg {
if header.Number.Uint64() == block.NumberU64() {
return nil
}
}
time.Sleep(100 * time.Millisecond)
default:
return fmt.Errorf("invalid message: %s", pretty.Sdump(msg))
}
}
} }