As of this commit, p2p will disconnect nodes directly after the encryption handshake if too many peer connections are active. Errors in the protocol handshake packet are now handled more politely by sending a disconnect packet before closing the connection.
		
			
				
	
	
		
			172 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			172 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package p2p
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"crypto/rand"
 | |
| 	"fmt"
 | |
| 	"net"
 | |
| 	"reflect"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/ethereum/go-ethereum/crypto"
 | |
| 	"github.com/ethereum/go-ethereum/crypto/ecies"
 | |
| 	"github.com/ethereum/go-ethereum/p2p/discover"
 | |
| )
 | |
| 
 | |
| func TestSharedSecret(t *testing.T) {
 | |
| 	prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
 | |
| 	pub0 := &prv0.PublicKey
 | |
| 	prv1, _ := crypto.GenerateKey()
 | |
| 	pub1 := &prv1.PublicKey
 | |
| 
 | |
| 	ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 	ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 	t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
 | |
| 	if !bytes.Equal(ss0, ss1) {
 | |
| 		t.Errorf("dont match :(")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestEncHandshake(t *testing.T) {
 | |
| 	for i := 0; i < 20; i++ {
 | |
| 		start := time.Now()
 | |
| 		if err := testEncHandshake(nil); err != nil {
 | |
| 			t.Fatalf("i=%d %v", i, err)
 | |
| 		}
 | |
| 		t.Logf("(without token) %d %v\n", i+1, time.Since(start))
 | |
| 	}
 | |
| 
 | |
| 	for i := 0; i < 20; i++ {
 | |
| 		tok := make([]byte, shaLen)
 | |
| 		rand.Reader.Read(tok)
 | |
| 		start := time.Now()
 | |
| 		if err := testEncHandshake(tok); err != nil {
 | |
| 			t.Fatalf("i=%d %v", i, err)
 | |
| 		}
 | |
| 		t.Logf("(with token) %d %v\n", i+1, time.Since(start))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func testEncHandshake(token []byte) error {
 | |
| 	type result struct {
 | |
| 		side string
 | |
| 		s    secrets
 | |
| 		err  error
 | |
| 	}
 | |
| 	var (
 | |
| 		prv0, _  = crypto.GenerateKey()
 | |
| 		prv1, _  = crypto.GenerateKey()
 | |
| 		rw0, rw1 = net.Pipe()
 | |
| 		output   = make(chan result)
 | |
| 	)
 | |
| 
 | |
| 	go func() {
 | |
| 		r := result{side: "initiator"}
 | |
| 		defer func() { output <- r }()
 | |
| 
 | |
| 		pub1s := discover.PubkeyID(&prv1.PublicKey)
 | |
| 		r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
 | |
| 		if r.err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		id1 := discover.PubkeyID(&prv1.PublicKey)
 | |
| 		if r.s.RemoteID != id1 {
 | |
| 			r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
 | |
| 		}
 | |
| 	}()
 | |
| 	go func() {
 | |
| 		r := result{side: "receiver"}
 | |
| 		defer func() { output <- r }()
 | |
| 
 | |
| 		r.s, r.err = receiverEncHandshake(rw1, prv1, token)
 | |
| 		if r.err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		id0 := discover.PubkeyID(&prv0.PublicKey)
 | |
| 		if r.s.RemoteID != id0 {
 | |
| 			r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	// wait for results from both sides
 | |
| 	r1, r2 := <-output, <-output
 | |
| 
 | |
| 	if r1.err != nil {
 | |
| 		return fmt.Errorf("%s side error: %v", r1.side, r1.err)
 | |
| 	}
 | |
| 	if r2.err != nil {
 | |
| 		return fmt.Errorf("%s side error: %v", r2.side, r2.err)
 | |
| 	}
 | |
| 
 | |
| 	// don't compare remote node IDs
 | |
| 	r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
 | |
| 	// flip MACs on one of them so they compare equal
 | |
| 	r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
 | |
| 	if !reflect.DeepEqual(r1.s, r2.s) {
 | |
| 		return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func TestSetupConn(t *testing.T) {
 | |
| 	prv0, _ := crypto.GenerateKey()
 | |
| 	prv1, _ := crypto.GenerateKey()
 | |
| 	node0 := &discover.Node{
 | |
| 		ID:      discover.PubkeyID(&prv0.PublicKey),
 | |
| 		IP:      net.IP{1, 2, 3, 4},
 | |
| 		TCPPort: 33,
 | |
| 	}
 | |
| 	node1 := &discover.Node{
 | |
| 		ID:      discover.PubkeyID(&prv1.PublicKey),
 | |
| 		IP:      net.IP{5, 6, 7, 8},
 | |
| 		TCPPort: 44,
 | |
| 	}
 | |
| 	hs0 := &protoHandshake{
 | |
| 		Version: baseProtocolVersion,
 | |
| 		ID:      node0.ID,
 | |
| 		Caps:    []Cap{{"a", 0}, {"b", 2}},
 | |
| 	}
 | |
| 	hs1 := &protoHandshake{
 | |
| 		Version: baseProtocolVersion,
 | |
| 		ID:      node1.ID,
 | |
| 		Caps:    []Cap{{"c", 1}, {"d", 3}},
 | |
| 	}
 | |
| 	fd0, fd1 := net.Pipe()
 | |
| 
 | |
| 	done := make(chan struct{})
 | |
| 	go func() {
 | |
| 		defer close(done)
 | |
| 		conn0, err := setupConn(fd0, prv0, hs0, node1, false)
 | |
| 		if err != nil {
 | |
| 			t.Errorf("outbound side error: %v", err)
 | |
| 			return
 | |
| 		}
 | |
| 		if conn0.ID != node1.ID {
 | |
| 			t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
 | |
| 		}
 | |
| 		if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
 | |
| 			t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	conn1, err := setupConn(fd1, prv1, hs1, nil, false)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("inbound side error: %v", err)
 | |
| 	}
 | |
| 	if conn1.ID != node0.ID {
 | |
| 		t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
 | |
| 	}
 | |
| 	if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
 | |
| 		t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
 | |
| 	}
 | |
| 
 | |
| 	<-done
 | |
| }
 |