eb0e7b1b81
...and make it a top-level function instead. The original idea behind having EncodeMsg in the interface was that implementations might be able to encode RLP data to their underlying writer directly instead of buffering the encoded data. The encoder will buffer anyway, so that doesn't matter anymore. Given the recent problems with EncodeMsg (copy-pasted implementation bug) I'd rather implement once, correctly.
302 lines
7.1 KiB
Go
302 lines
7.1 KiB
Go
package p2p
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/hex"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
var discard = Protocol{
|
|
Name: "discard",
|
|
Length: 1,
|
|
Run: func(p *Peer, rw MsgReadWriter) error {
|
|
for {
|
|
msg, err := rw.ReadMsg()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err = msg.Discard(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
},
|
|
}
|
|
|
|
func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
|
|
conn1, conn2 := net.Pipe()
|
|
peer := newPeer(conn1, protos, nil)
|
|
peer.ourID = &peerId{}
|
|
peer.pubkeyHook = func(*peerAddr) error { return nil }
|
|
errc := make(chan error, 1)
|
|
go func() {
|
|
_, err := peer.loop()
|
|
errc <- err
|
|
}()
|
|
return conn2, peer, errc
|
|
}
|
|
|
|
func TestPeerProtoReadMsg(t *testing.T) {
|
|
defer testlog(t).detach()
|
|
|
|
done := make(chan struct{})
|
|
proto := Protocol{
|
|
Name: "a",
|
|
Length: 5,
|
|
Run: func(peer *Peer, rw MsgReadWriter) error {
|
|
msg, err := rw.ReadMsg()
|
|
if err != nil {
|
|
t.Errorf("read error: %v", err)
|
|
}
|
|
if msg.Code != 2 {
|
|
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
|
|
}
|
|
data, err := ioutil.ReadAll(msg.Payload)
|
|
if err != nil {
|
|
t.Errorf("payload read error: %v", err)
|
|
}
|
|
expdata, _ := hex.DecodeString("0183303030")
|
|
if !bytes.Equal(expdata, data) {
|
|
t.Errorf("incorrect msg data %x", data)
|
|
}
|
|
close(done)
|
|
return nil
|
|
},
|
|
}
|
|
|
|
net, peer, errc := testPeer([]Protocol{proto})
|
|
defer net.Close()
|
|
peer.startSubprotocols([]Cap{proto.cap()})
|
|
|
|
writeMsg(net, NewMsg(18, 1, "000"))
|
|
select {
|
|
case <-done:
|
|
case err := <-errc:
|
|
t.Errorf("peer returned: %v", err)
|
|
case <-time.After(2 * time.Second):
|
|
t.Errorf("receive timeout")
|
|
}
|
|
}
|
|
|
|
func TestPeerProtoReadLargeMsg(t *testing.T) {
|
|
defer testlog(t).detach()
|
|
|
|
msgsize := uint32(10 * 1024 * 1024)
|
|
done := make(chan struct{})
|
|
proto := Protocol{
|
|
Name: "a",
|
|
Length: 5,
|
|
Run: func(peer *Peer, rw MsgReadWriter) error {
|
|
msg, err := rw.ReadMsg()
|
|
if err != nil {
|
|
t.Errorf("read error: %v", err)
|
|
}
|
|
if msg.Size != msgsize+4 {
|
|
t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
|
|
}
|
|
msg.Discard()
|
|
close(done)
|
|
return nil
|
|
},
|
|
}
|
|
|
|
net, peer, errc := testPeer([]Protocol{proto})
|
|
defer net.Close()
|
|
peer.startSubprotocols([]Cap{proto.cap()})
|
|
|
|
writeMsg(net, NewMsg(18, make([]byte, msgsize)))
|
|
select {
|
|
case <-done:
|
|
case err := <-errc:
|
|
t.Errorf("peer returned: %v", err)
|
|
case <-time.After(2 * time.Second):
|
|
t.Errorf("receive timeout")
|
|
}
|
|
}
|
|
|
|
func TestPeerProtoEncodeMsg(t *testing.T) {
|
|
defer testlog(t).detach()
|
|
|
|
proto := Protocol{
|
|
Name: "a",
|
|
Length: 2,
|
|
Run: func(peer *Peer, rw MsgReadWriter) error {
|
|
if err := EncodeMsg(rw, 2); err == nil {
|
|
t.Error("expected error for out-of-range msg code, got nil")
|
|
}
|
|
if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil {
|
|
t.Errorf("write error: %v", err)
|
|
}
|
|
return nil
|
|
},
|
|
}
|
|
net, peer, _ := testPeer([]Protocol{proto})
|
|
defer net.Close()
|
|
peer.startSubprotocols([]Cap{proto.cap()})
|
|
|
|
bufr := bufio.NewReader(net)
|
|
msg, err := readMsg(bufr)
|
|
if err != nil {
|
|
t.Errorf("read error: %v", err)
|
|
}
|
|
if msg.Code != 17 {
|
|
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
|
|
}
|
|
var data []string
|
|
if err := msg.Decode(&data); err != nil {
|
|
t.Errorf("payload decode error: %v", err)
|
|
}
|
|
if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
|
|
t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
|
|
}
|
|
}
|
|
|
|
func TestPeerWrite(t *testing.T) {
|
|
defer testlog(t).detach()
|
|
|
|
net, peer, peerErr := testPeer([]Protocol{discard})
|
|
defer net.Close()
|
|
peer.startSubprotocols([]Cap{discard.cap()})
|
|
|
|
// test write errors
|
|
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
|
|
t.Errorf("expected error for unknown protocol, got nil")
|
|
}
|
|
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
|
|
t.Errorf("expected error for out-of-range msg code, got nil")
|
|
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
|
|
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
|
|
}
|
|
|
|
// setup for reading the message on the other end
|
|
read := make(chan struct{})
|
|
go func() {
|
|
bufr := bufio.NewReader(net)
|
|
msg, err := readMsg(bufr)
|
|
if err != nil {
|
|
t.Errorf("read error: %v", err)
|
|
} else if msg.Code != 16 {
|
|
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
|
|
}
|
|
msg.Discard()
|
|
close(read)
|
|
}()
|
|
|
|
// test succcessful write
|
|
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
|
|
t.Errorf("expect no error for known protocol: %v", err)
|
|
}
|
|
select {
|
|
case <-read:
|
|
case err := <-peerErr:
|
|
t.Fatalf("peer stopped: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestPeerActivity(t *testing.T) {
|
|
// shorten inactivityTimeout while this test is running
|
|
oldT := inactivityTimeout
|
|
defer func() { inactivityTimeout = oldT }()
|
|
inactivityTimeout = 20 * time.Millisecond
|
|
|
|
net, peer, peerErr := testPeer([]Protocol{discard})
|
|
defer net.Close()
|
|
peer.startSubprotocols([]Cap{discard.cap()})
|
|
|
|
sub := peer.activity.Subscribe(time.Time{})
|
|
defer sub.Unsubscribe()
|
|
|
|
for i := 0; i < 6; i++ {
|
|
writeMsg(net, NewMsg(16))
|
|
select {
|
|
case <-sub.Chan():
|
|
case <-time.After(inactivityTimeout / 2):
|
|
t.Fatal("no event within ", inactivityTimeout/2)
|
|
case err := <-peerErr:
|
|
t.Fatal("peer error", err)
|
|
}
|
|
}
|
|
|
|
select {
|
|
case <-time.After(inactivityTimeout * 2):
|
|
case <-sub.Chan():
|
|
t.Fatal("got activity event while connection was inactive")
|
|
case err := <-peerErr:
|
|
t.Fatal("peer error", err)
|
|
}
|
|
}
|
|
|
|
func TestNewPeer(t *testing.T) {
|
|
caps := []Cap{{"foo", 2}, {"bar", 3}}
|
|
id := &peerId{}
|
|
p := NewPeer(id, caps)
|
|
if !reflect.DeepEqual(p.Caps(), caps) {
|
|
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
|
|
}
|
|
if p.Identity() != id {
|
|
t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id)
|
|
}
|
|
// Should not hang.
|
|
p.Disconnect(DiscAlreadyConnected)
|
|
}
|
|
|
|
func TestEOFSignal(t *testing.T) {
|
|
rb := make([]byte, 10)
|
|
|
|
// empty reader
|
|
eof := make(chan struct{}, 1)
|
|
sig := &eofSignal{new(bytes.Buffer), 0, eof}
|
|
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
|
}
|
|
select {
|
|
case <-eof:
|
|
default:
|
|
t.Error("EOF chan not signaled")
|
|
}
|
|
|
|
// count before error
|
|
eof = make(chan struct{}, 1)
|
|
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
|
|
if n, err := sig.Read(rb); n != 8 || err != nil {
|
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
|
}
|
|
select {
|
|
case <-eof:
|
|
default:
|
|
t.Error("EOF chan not signaled")
|
|
}
|
|
|
|
// error before count
|
|
eof = make(chan struct{}, 1)
|
|
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
|
|
if n, err := sig.Read(rb); n != 4 || err != nil {
|
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
|
}
|
|
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
|
}
|
|
select {
|
|
case <-eof:
|
|
default:
|
|
t.Error("EOF chan not signaled")
|
|
}
|
|
|
|
// no signal if neither occurs
|
|
eof = make(chan struct{}, 1)
|
|
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
|
|
if n, err := sig.Read(rb); n != 10 || err != nil {
|
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
|
}
|
|
select {
|
|
case <-eof:
|
|
t.Error("unexpected EOF signal")
|
|
default:
|
|
}
|
|
}
|