This mostly changes how information is passed around. Instead of using many function parameters and return values, put the entire state in a struct and pass that. This also adds back derivation of ecdhe-shared-secret. I deleted it by accident in a previous refactoring.
		
			
				
	
	
		
			172 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			172 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package p2p
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"crypto/rand"
 | |
| 	"fmt"
 | |
| 	"net"
 | |
| 	"reflect"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/ethereum/go-ethereum/crypto"
 | |
| 	"github.com/ethereum/go-ethereum/crypto/ecies"
 | |
| 	"github.com/ethereum/go-ethereum/p2p/discover"
 | |
| )
 | |
| 
 | |
| func TestSharedSecret(t *testing.T) {
 | |
| 	prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
 | |
| 	pub0 := &prv0.PublicKey
 | |
| 	prv1, _ := crypto.GenerateKey()
 | |
| 	pub1 := &prv1.PublicKey
 | |
| 
 | |
| 	ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 	ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 	t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
 | |
| 	if !bytes.Equal(ss0, ss1) {
 | |
| 		t.Errorf("dont match :(")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestEncHandshake(t *testing.T) {
 | |
| 	for i := 0; i < 20; i++ {
 | |
| 		start := time.Now()
 | |
| 		if err := testEncHandshake(nil); err != nil {
 | |
| 			t.Fatalf("i=%d %v", i, err)
 | |
| 		}
 | |
| 		t.Logf("(without token) %d %v\n", i+1, time.Since(start))
 | |
| 	}
 | |
| 
 | |
| 	for i := 0; i < 20; i++ {
 | |
| 		tok := make([]byte, shaLen)
 | |
| 		rand.Reader.Read(tok)
 | |
| 		start := time.Now()
 | |
| 		if err := testEncHandshake(tok); err != nil {
 | |
| 			t.Fatalf("i=%d %v", i, err)
 | |
| 		}
 | |
| 		t.Logf("(with token) %d %v\n", i+1, time.Since(start))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func testEncHandshake(token []byte) error {
 | |
| 	type result struct {
 | |
| 		side string
 | |
| 		s    secrets
 | |
| 		err  error
 | |
| 	}
 | |
| 	var (
 | |
| 		prv0, _  = crypto.GenerateKey()
 | |
| 		prv1, _  = crypto.GenerateKey()
 | |
| 		rw0, rw1 = net.Pipe()
 | |
| 		output   = make(chan result)
 | |
| 	)
 | |
| 
 | |
| 	go func() {
 | |
| 		r := result{side: "initiator"}
 | |
| 		defer func() { output <- r }()
 | |
| 
 | |
| 		pub1s := discover.PubkeyID(&prv1.PublicKey)
 | |
| 		r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
 | |
| 		if r.err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		id1 := discover.PubkeyID(&prv1.PublicKey)
 | |
| 		if r.s.RemoteID != id1 {
 | |
| 			r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
 | |
| 		}
 | |
| 	}()
 | |
| 	go func() {
 | |
| 		r := result{side: "receiver"}
 | |
| 		defer func() { output <- r }()
 | |
| 
 | |
| 		r.s, r.err = receiverEncHandshake(rw1, prv1, token)
 | |
| 		if r.err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		id0 := discover.PubkeyID(&prv0.PublicKey)
 | |
| 		if r.s.RemoteID != id0 {
 | |
| 			r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	// wait for results from both sides
 | |
| 	r1, r2 := <-output, <-output
 | |
| 
 | |
| 	if r1.err != nil {
 | |
| 		return fmt.Errorf("%s side error: %v", r1.side, r1.err)
 | |
| 	}
 | |
| 	if r2.err != nil {
 | |
| 		return fmt.Errorf("%s side error: %v", r2.side, r2.err)
 | |
| 	}
 | |
| 
 | |
| 	// don't compare remote node IDs
 | |
| 	r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
 | |
| 	// flip MACs on one of them so they compare equal
 | |
| 	r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
 | |
| 	if !reflect.DeepEqual(r1.s, r2.s) {
 | |
| 		return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func TestSetupConn(t *testing.T) {
 | |
| 	prv0, _ := crypto.GenerateKey()
 | |
| 	prv1, _ := crypto.GenerateKey()
 | |
| 	node0 := &discover.Node{
 | |
| 		ID:      discover.PubkeyID(&prv0.PublicKey),
 | |
| 		IP:      net.IP{1, 2, 3, 4},
 | |
| 		TCPPort: 33,
 | |
| 	}
 | |
| 	node1 := &discover.Node{
 | |
| 		ID:      discover.PubkeyID(&prv1.PublicKey),
 | |
| 		IP:      net.IP{5, 6, 7, 8},
 | |
| 		TCPPort: 44,
 | |
| 	}
 | |
| 	hs0 := &protoHandshake{
 | |
| 		Version: baseProtocolVersion,
 | |
| 		ID:      node0.ID,
 | |
| 		Caps:    []Cap{{"a", 0}, {"b", 2}},
 | |
| 	}
 | |
| 	hs1 := &protoHandshake{
 | |
| 		Version: baseProtocolVersion,
 | |
| 		ID:      node1.ID,
 | |
| 		Caps:    []Cap{{"c", 1}, {"d", 3}},
 | |
| 	}
 | |
| 	fd0, fd1 := net.Pipe()
 | |
| 
 | |
| 	done := make(chan struct{})
 | |
| 	go func() {
 | |
| 		defer close(done)
 | |
| 		conn0, err := setupConn(fd0, prv0, hs0, node1)
 | |
| 		if err != nil {
 | |
| 			t.Errorf("outbound side error: %v", err)
 | |
| 			return
 | |
| 		}
 | |
| 		if conn0.ID != node1.ID {
 | |
| 			t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
 | |
| 		}
 | |
| 		if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
 | |
| 			t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	conn1, err := setupConn(fd1, prv1, hs1, nil)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("inbound side error: %v", err)
 | |
| 	}
 | |
| 	if conn1.ID != node0.ID {
 | |
| 		t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
 | |
| 	}
 | |
| 	if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
 | |
| 		t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
 | |
| 	}
 | |
| 
 | |
| 	<-done
 | |
| }
 |