1440f9a37a
The most visible change is event-based dialing, which should be an improvement over the timer-based system that we have at the moment. The dialer gets a chance to compute new tasks whenever peers change or dials complete. This is better than checking peers on a timer because dials happen faster. The dialer can now make more precise decisions about whom to dial based on the peer set and we can test those decisions without actually opening any sockets. Peer management is easier to test because the tests can inject connections at checkpoints (after enc handshake, after protocol handshake). Most of the handshake stuff is now part of the RLPx code. It could be exported or move to its own package because it is no longer entangled with Server logic.
359 lines
9.4 KiB
Go
359 lines
9.4 KiB
Go
package p2p
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/davecgh/go-spew/spew"
|
|
"github.com/ethereum/go-ethereum/crypto"
|
|
"github.com/ethereum/go-ethereum/crypto/ecies"
|
|
"github.com/ethereum/go-ethereum/crypto/sha3"
|
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
|
"github.com/ethereum/go-ethereum/rlp"
|
|
)
|
|
|
|
func TestSharedSecret(t *testing.T) {
|
|
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
|
pub0 := &prv0.PublicKey
|
|
prv1, _ := crypto.GenerateKey()
|
|
pub1 := &prv1.PublicKey
|
|
|
|
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
|
|
if err != nil {
|
|
return
|
|
}
|
|
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
|
|
if err != nil {
|
|
return
|
|
}
|
|
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
|
|
if !bytes.Equal(ss0, ss1) {
|
|
t.Errorf("dont match :(")
|
|
}
|
|
}
|
|
|
|
func TestEncHandshake(t *testing.T) {
|
|
for i := 0; i < 10; i++ {
|
|
start := time.Now()
|
|
if err := testEncHandshake(nil); err != nil {
|
|
t.Fatalf("i=%d %v", i, err)
|
|
}
|
|
t.Logf("(without token) %d %v\n", i+1, time.Since(start))
|
|
}
|
|
for i := 0; i < 10; i++ {
|
|
tok := make([]byte, shaLen)
|
|
rand.Reader.Read(tok)
|
|
start := time.Now()
|
|
if err := testEncHandshake(tok); err != nil {
|
|
t.Fatalf("i=%d %v", i, err)
|
|
}
|
|
t.Logf("(with token) %d %v\n", i+1, time.Since(start))
|
|
}
|
|
}
|
|
|
|
func testEncHandshake(token []byte) error {
|
|
type result struct {
|
|
side string
|
|
id discover.NodeID
|
|
err error
|
|
}
|
|
var (
|
|
prv0, _ = crypto.GenerateKey()
|
|
prv1, _ = crypto.GenerateKey()
|
|
fd0, fd1 = net.Pipe()
|
|
c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx)
|
|
output = make(chan result)
|
|
)
|
|
|
|
go func() {
|
|
r := result{side: "initiator"}
|
|
defer func() { output <- r }()
|
|
|
|
dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)}
|
|
r.id, r.err = c0.doEncHandshake(prv0, dest)
|
|
if r.err != nil {
|
|
return
|
|
}
|
|
id1 := discover.PubkeyID(&prv1.PublicKey)
|
|
if r.id != id1 {
|
|
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1)
|
|
}
|
|
}()
|
|
go func() {
|
|
r := result{side: "receiver"}
|
|
defer func() { output <- r }()
|
|
|
|
r.id, r.err = c1.doEncHandshake(prv1, nil)
|
|
if r.err != nil {
|
|
return
|
|
}
|
|
id0 := discover.PubkeyID(&prv0.PublicKey)
|
|
if r.id != id0 {
|
|
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0)
|
|
}
|
|
}()
|
|
|
|
// wait for results from both sides
|
|
r1, r2 := <-output, <-output
|
|
if r1.err != nil {
|
|
return fmt.Errorf("%s side error: %v", r1.side, r1.err)
|
|
}
|
|
if r2.err != nil {
|
|
return fmt.Errorf("%s side error: %v", r2.side, r2.err)
|
|
}
|
|
|
|
// compare derived secrets
|
|
if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) {
|
|
return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC)
|
|
}
|
|
if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) {
|
|
return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC)
|
|
}
|
|
if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) {
|
|
return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc)
|
|
}
|
|
if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) {
|
|
return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestProtocolHandshake(t *testing.T) {
|
|
var (
|
|
prv0, _ = crypto.GenerateKey()
|
|
node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33}
|
|
hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}}
|
|
|
|
prv1, _ = crypto.GenerateKey()
|
|
node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44}
|
|
hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}}
|
|
|
|
fd0, fd1 = net.Pipe()
|
|
wg sync.WaitGroup
|
|
)
|
|
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
rlpx := newRLPX(fd0)
|
|
remid, err := rlpx.doEncHandshake(prv0, node1)
|
|
if err != nil {
|
|
t.Errorf("dial side enc handshake failed: %v", err)
|
|
return
|
|
}
|
|
if remid != node1.ID {
|
|
t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID)
|
|
return
|
|
}
|
|
|
|
phs, err := rlpx.doProtoHandshake(hs0)
|
|
if err != nil {
|
|
t.Errorf("dial side proto handshake error: %v", err)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(phs, hs1) {
|
|
t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1))
|
|
return
|
|
}
|
|
rlpx.close(DiscQuitting)
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
rlpx := newRLPX(fd1)
|
|
remid, err := rlpx.doEncHandshake(prv1, nil)
|
|
if err != nil {
|
|
t.Errorf("listen side enc handshake failed: %v", err)
|
|
return
|
|
}
|
|
if remid != node0.ID {
|
|
t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID)
|
|
return
|
|
}
|
|
|
|
phs, err := rlpx.doProtoHandshake(hs1)
|
|
if err != nil {
|
|
t.Errorf("listen side proto handshake error: %v", err)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(phs, hs0) {
|
|
t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0))
|
|
return
|
|
}
|
|
|
|
if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil {
|
|
t.Errorf("error receiving disconnect: %v", err)
|
|
}
|
|
}()
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestProtocolHandshakeErrors(t *testing.T) {
|
|
our := &protoHandshake{Version: 3, Caps: []Cap{{"foo", 2}, {"bar", 3}}, Name: "quux"}
|
|
id := randomID()
|
|
tests := []struct {
|
|
code uint64
|
|
msg interface{}
|
|
err error
|
|
}{
|
|
{
|
|
code: discMsg,
|
|
msg: []DiscReason{DiscQuitting},
|
|
err: DiscQuitting,
|
|
},
|
|
{
|
|
code: 0x989898,
|
|
msg: []byte{1},
|
|
err: errors.New("expected handshake, got 989898"),
|
|
},
|
|
{
|
|
code: handshakeMsg,
|
|
msg: make([]byte, baseProtocolMaxMsgSize+2),
|
|
err: errors.New("message too big"),
|
|
},
|
|
{
|
|
code: handshakeMsg,
|
|
msg: []byte{1, 2, 3},
|
|
err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"),
|
|
},
|
|
{
|
|
code: handshakeMsg,
|
|
msg: &protoHandshake{Version: 9944, ID: id},
|
|
err: DiscIncompatibleVersion,
|
|
},
|
|
{
|
|
code: handshakeMsg,
|
|
msg: &protoHandshake{Version: 3},
|
|
err: DiscInvalidIdentity,
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
p1, p2 := MsgPipe()
|
|
go Send(p1, test.code, test.msg)
|
|
_, err := readProtocolHandshake(p2, our)
|
|
if !reflect.DeepEqual(err, test.err) {
|
|
t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRLPXFrameFake(t *testing.T) {
|
|
buf := new(bytes.Buffer)
|
|
hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
|
|
rw := newRLPXFrameRW(buf, secrets{
|
|
AES: crypto.Sha3(),
|
|
MAC: crypto.Sha3(),
|
|
IngressMAC: hash,
|
|
EgressMAC: hash,
|
|
})
|
|
|
|
golden := unhex(`
|
|
00828ddae471818bb0bfa6b551d1cb42
|
|
01010101010101010101010101010101
|
|
ba628a4ba590cb43f7848f41c4382885
|
|
01010101010101010101010101010101
|
|
`)
|
|
|
|
// Check WriteMsg. This puts a message into the buffer.
|
|
if err := Send(rw, 8, []uint{1, 2, 3, 4}); err != nil {
|
|
t.Fatalf("WriteMsg error: %v", err)
|
|
}
|
|
written := buf.Bytes()
|
|
if !bytes.Equal(written, golden) {
|
|
t.Fatalf("output mismatch:\n got: %x\n want: %x", written, golden)
|
|
}
|
|
|
|
// Check ReadMsg. It reads the message encoded by WriteMsg, which
|
|
// is equivalent to the golden message above.
|
|
msg, err := rw.ReadMsg()
|
|
if err != nil {
|
|
t.Fatalf("ReadMsg error: %v", err)
|
|
}
|
|
if msg.Size != 5 {
|
|
t.Errorf("msg size mismatch: got %d, want %d", msg.Size, 5)
|
|
}
|
|
if msg.Code != 8 {
|
|
t.Errorf("msg code mismatch: got %d, want %d", msg.Code, 8)
|
|
}
|
|
payload, _ := ioutil.ReadAll(msg.Payload)
|
|
wantPayload := unhex("C401020304")
|
|
if !bytes.Equal(payload, wantPayload) {
|
|
t.Errorf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload)
|
|
}
|
|
}
|
|
|
|
type fakeHash []byte
|
|
|
|
func (fakeHash) Write(p []byte) (int, error) { return len(p), nil }
|
|
func (fakeHash) Reset() {}
|
|
func (fakeHash) BlockSize() int { return 0 }
|
|
|
|
func (h fakeHash) Size() int { return len(h) }
|
|
func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
|
|
|
|
func TestRLPXFrameRW(t *testing.T) {
|
|
var (
|
|
aesSecret = make([]byte, 16)
|
|
macSecret = make([]byte, 16)
|
|
egressMACinit = make([]byte, 32)
|
|
ingressMACinit = make([]byte, 32)
|
|
)
|
|
for _, s := range [][]byte{aesSecret, macSecret, egressMACinit, ingressMACinit} {
|
|
rand.Read(s)
|
|
}
|
|
conn := new(bytes.Buffer)
|
|
|
|
s1 := secrets{
|
|
AES: aesSecret,
|
|
MAC: macSecret,
|
|
EgressMAC: sha3.NewKeccak256(),
|
|
IngressMAC: sha3.NewKeccak256(),
|
|
}
|
|
s1.EgressMAC.Write(egressMACinit)
|
|
s1.IngressMAC.Write(ingressMACinit)
|
|
rw1 := newRLPXFrameRW(conn, s1)
|
|
|
|
s2 := secrets{
|
|
AES: aesSecret,
|
|
MAC: macSecret,
|
|
EgressMAC: sha3.NewKeccak256(),
|
|
IngressMAC: sha3.NewKeccak256(),
|
|
}
|
|
s2.EgressMAC.Write(ingressMACinit)
|
|
s2.IngressMAC.Write(egressMACinit)
|
|
rw2 := newRLPXFrameRW(conn, s2)
|
|
|
|
// send some messages
|
|
for i := 0; i < 10; i++ {
|
|
// write message into conn buffer
|
|
wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)}
|
|
err := Send(rw1, uint64(i), wmsg)
|
|
if err != nil {
|
|
t.Fatalf("WriteMsg error (i=%d): %v", i, err)
|
|
}
|
|
|
|
// read message that rw1 just wrote
|
|
msg, err := rw2.ReadMsg()
|
|
if err != nil {
|
|
t.Fatalf("ReadMsg error (i=%d): %v", i, err)
|
|
}
|
|
if msg.Code != uint64(i) {
|
|
t.Fatalf("msg code mismatch: got %d, want %d", msg.Code, i)
|
|
}
|
|
payload, _ := ioutil.ReadAll(msg.Payload)
|
|
wantPayload, _ := rlp.EncodeToBytes(wmsg)
|
|
if !bytes.Equal(payload, wantPayload) {
|
|
t.Fatalf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload)
|
|
}
|
|
}
|
|
}
|