362 lines
11 KiB
Go
362 lines
11 KiB
Go
// Copyright 2023 The go-ethereum Authors
|
|
// This file is part of go-ethereum.
|
|
//
|
|
// go-ethereum is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// go-ethereum is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU General Public License
|
|
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
package ethtest
|
|
|
|
import (
|
|
"crypto/ecdsa"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"reflect"
|
|
"time"
|
|
|
|
"github.com/davecgh/go-spew/spew"
|
|
"github.com/ethereum/go-ethereum/crypto"
|
|
"github.com/ethereum/go-ethereum/eth/protocols/eth"
|
|
"github.com/ethereum/go-ethereum/eth/protocols/snap"
|
|
"github.com/ethereum/go-ethereum/p2p"
|
|
"github.com/ethereum/go-ethereum/p2p/rlpx"
|
|
"github.com/ethereum/go-ethereum/rlp"
|
|
)
|
|
|
|
var (
|
|
pretty = spew.ConfigState{
|
|
Indent: " ",
|
|
DisableCapacities: true,
|
|
DisablePointerAddresses: true,
|
|
SortKeys: true,
|
|
}
|
|
timeout = 2 * time.Second
|
|
)
|
|
|
|
// dial attempts to dial the given node and perform a handshake, returning the
|
|
// created Conn if successful.
|
|
func (s *Suite) dial() (*Conn, error) {
|
|
key, _ := crypto.GenerateKey()
|
|
return s.dialAs(key)
|
|
}
|
|
|
|
// dialAs attempts to dial a given node and perform a handshake using the given
|
|
// private key.
|
|
func (s *Suite) dialAs(key *ecdsa.PrivateKey) (*Conn, error) {
|
|
fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP()))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conn := Conn{Conn: rlpx.NewConn(fd, s.Dest.Pubkey())}
|
|
conn.ourKey = key
|
|
_, err = conn.Handshake(conn.ourKey)
|
|
if err != nil {
|
|
conn.Close()
|
|
return nil, err
|
|
}
|
|
conn.caps = []p2p.Cap{
|
|
{Name: "eth", Version: 67},
|
|
{Name: "eth", Version: 68},
|
|
}
|
|
conn.ourHighestProtoVersion = 68
|
|
return &conn, nil
|
|
}
|
|
|
|
// dialSnap creates a connection with snap/1 capability.
|
|
func (s *Suite) dialSnap() (*Conn, error) {
|
|
conn, err := s.dial()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("dial failed: %v", err)
|
|
}
|
|
conn.caps = append(conn.caps, p2p.Cap{Name: "snap", Version: 1})
|
|
conn.ourHighestSnapProtoVersion = 1
|
|
return conn, nil
|
|
}
|
|
|
|
// Conn represents an individual connection with a peer
|
|
type Conn struct {
|
|
*rlpx.Conn
|
|
ourKey *ecdsa.PrivateKey
|
|
negotiatedProtoVersion uint
|
|
negotiatedSnapProtoVersion uint
|
|
ourHighestProtoVersion uint
|
|
ourHighestSnapProtoVersion uint
|
|
caps []p2p.Cap
|
|
}
|
|
|
|
// Read reads a packet from the connection.
|
|
func (c *Conn) Read() (uint64, []byte, error) {
|
|
c.SetReadDeadline(time.Now().Add(timeout))
|
|
code, data, _, err := c.Conn.Read()
|
|
if err != nil {
|
|
return 0, nil, err
|
|
}
|
|
return code, data, nil
|
|
}
|
|
|
|
// ReadMsg attempts to read a devp2p message with a specific code.
|
|
func (c *Conn) ReadMsg(proto Proto, code uint64, msg any) error {
|
|
c.SetReadDeadline(time.Now().Add(timeout))
|
|
for {
|
|
got, data, err := c.Read()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if protoOffset(proto)+code == got {
|
|
return rlp.DecodeBytes(data, msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Write writes a eth packet to the connection.
|
|
func (c *Conn) Write(proto Proto, code uint64, msg any) error {
|
|
c.SetWriteDeadline(time.Now().Add(timeout))
|
|
payload, err := rlp.EncodeToBytes(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = c.Conn.Write(protoOffset(proto)+code, payload)
|
|
return err
|
|
}
|
|
|
|
// ReadEth reads an Eth sub-protocol wire message.
|
|
func (c *Conn) ReadEth() (any, error) {
|
|
c.SetReadDeadline(time.Now().Add(timeout))
|
|
for {
|
|
code, data, _, err := c.Conn.Read()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if code == pingMsg {
|
|
c.Write(baseProto, pongMsg, []byte{})
|
|
continue
|
|
}
|
|
if getProto(code) != ethProto {
|
|
// Read until eth message.
|
|
continue
|
|
}
|
|
code -= baseProtoLen
|
|
|
|
var msg any
|
|
switch int(code) {
|
|
case eth.StatusMsg:
|
|
msg = new(eth.StatusPacket)
|
|
case eth.GetBlockHeadersMsg:
|
|
msg = new(eth.GetBlockHeadersPacket)
|
|
case eth.BlockHeadersMsg:
|
|
msg = new(eth.BlockHeadersPacket)
|
|
case eth.GetBlockBodiesMsg:
|
|
msg = new(eth.GetBlockBodiesPacket)
|
|
case eth.BlockBodiesMsg:
|
|
msg = new(eth.BlockBodiesPacket)
|
|
case eth.NewBlockMsg:
|
|
msg = new(eth.NewBlockPacket)
|
|
case eth.NewBlockHashesMsg:
|
|
msg = new(eth.NewBlockHashesPacket)
|
|
case eth.TransactionsMsg:
|
|
msg = new(eth.TransactionsPacket)
|
|
case eth.NewPooledTransactionHashesMsg:
|
|
msg = new(eth.NewPooledTransactionHashesPacket)
|
|
case eth.GetPooledTransactionsMsg:
|
|
msg = new(eth.GetPooledTransactionsPacket)
|
|
case eth.PooledTransactionsMsg:
|
|
msg = new(eth.PooledTransactionsPacket)
|
|
default:
|
|
panic(fmt.Sprintf("unhandled eth msg code %d", code))
|
|
}
|
|
if err := rlp.DecodeBytes(data, msg); err != nil {
|
|
return nil, fmt.Errorf("unable to decode eth msg: %v", err)
|
|
}
|
|
return msg, nil
|
|
}
|
|
}
|
|
|
|
// ReadSnap reads a snap/1 response with the given id from the connection.
|
|
func (c *Conn) ReadSnap() (any, error) {
|
|
c.SetReadDeadline(time.Now().Add(timeout))
|
|
for {
|
|
code, data, _, err := c.Conn.Read()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if getProto(code) != snapProto {
|
|
// Read until snap message.
|
|
continue
|
|
}
|
|
code -= baseProtoLen + ethProtoLen
|
|
|
|
var msg any
|
|
switch int(code) {
|
|
case snap.GetAccountRangeMsg:
|
|
msg = new(snap.GetAccountRangePacket)
|
|
case snap.AccountRangeMsg:
|
|
msg = new(snap.AccountRangePacket)
|
|
case snap.GetStorageRangesMsg:
|
|
msg = new(snap.GetStorageRangesPacket)
|
|
case snap.StorageRangesMsg:
|
|
msg = new(snap.StorageRangesPacket)
|
|
case snap.GetByteCodesMsg:
|
|
msg = new(snap.GetByteCodesPacket)
|
|
case snap.ByteCodesMsg:
|
|
msg = new(snap.ByteCodesPacket)
|
|
case snap.GetTrieNodesMsg:
|
|
msg = new(snap.GetTrieNodesPacket)
|
|
case snap.TrieNodesMsg:
|
|
msg = new(snap.TrieNodesPacket)
|
|
default:
|
|
panic(fmt.Errorf("unhandled snap code: %d", code))
|
|
}
|
|
if err := rlp.DecodeBytes(data, msg); err != nil {
|
|
return nil, fmt.Errorf("could not rlp decode message: %v", err)
|
|
}
|
|
return msg, nil
|
|
}
|
|
}
|
|
|
|
// peer performs both the protocol handshake and the status message
|
|
// exchange with the node in order to peer with it.
|
|
func (c *Conn) peer(chain *Chain, status *eth.StatusPacket) error {
|
|
if err := c.handshake(); err != nil {
|
|
return fmt.Errorf("handshake failed: %v", err)
|
|
}
|
|
if err := c.statusExchange(chain, status); err != nil {
|
|
return fmt.Errorf("status exchange failed: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handshake performs a protocol handshake with the node.
|
|
func (c *Conn) handshake() error {
|
|
// Write hello to client.
|
|
pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:]
|
|
ourHandshake := &protoHandshake{
|
|
Version: 5,
|
|
Caps: c.caps,
|
|
ID: pub0,
|
|
}
|
|
if err := c.Write(baseProto, handshakeMsg, ourHandshake); err != nil {
|
|
return fmt.Errorf("write to connection failed: %v", err)
|
|
}
|
|
// Read hello from client.
|
|
code, data, err := c.Read()
|
|
if err != nil {
|
|
return fmt.Errorf("erroring reading handshake: %v", err)
|
|
}
|
|
switch code {
|
|
case handshakeMsg:
|
|
msg := new(protoHandshake)
|
|
if err := rlp.DecodeBytes(data, &msg); err != nil {
|
|
return fmt.Errorf("error decoding handshake msg: %v", err)
|
|
}
|
|
// Set snappy if version is at least 5.
|
|
if msg.Version >= 5 {
|
|
c.SetSnappy(true)
|
|
}
|
|
c.negotiateEthProtocol(msg.Caps)
|
|
if c.negotiatedProtoVersion == 0 {
|
|
return fmt.Errorf("could not negotiate eth protocol (remote caps: %v, local eth version: %v)", msg.Caps, c.ourHighestProtoVersion)
|
|
}
|
|
// If we require snap, verify that it was negotiated.
|
|
if c.ourHighestSnapProtoVersion != c.negotiatedSnapProtoVersion {
|
|
return fmt.Errorf("could not negotiate snap protocol (remote caps: %v, local snap version: %v)", msg.Caps, c.ourHighestSnapProtoVersion)
|
|
}
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("bad handshake: got msg code %d", code)
|
|
}
|
|
}
|
|
|
|
// negotiateEthProtocol sets the Conn's eth protocol version to highest
|
|
// advertised capability from peer.
|
|
func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) {
|
|
var highestEthVersion uint
|
|
var highestSnapVersion uint
|
|
for _, capability := range caps {
|
|
switch capability.Name {
|
|
case "eth":
|
|
if capability.Version > highestEthVersion && capability.Version <= c.ourHighestProtoVersion {
|
|
highestEthVersion = capability.Version
|
|
}
|
|
case "snap":
|
|
if capability.Version > highestSnapVersion && capability.Version <= c.ourHighestSnapProtoVersion {
|
|
highestSnapVersion = capability.Version
|
|
}
|
|
}
|
|
}
|
|
c.negotiatedProtoVersion = highestEthVersion
|
|
c.negotiatedSnapProtoVersion = highestSnapVersion
|
|
}
|
|
|
|
// statusExchange performs a `Status` message exchange with the given node.
|
|
func (c *Conn) statusExchange(chain *Chain, status *eth.StatusPacket) error {
|
|
loop:
|
|
for {
|
|
code, data, err := c.Read()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read from connection: %w", err)
|
|
}
|
|
switch code {
|
|
case eth.StatusMsg + protoOffset(ethProto):
|
|
msg := new(eth.StatusPacket)
|
|
if err := rlp.DecodeBytes(data, &msg); err != nil {
|
|
return fmt.Errorf("error decoding status packet: %w", err)
|
|
}
|
|
if have, want := msg.Head, chain.blocks[chain.Len()-1].Hash(); have != want {
|
|
return fmt.Errorf("wrong head block in status, want: %#x (block %d) have %#x",
|
|
want, chain.blocks[chain.Len()-1].NumberU64(), have)
|
|
}
|
|
if have, want := msg.TD.Cmp(chain.TD()), 0; have != want {
|
|
return fmt.Errorf("wrong TD in status: have %v want %v", have, want)
|
|
}
|
|
if have, want := msg.ForkID, chain.ForkID(); !reflect.DeepEqual(have, want) {
|
|
return fmt.Errorf("wrong fork ID in status: have %v, want %v", have, want)
|
|
}
|
|
if have, want := msg.ProtocolVersion, c.ourHighestProtoVersion; have != uint32(want) {
|
|
return fmt.Errorf("wrong protocol version: have %v, want %v", have, want)
|
|
}
|
|
break loop
|
|
case discMsg:
|
|
var msg []p2p.DiscReason
|
|
if rlp.DecodeBytes(data, &msg); len(msg) == 0 {
|
|
return errors.New("invalid disconnect message")
|
|
}
|
|
return fmt.Errorf("disconnect received: %v", pretty.Sdump(msg))
|
|
case pingMsg:
|
|
// TODO (renaynay): in the future, this should be an error
|
|
// (PINGs should not be a response upon fresh connection)
|
|
c.Write(baseProto, pongMsg, nil)
|
|
default:
|
|
return fmt.Errorf("bad status message: code %d", code)
|
|
}
|
|
}
|
|
// make sure eth protocol version is set for negotiation
|
|
if c.negotiatedProtoVersion == 0 {
|
|
return errors.New("eth protocol version must be set in Conn")
|
|
}
|
|
if status == nil {
|
|
// default status message
|
|
status = ð.StatusPacket{
|
|
ProtocolVersion: uint32(c.negotiatedProtoVersion),
|
|
NetworkID: chain.config.ChainID.Uint64(),
|
|
TD: chain.TD(),
|
|
Head: chain.blocks[chain.Len()-1].Hash(),
|
|
Genesis: chain.blocks[0].Hash(),
|
|
ForkID: chain.ForkID(),
|
|
}
|
|
}
|
|
if err := c.Write(ethProto, eth.StatusMsg, status); err != nil {
|
|
return fmt.Errorf("write to connection failed: %v", err)
|
|
}
|
|
return nil
|
|
}
|