b3c058a9e4
As of this commit, p2p will disconnect nodes directly after the encryption handshake if too many peer connections are active. Errors in the protocol handshake packet are now handled more politely by sending a disconnect packet before closing the connection.
172 lines
4.1 KiB
Go
172 lines
4.1 KiB
Go
package p2p
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"net"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/ethereum/go-ethereum/crypto"
|
|
"github.com/ethereum/go-ethereum/crypto/ecies"
|
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
|
)
|
|
|
|
func TestSharedSecret(t *testing.T) {
|
|
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
|
pub0 := &prv0.PublicKey
|
|
prv1, _ := crypto.GenerateKey()
|
|
pub1 := &prv1.PublicKey
|
|
|
|
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
|
|
if err != nil {
|
|
return
|
|
}
|
|
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
|
|
if err != nil {
|
|
return
|
|
}
|
|
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
|
|
if !bytes.Equal(ss0, ss1) {
|
|
t.Errorf("dont match :(")
|
|
}
|
|
}
|
|
|
|
func TestEncHandshake(t *testing.T) {
|
|
for i := 0; i < 20; i++ {
|
|
start := time.Now()
|
|
if err := testEncHandshake(nil); err != nil {
|
|
t.Fatalf("i=%d %v", i, err)
|
|
}
|
|
t.Logf("(without token) %d %v\n", i+1, time.Since(start))
|
|
}
|
|
|
|
for i := 0; i < 20; i++ {
|
|
tok := make([]byte, shaLen)
|
|
rand.Reader.Read(tok)
|
|
start := time.Now()
|
|
if err := testEncHandshake(tok); err != nil {
|
|
t.Fatalf("i=%d %v", i, err)
|
|
}
|
|
t.Logf("(with token) %d %v\n", i+1, time.Since(start))
|
|
}
|
|
}
|
|
|
|
func testEncHandshake(token []byte) error {
|
|
type result struct {
|
|
side string
|
|
s secrets
|
|
err error
|
|
}
|
|
var (
|
|
prv0, _ = crypto.GenerateKey()
|
|
prv1, _ = crypto.GenerateKey()
|
|
rw0, rw1 = net.Pipe()
|
|
output = make(chan result)
|
|
)
|
|
|
|
go func() {
|
|
r := result{side: "initiator"}
|
|
defer func() { output <- r }()
|
|
|
|
pub1s := discover.PubkeyID(&prv1.PublicKey)
|
|
r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
|
|
if r.err != nil {
|
|
return
|
|
}
|
|
id1 := discover.PubkeyID(&prv1.PublicKey)
|
|
if r.s.RemoteID != id1 {
|
|
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
|
|
}
|
|
}()
|
|
go func() {
|
|
r := result{side: "receiver"}
|
|
defer func() { output <- r }()
|
|
|
|
r.s, r.err = receiverEncHandshake(rw1, prv1, token)
|
|
if r.err != nil {
|
|
return
|
|
}
|
|
id0 := discover.PubkeyID(&prv0.PublicKey)
|
|
if r.s.RemoteID != id0 {
|
|
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
|
|
}
|
|
}()
|
|
|
|
// wait for results from both sides
|
|
r1, r2 := <-output, <-output
|
|
|
|
if r1.err != nil {
|
|
return fmt.Errorf("%s side error: %v", r1.side, r1.err)
|
|
}
|
|
if r2.err != nil {
|
|
return fmt.Errorf("%s side error: %v", r2.side, r2.err)
|
|
}
|
|
|
|
// don't compare remote node IDs
|
|
r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
|
|
// flip MACs on one of them so they compare equal
|
|
r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
|
|
if !reflect.DeepEqual(r1.s, r2.s) {
|
|
return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestSetupConn(t *testing.T) {
|
|
prv0, _ := crypto.GenerateKey()
|
|
prv1, _ := crypto.GenerateKey()
|
|
node0 := &discover.Node{
|
|
ID: discover.PubkeyID(&prv0.PublicKey),
|
|
IP: net.IP{1, 2, 3, 4},
|
|
TCPPort: 33,
|
|
}
|
|
node1 := &discover.Node{
|
|
ID: discover.PubkeyID(&prv1.PublicKey),
|
|
IP: net.IP{5, 6, 7, 8},
|
|
TCPPort: 44,
|
|
}
|
|
hs0 := &protoHandshake{
|
|
Version: baseProtocolVersion,
|
|
ID: node0.ID,
|
|
Caps: []Cap{{"a", 0}, {"b", 2}},
|
|
}
|
|
hs1 := &protoHandshake{
|
|
Version: baseProtocolVersion,
|
|
ID: node1.ID,
|
|
Caps: []Cap{{"c", 1}, {"d", 3}},
|
|
}
|
|
fd0, fd1 := net.Pipe()
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
conn0, err := setupConn(fd0, prv0, hs0, node1, false)
|
|
if err != nil {
|
|
t.Errorf("outbound side error: %v", err)
|
|
return
|
|
}
|
|
if conn0.ID != node1.ID {
|
|
t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
|
|
}
|
|
if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
|
|
t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
|
|
}
|
|
}()
|
|
|
|
conn1, err := setupConn(fd1, prv1, hs1, nil, false)
|
|
if err != nil {
|
|
t.Fatalf("inbound side error: %v", err)
|
|
}
|
|
if conn1.ID != node0.ID {
|
|
t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
|
|
}
|
|
if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
|
|
t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
|
|
}
|
|
|
|
<-done
|
|
}
|