914e57e49b
With the introduction of static/trusted nodes, the peer count can go above MaxPeers. Update the capacity check to handle this. While here, decouple the trusted nodes check from the handshake by passing a function instead.
173 lines
4.2 KiB
Go
173 lines
4.2 KiB
Go
package p2p
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"net"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/ethereum/go-ethereum/crypto"
|
|
"github.com/ethereum/go-ethereum/crypto/ecies"
|
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
|
)
|
|
|
|
func TestSharedSecret(t *testing.T) {
|
|
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
|
pub0 := &prv0.PublicKey
|
|
prv1, _ := crypto.GenerateKey()
|
|
pub1 := &prv1.PublicKey
|
|
|
|
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
|
|
if err != nil {
|
|
return
|
|
}
|
|
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
|
|
if err != nil {
|
|
return
|
|
}
|
|
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
|
|
if !bytes.Equal(ss0, ss1) {
|
|
t.Errorf("dont match :(")
|
|
}
|
|
}
|
|
|
|
func TestEncHandshake(t *testing.T) {
|
|
for i := 0; i < 20; i++ {
|
|
start := time.Now()
|
|
if err := testEncHandshake(nil); err != nil {
|
|
t.Fatalf("i=%d %v", i, err)
|
|
}
|
|
t.Logf("(without token) %d %v\n", i+1, time.Since(start))
|
|
}
|
|
|
|
for i := 0; i < 20; i++ {
|
|
tok := make([]byte, shaLen)
|
|
rand.Reader.Read(tok)
|
|
start := time.Now()
|
|
if err := testEncHandshake(tok); err != nil {
|
|
t.Fatalf("i=%d %v", i, err)
|
|
}
|
|
t.Logf("(with token) %d %v\n", i+1, time.Since(start))
|
|
}
|
|
}
|
|
|
|
func testEncHandshake(token []byte) error {
|
|
type result struct {
|
|
side string
|
|
s secrets
|
|
err error
|
|
}
|
|
var (
|
|
prv0, _ = crypto.GenerateKey()
|
|
prv1, _ = crypto.GenerateKey()
|
|
rw0, rw1 = net.Pipe()
|
|
output = make(chan result)
|
|
)
|
|
|
|
go func() {
|
|
r := result{side: "initiator"}
|
|
defer func() { output <- r }()
|
|
|
|
pub1s := discover.PubkeyID(&prv1.PublicKey)
|
|
r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
|
|
if r.err != nil {
|
|
return
|
|
}
|
|
id1 := discover.PubkeyID(&prv1.PublicKey)
|
|
if r.s.RemoteID != id1 {
|
|
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
|
|
}
|
|
}()
|
|
go func() {
|
|
r := result{side: "receiver"}
|
|
defer func() { output <- r }()
|
|
|
|
r.s, r.err = receiverEncHandshake(rw1, prv1, token)
|
|
if r.err != nil {
|
|
return
|
|
}
|
|
id0 := discover.PubkeyID(&prv0.PublicKey)
|
|
if r.s.RemoteID != id0 {
|
|
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
|
|
}
|
|
}()
|
|
|
|
// wait for results from both sides
|
|
r1, r2 := <-output, <-output
|
|
|
|
if r1.err != nil {
|
|
return fmt.Errorf("%s side error: %v", r1.side, r1.err)
|
|
}
|
|
if r2.err != nil {
|
|
return fmt.Errorf("%s side error: %v", r2.side, r2.err)
|
|
}
|
|
|
|
// don't compare remote node IDs
|
|
r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
|
|
// flip MACs on one of them so they compare equal
|
|
r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
|
|
if !reflect.DeepEqual(r1.s, r2.s) {
|
|
return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestSetupConn(t *testing.T) {
|
|
prv0, _ := crypto.GenerateKey()
|
|
prv1, _ := crypto.GenerateKey()
|
|
node0 := &discover.Node{
|
|
ID: discover.PubkeyID(&prv0.PublicKey),
|
|
IP: net.IP{1, 2, 3, 4},
|
|
TCP: 33,
|
|
}
|
|
node1 := &discover.Node{
|
|
ID: discover.PubkeyID(&prv1.PublicKey),
|
|
IP: net.IP{5, 6, 7, 8},
|
|
TCP: 44,
|
|
}
|
|
hs0 := &protoHandshake{
|
|
Version: baseProtocolVersion,
|
|
ID: node0.ID,
|
|
Caps: []Cap{{"a", 0}, {"b", 2}},
|
|
}
|
|
hs1 := &protoHandshake{
|
|
Version: baseProtocolVersion,
|
|
ID: node1.ID,
|
|
Caps: []Cap{{"c", 1}, {"d", 3}},
|
|
}
|
|
fd0, fd1 := net.Pipe()
|
|
|
|
done := make(chan struct{})
|
|
keepalways := func(discover.NodeID) bool { return true }
|
|
go func() {
|
|
defer close(done)
|
|
conn0, err := setupConn(fd0, prv0, hs0, node1, keepalways)
|
|
if err != nil {
|
|
t.Errorf("outbound side error: %v", err)
|
|
return
|
|
}
|
|
if conn0.ID != node1.ID {
|
|
t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
|
|
}
|
|
if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
|
|
t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
|
|
}
|
|
}()
|
|
|
|
conn1, err := setupConn(fd1, prv1, hs1, nil, keepalways)
|
|
if err != nil {
|
|
t.Fatalf("inbound side error: %v", err)
|
|
}
|
|
if conn1.ID != node0.ID {
|
|
t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
|
|
}
|
|
if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
|
|
t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
|
|
}
|
|
|
|
<-done
|
|
}
|