p2p: rework protocol API
This commit is contained in:
parent
8cf9ed0ea5
commit
f38052c499
@ -1,275 +0,0 @@
|
|||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
// "fmt"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/ethutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Connection struct {
|
|
||||||
conn net.Conn
|
|
||||||
// conn NetworkConnection
|
|
||||||
timeout time.Duration
|
|
||||||
in chan []byte
|
|
||||||
out chan []byte
|
|
||||||
err chan *PeerError
|
|
||||||
closingIn chan chan bool
|
|
||||||
closingOut chan chan bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// const readBufferLength = 2 //for testing
|
|
||||||
|
|
||||||
const readBufferLength = 1440
|
|
||||||
const partialsQueueSize = 10
|
|
||||||
const maxPendingQueueSize = 1
|
|
||||||
const defaultTimeout = 500
|
|
||||||
|
|
||||||
var magicToken = []byte{34, 64, 8, 145}
|
|
||||||
|
|
||||||
func (self *Connection) Open() {
|
|
||||||
go self.startRead()
|
|
||||||
go self.startWrite()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) Close() {
|
|
||||||
self.closeIn()
|
|
||||||
self.closeOut()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) closeIn() {
|
|
||||||
errc := make(chan bool)
|
|
||||||
self.closingIn <- errc
|
|
||||||
<-errc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) closeOut() {
|
|
||||||
errc := make(chan bool)
|
|
||||||
self.closingOut <- errc
|
|
||||||
<-errc
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection {
|
|
||||||
return &Connection{
|
|
||||||
conn: conn,
|
|
||||||
timeout: defaultTimeout,
|
|
||||||
in: make(chan []byte),
|
|
||||||
out: make(chan []byte),
|
|
||||||
err: errchan,
|
|
||||||
closingIn: make(chan chan bool, 1),
|
|
||||||
closingOut: make(chan chan bool, 1),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) Read() <-chan []byte {
|
|
||||||
return self.in
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) Write() chan<- []byte {
|
|
||||||
return self.out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) Error() <-chan *PeerError {
|
|
||||||
return self.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) startRead() {
|
|
||||||
payloads := make(chan []byte)
|
|
||||||
done := make(chan *PeerError)
|
|
||||||
pending := [][]byte{}
|
|
||||||
var head []byte
|
|
||||||
var wait time.Duration // initally 0 (no delay)
|
|
||||||
read := time.After(wait * time.Millisecond)
|
|
||||||
|
|
||||||
for {
|
|
||||||
// if pending empty, nil channel blocks
|
|
||||||
var in chan []byte
|
|
||||||
if len(pending) > 0 {
|
|
||||||
in = self.in // enable send case
|
|
||||||
head = pending[0]
|
|
||||||
} else {
|
|
||||||
in = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-read:
|
|
||||||
go self.read(payloads, done)
|
|
||||||
case err := <-done:
|
|
||||||
if err == nil { // no error but nothing to read
|
|
||||||
if len(pending) < maxPendingQueueSize {
|
|
||||||
wait = 100
|
|
||||||
} else if wait == 0 {
|
|
||||||
wait = 100
|
|
||||||
} else {
|
|
||||||
wait = 2 * wait
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
self.err <- err // report error
|
|
||||||
wait = 100
|
|
||||||
}
|
|
||||||
read = time.After(wait * time.Millisecond)
|
|
||||||
case payload := <-payloads:
|
|
||||||
pending = append(pending, payload)
|
|
||||||
if len(pending) < maxPendingQueueSize {
|
|
||||||
wait = 0
|
|
||||||
} else {
|
|
||||||
wait = 100
|
|
||||||
}
|
|
||||||
read = time.After(wait * time.Millisecond)
|
|
||||||
case in <- head:
|
|
||||||
pending = pending[1:]
|
|
||||||
case errc := <-self.closingIn:
|
|
||||||
errc <- true
|
|
||||||
close(self.in)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) startWrite() {
|
|
||||||
pending := [][]byte{}
|
|
||||||
done := make(chan *PeerError)
|
|
||||||
writing := false
|
|
||||||
for {
|
|
||||||
if len(pending) > 0 && !writing {
|
|
||||||
writing = true
|
|
||||||
go self.write(pending[0], done)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case payload := <-self.out:
|
|
||||||
pending = append(pending, payload)
|
|
||||||
case err := <-done:
|
|
||||||
if err == nil {
|
|
||||||
pending = pending[1:]
|
|
||||||
writing = false
|
|
||||||
} else {
|
|
||||||
self.err <- err // report error
|
|
||||||
}
|
|
||||||
case errc := <-self.closingOut:
|
|
||||||
errc <- true
|
|
||||||
close(self.out)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func pack(payload []byte) (packet []byte) {
|
|
||||||
length := ethutil.NumberToBytes(uint32(len(payload)), 32)
|
|
||||||
// return error if too long?
|
|
||||||
// Write magic token and payload length (first 8 bytes)
|
|
||||||
packet = append(magicToken, length...)
|
|
||||||
packet = append(packet, payload...)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func avoidPanic(done chan *PeerError) {
|
|
||||||
if rec := recover(); rec != nil {
|
|
||||||
err := NewPeerError(MiscError, " %v", rec)
|
|
||||||
logger.Debugln(err)
|
|
||||||
done <- err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) write(payload []byte, done chan *PeerError) {
|
|
||||||
defer avoidPanic(done)
|
|
||||||
var err *PeerError
|
|
||||||
_, ok := self.conn.Write(pack(payload))
|
|
||||||
if ok != nil {
|
|
||||||
err = NewPeerError(WriteError, " %v", ok)
|
|
||||||
logger.Debugln(err)
|
|
||||||
}
|
|
||||||
done <- err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) read(payloads chan []byte, done chan *PeerError) {
|
|
||||||
//defer avoidPanic(done)
|
|
||||||
|
|
||||||
partials := make(chan []byte, partialsQueueSize)
|
|
||||||
errc := make(chan *PeerError)
|
|
||||||
go self.readPartials(partials, errc)
|
|
||||||
|
|
||||||
packet := []byte{}
|
|
||||||
length := 8
|
|
||||||
start := true
|
|
||||||
var err *PeerError
|
|
||||||
out:
|
|
||||||
for {
|
|
||||||
// appends partials read via connection until packet is
|
|
||||||
// - either parseable (>=8bytes)
|
|
||||||
// - or complete (payload fully consumed)
|
|
||||||
for len(packet) < length {
|
|
||||||
partial, ok := <-partials
|
|
||||||
if !ok { // partials channel is closed
|
|
||||||
err = <-errc
|
|
||||||
if err == nil && len(packet) > 0 {
|
|
||||||
if start {
|
|
||||||
err = NewPeerError(PacketTooShort, "%v", packet)
|
|
||||||
} else {
|
|
||||||
err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break out
|
|
||||||
}
|
|
||||||
packet = append(packet, partial...)
|
|
||||||
}
|
|
||||||
if start {
|
|
||||||
// at least 8 bytes read, can validate packet
|
|
||||||
if bytes.Compare(magicToken, packet[:4]) != 0 {
|
|
||||||
err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4])
|
|
||||||
break
|
|
||||||
}
|
|
||||||
length = int(ethutil.BytesToNumber(packet[4:8]))
|
|
||||||
packet = packet[8:]
|
|
||||||
|
|
||||||
if length > 0 {
|
|
||||||
start = false // now consuming payload
|
|
||||||
} else { //penalize peer but read on
|
|
||||||
self.err <- NewPeerError(EmptyPayload, "")
|
|
||||||
length = 8
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// packet complete (payload fully consumed)
|
|
||||||
payloads <- packet[:length]
|
|
||||||
packet = packet[length:] // resclice packet
|
|
||||||
start = true
|
|
||||||
length = 8
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// this stops partials read via the connection, should we?
|
|
||||||
//if err != nil {
|
|
||||||
// select {
|
|
||||||
// case errc <- err
|
|
||||||
// default:
|
|
||||||
//}
|
|
||||||
done <- err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) {
|
|
||||||
defer close(partials)
|
|
||||||
for {
|
|
||||||
// Give buffering some time
|
|
||||||
self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond))
|
|
||||||
buffer := make([]byte, readBufferLength)
|
|
||||||
// read partial from connection
|
|
||||||
bytesRead, err := self.conn.Read(buffer)
|
|
||||||
if err == nil || err.Error() == "EOF" {
|
|
||||||
if bytesRead > 0 {
|
|
||||||
partials <- buffer[:bytesRead]
|
|
||||||
}
|
|
||||||
if err != nil && err.Error() == "EOF" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// unexpected error, report to errc
|
|
||||||
err := NewPeerError(ReadError, " %v", err)
|
|
||||||
logger.Debugln(err)
|
|
||||||
errc <- err
|
|
||||||
return // will close partials channel
|
|
||||||
}
|
|
||||||
}
|
|
||||||
close(errc)
|
|
||||||
}
|
|
@ -1,222 +0,0 @@
|
|||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TestNetworkConnection struct {
|
|
||||||
in chan []byte
|
|
||||||
current []byte
|
|
||||||
Out [][]byte
|
|
||||||
addr net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
|
|
||||||
return &TestNetworkConnection{
|
|
||||||
in: make(chan []byte),
|
|
||||||
current: []byte{},
|
|
||||||
Out: [][]byte{},
|
|
||||||
addr: addr,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
|
|
||||||
time.Sleep(latency)
|
|
||||||
for _, s := range packets {
|
|
||||||
self.in <- s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
|
|
||||||
if len(self.current) == 0 {
|
|
||||||
select {
|
|
||||||
case self.current = <-self.in:
|
|
||||||
default:
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
}
|
|
||||||
length := len(self.current)
|
|
||||||
if length > len(buff) {
|
|
||||||
copy(buff[:], self.current[:len(buff)])
|
|
||||||
self.current = self.current[len(buff):]
|
|
||||||
return len(buff), nil
|
|
||||||
} else {
|
|
||||||
copy(buff[:length], self.current[:])
|
|
||||||
self.current = []byte{}
|
|
||||||
return length, io.EOF
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
|
|
||||||
self.Out = append(self.Out, buff)
|
|
||||||
fmt.Printf("net write %v\n%v\n", len(self.Out), buff)
|
|
||||||
return len(buff), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) Close() (err error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
|
|
||||||
return self.addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupConnection() (*Connection, *TestNetworkConnection) {
|
|
||||||
addr := &TestAddr{"test:30303"}
|
|
||||||
net := NewTestNetworkConnection(addr)
|
|
||||||
conn := NewConnection(net, NewPeerErrorChannel())
|
|
||||||
conn.Open()
|
|
||||||
return conn, net
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingNilPacket(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{})
|
|
||||||
// time.Sleep(10 * time.Millisecond)
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
t.Errorf("read %v", packet)
|
|
||||||
case err := <-conn.Error():
|
|
||||||
t.Errorf("incorrect error %v", err)
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingShortPacket(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{0})
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
t.Errorf("read %v", packet)
|
|
||||||
case err := <-conn.Error():
|
|
||||||
if err.Code != PacketTooShort {
|
|
||||||
t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingInvalidPacket(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0})
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
t.Errorf("read %v", packet)
|
|
||||||
case err := <-conn.Error():
|
|
||||||
if err.Code != MagicTokenMismatch {
|
|
||||||
t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingInvalidPayload(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0})
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
t.Errorf("read %v", packet)
|
|
||||||
case err := <-conn.Error():
|
|
||||||
if err.Code != PayloadTooShort {
|
|
||||||
t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingEmptyPayload(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0})
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
t.Errorf("read %v", packet)
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case err := <-conn.Error():
|
|
||||||
code := err.Code
|
|
||||||
if code != EmptyPayload {
|
|
||||||
t.Errorf("incorrect error, expected EmptyPayload, got %v", code)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
t.Errorf("no error, expected EmptyPayload")
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingCompletePacket(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1})
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
if bytes.Compare(packet, []byte{1}) != 0 {
|
|
||||||
t.Errorf("incorrect payload read")
|
|
||||||
}
|
|
||||||
case err := <-conn.Error():
|
|
||||||
t.Errorf("incorrect error %v", err)
|
|
||||||
default:
|
|
||||||
t.Errorf("nothing read")
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingTwoCompletePackets(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1})
|
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
if bytes.Compare(packet, []byte{byte(i)}) != 0 {
|
|
||||||
t.Errorf("incorrect payload read")
|
|
||||||
}
|
|
||||||
case err := <-conn.Error():
|
|
||||||
t.Errorf("incorrect error %v", err)
|
|
||||||
default:
|
|
||||||
t.Errorf("nothing read")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriting(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
conn.Write() <- []byte{0}
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
if len(net.Out) == 0 {
|
|
||||||
t.Errorf("no output")
|
|
||||||
} else {
|
|
||||||
out := net.Out[0]
|
|
||||||
if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 {
|
|
||||||
t.Errorf("incorrect packet %v", out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243
|
|
205
p2p/message.go
205
p2p/message.go
@ -1,75 +1,174 @@
|
|||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
// "fmt"
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/big"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/ethutil"
|
"github.com/ethereum/go-ethereum/ethutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MsgCode uint8
|
type MsgCode uint64
|
||||||
|
|
||||||
|
// Msg defines the structure of a p2p message.
|
||||||
|
//
|
||||||
|
// Note that a Msg can only be sent once since the Payload reader is
|
||||||
|
// consumed during sending. It is not possible to create a Msg and
|
||||||
|
// send it any number of times. If you want to reuse an encoded
|
||||||
|
// structure, encode the payload into a byte array and create a
|
||||||
|
// separate Msg with a bytes.Reader as Payload for each send.
|
||||||
type Msg struct {
|
type Msg struct {
|
||||||
code MsgCode // this is the raw code as per adaptive msg code scheme
|
Code MsgCode
|
||||||
data *ethutil.Value
|
Size uint32 // size of the paylod
|
||||||
encoded []byte
|
Payload io.Reader
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Msg) Code() MsgCode {
|
// NewMsg creates an RLP-encoded message with the given code.
|
||||||
return self.code
|
func NewMsg(code MsgCode, params ...interface{}) Msg {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
for _, p := range params {
|
||||||
|
buf.Write(ethutil.Encode(p))
|
||||||
|
}
|
||||||
|
return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Msg) Data() *ethutil.Value {
|
func encodePayload(params ...interface{}) []byte {
|
||||||
return self.data
|
buf := new(bytes.Buffer)
|
||||||
|
for _, p := range params {
|
||||||
|
buf.Write(ethutil.Encode(p))
|
||||||
|
}
|
||||||
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) {
|
// Data returns the decoded RLP payload items in a message.
|
||||||
|
func (msg Msg) Data() (*ethutil.Value, error) {
|
||||||
|
// TODO: avoid copying when we have a better RLP decoder
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
var s []interface{}
|
||||||
|
if _, err := buf.ReadFrom(msg.Payload); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for buf.Len() > 0 {
|
||||||
|
s = append(s, ethutil.DecodeWithReader(buf))
|
||||||
|
}
|
||||||
|
return ethutil.NewValue(s), nil
|
||||||
|
}
|
||||||
|
|
||||||
// // data := [][]interface{}{}
|
// Discard reads any remaining payload data into a black hole.
|
||||||
// data := []interface{}{}
|
func (msg Msg) Discard() error {
|
||||||
// for _, value := range params {
|
_, err := io.Copy(ioutil.Discard, msg.Payload)
|
||||||
// if encodable, ok := value.(ethutil.RlpEncodeDecode); ok {
|
return err
|
||||||
// data = append(data, encodable.RlpValue())
|
}
|
||||||
// } else if raw, ok := value.([]interface{}); ok {
|
|
||||||
// data = append(data, raw)
|
var magicToken = []byte{34, 64, 8, 145}
|
||||||
// } else {
|
|
||||||
// // data = append(data, interface{}(raw))
|
func writeMsg(w io.Writer, msg Msg) error {
|
||||||
// err = fmt.Errorf("Unable to encode object of type %T", value)
|
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
|
||||||
// return
|
code := ethutil.Encode(uint32(msg.Code))
|
||||||
// }
|
listhdr := makeListHeader(msg.Size + uint32(len(code)))
|
||||||
// }
|
payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size
|
||||||
return &Msg{
|
|
||||||
code: code,
|
start := make([]byte, 8)
|
||||||
data: ethutil.NewValue(interface{}(params)),
|
copy(start, magicToken)
|
||||||
|
binary.BigEndian.PutUint32(start[4:], payloadLen)
|
||||||
|
|
||||||
|
for _, b := range [][]byte{start, listhdr, code} {
|
||||||
|
if _, err := w.Write(b); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err := io.CopyN(w, msg.Payload, int64(msg.Size))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeListHeader(length uint32) []byte {
|
||||||
|
if length < 56 {
|
||||||
|
return []byte{byte(length + 0xc0)}
|
||||||
|
}
|
||||||
|
enc := big.NewInt(int64(length)).Bytes()
|
||||||
|
lenb := byte(len(enc)) + 0xf7
|
||||||
|
return append([]byte{lenb}, enc...)
|
||||||
|
}
|
||||||
|
|
||||||
|
type byteReader interface {
|
||||||
|
io.Reader
|
||||||
|
io.ByteReader
|
||||||
|
}
|
||||||
|
|
||||||
|
// readMsg reads a message header.
|
||||||
|
func readMsg(r byteReader) (msg Msg, err error) {
|
||||||
|
// read magic and payload size
|
||||||
|
start := make([]byte, 8)
|
||||||
|
if _, err = io.ReadFull(r, start); err != nil {
|
||||||
|
return msg, NewPeerError(ReadError, "%v", err)
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(start, magicToken) {
|
||||||
|
return msg, NewPeerError(MagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
|
||||||
|
}
|
||||||
|
size := binary.BigEndian.Uint32(start[4:])
|
||||||
|
|
||||||
|
// decode start of RLP message to get the message code
|
||||||
|
_, hdrlen, err := readListHeader(r)
|
||||||
|
if err != nil {
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
code, codelen, err := readMsgCode(r)
|
||||||
|
if err != nil {
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rlpsize := size - hdrlen - codelen
|
||||||
|
return Msg{
|
||||||
|
Code: code,
|
||||||
|
Size: rlpsize,
|
||||||
|
Payload: io.LimitReader(r, int64(rlpsize)),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) {
|
// readListHeader reads an RLP list header from r.
|
||||||
value := ethutil.NewValueFromBytes(encoded)
|
func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) {
|
||||||
// Type of message
|
b, err := r.ReadByte()
|
||||||
code := value.Get(0).Uint()
|
if err != nil {
|
||||||
// Actual data
|
return 0, 0, err
|
||||||
data := value.SliceFrom(1)
|
|
||||||
|
|
||||||
msg = &Msg{
|
|
||||||
code: MsgCode(code),
|
|
||||||
data: data,
|
|
||||||
// data: ethutil.NewValue(data),
|
|
||||||
encoded: encoded,
|
|
||||||
}
|
}
|
||||||
return
|
if b < 0xC0 {
|
||||||
}
|
return 0, 0, fmt.Errorf("expected list start byte >= 0xC0, got %x", b)
|
||||||
|
} else if b < 0xF7 {
|
||||||
func (self *Msg) Decode(offset MsgCode) {
|
len = uint64(b - 0xc0)
|
||||||
self.code = self.code - offset
|
hdrlen = 1
|
||||||
}
|
|
||||||
|
|
||||||
// encode takes an offset argument to implement adaptive message coding
|
|
||||||
// the encoded message is memoized to make msgs relayed to several peers more efficient
|
|
||||||
func (self *Msg) Encode(offset MsgCode) (res []byte) {
|
|
||||||
if len(self.encoded) == 0 {
|
|
||||||
res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode()
|
|
||||||
self.encoded = res
|
|
||||||
} else {
|
} else {
|
||||||
res = self.encoded
|
lenlen := b - 0xF7
|
||||||
|
lenbuf := make([]byte, 8)
|
||||||
|
if _, err := io.ReadFull(r, lenbuf[8-lenlen:]); err != nil {
|
||||||
|
return 0, 0, err
|
||||||
}
|
}
|
||||||
return
|
len = binary.BigEndian.Uint64(lenbuf)
|
||||||
|
hdrlen = 1 + uint32(lenlen)
|
||||||
|
}
|
||||||
|
return len, hdrlen, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readUint reads an RLP-encoded unsigned integer from r.
|
||||||
|
func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) {
|
||||||
|
b, err := r.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
if b < 0x80 {
|
||||||
|
return MsgCode(b), 1, nil
|
||||||
|
} else if b < 0x89 { // max length for uint64 is 8 bytes
|
||||||
|
codelen = uint32(b - 0x80)
|
||||||
|
if codelen == 0 {
|
||||||
|
return 0, 1, nil
|
||||||
|
}
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
return MsgCode(binary.BigEndian.Uint64(buf)), codelen, nil
|
||||||
|
}
|
||||||
|
return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b)
|
||||||
}
|
}
|
||||||
|
@ -1,38 +1,67 @@
|
|||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io/ioutil"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/ethutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewMsg(t *testing.T) {
|
func TestNewMsg(t *testing.T) {
|
||||||
msg, _ := NewMsg(3, 1, "000")
|
msg := NewMsg(3, 1, "000")
|
||||||
if msg.Code() != 3 {
|
if msg.Code != 3 {
|
||||||
t.Errorf("incorrect code %v", msg.Code())
|
t.Errorf("incorrect code %d, want %d", msg.Code)
|
||||||
}
|
}
|
||||||
data0 := msg.Data().Get(0).Uint()
|
if msg.Size != 5 {
|
||||||
data1 := string(msg.Data().Get(1).Bytes())
|
t.Errorf("incorrect size %d, want %d", msg.Size, 5)
|
||||||
if data0 != 1 {
|
|
||||||
t.Errorf("incorrect data %v", data0)
|
|
||||||
}
|
}
|
||||||
if data1 != "000" {
|
pl, _ := ioutil.ReadAll(msg.Payload)
|
||||||
t.Errorf("incorrect data %v", data1)
|
expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30}
|
||||||
|
if !bytes.Equal(pl, expect) {
|
||||||
|
t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEncodeDecodeMsg(t *testing.T) {
|
func TestEncodeDecodeMsg(t *testing.T) {
|
||||||
msg, _ := NewMsg(3, 1, "000")
|
msg := NewMsg(3, 1, "000")
|
||||||
encoded := msg.Encode(3)
|
buf := new(bytes.Buffer)
|
||||||
msg, _ = NewMsgFromBytes(encoded)
|
if err := writeMsg(buf, msg); err != nil {
|
||||||
msg.Decode(3)
|
t.Fatalf("encodeMsg error: %v", err)
|
||||||
if msg.Code() != 3 {
|
|
||||||
t.Errorf("incorrect code %v", msg.Code())
|
|
||||||
}
|
}
|
||||||
data0 := msg.Data().Get(0).Uint()
|
|
||||||
data1 := msg.Data().Get(1).Str()
|
t.Logf("encoded: %x", buf.Bytes())
|
||||||
if data0 != 1 {
|
|
||||||
t.Errorf("incorrect data %v", data0)
|
decmsg, err := readMsg(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("readMsg error: %v", err)
|
||||||
}
|
}
|
||||||
if data1 != "000" {
|
if decmsg.Code != 3 {
|
||||||
t.Errorf("incorrect data %v", data1)
|
t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
|
||||||
|
}
|
||||||
|
if decmsg.Size != 5 {
|
||||||
|
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
|
||||||
|
}
|
||||||
|
data, err := decmsg.Data()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first payload item decode error: %v", err)
|
||||||
|
}
|
||||||
|
if v := data.Get(0).Uint(); v != 1 {
|
||||||
|
t.Errorf("incorrect data[0]: got %v, expected %d", v, 1)
|
||||||
|
}
|
||||||
|
if v := data.Get(1).Str(); v != "000" {
|
||||||
|
t.Errorf("incorrect data[1]: got %q, expected %q", v, "000")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeRealMsg(t *testing.T) {
|
||||||
|
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
|
||||||
|
msg, err := readMsg(bytes.NewReader(data))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.Code != 0 {
|
||||||
|
t.Errorf("incorrect code %d, want %d", msg.Code, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
331
p2p/messenger.go
331
p2p/messenger.go
@ -1,220 +1,221 @@
|
|||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
type Handlers map[string]func() Protocol
|
||||||
handlerTimeout = 1000
|
|
||||||
)
|
|
||||||
|
|
||||||
type Handlers map[string](func(p *Peer) Protocol)
|
type proto struct {
|
||||||
|
in chan Msg
|
||||||
|
maxcode, offset MsgCode
|
||||||
|
messenger *messenger
|
||||||
|
}
|
||||||
|
|
||||||
type Messenger struct {
|
func (rw *proto) WriteMsg(msg Msg) error {
|
||||||
conn *Connection
|
if msg.Code >= rw.maxcode {
|
||||||
|
return NewPeerError(InvalidMsgCode, "not handled")
|
||||||
|
}
|
||||||
|
return rw.messenger.writeMsg(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *proto) ReadMsg() (Msg, error) {
|
||||||
|
msg, ok := <-rw.in
|
||||||
|
if !ok {
|
||||||
|
return msg, io.EOF
|
||||||
|
}
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// eofSignal is used to 'lend' the network connection
|
||||||
|
// to a protocol. when the protocol's read loop has read the
|
||||||
|
// whole payload, the done channel is closed.
|
||||||
|
type eofSignal struct {
|
||||||
|
wrapped io.Reader
|
||||||
|
eof chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *eofSignal) Read(buf []byte) (int, error) {
|
||||||
|
n, err := r.wrapped.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
close(r.eof) // tell messenger that msg has been consumed
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// messenger represents a message-oriented peer connection.
|
||||||
|
// It keeps track of the set of protocols understood
|
||||||
|
// by the remote peer.
|
||||||
|
type messenger struct {
|
||||||
peer *Peer
|
peer *Peer
|
||||||
handlers Handlers
|
handlers Handlers
|
||||||
|
|
||||||
|
// the mutex protects the connection
|
||||||
|
// so only one protocol can write at a time.
|
||||||
|
writeMu sync.Mutex
|
||||||
|
conn net.Conn
|
||||||
|
bufconn *bufio.ReadWriter
|
||||||
|
|
||||||
protocolLock sync.RWMutex
|
protocolLock sync.RWMutex
|
||||||
protocols []Protocol
|
protocols map[string]*proto
|
||||||
offsets []MsgCode // offsets for adaptive message idss
|
offsets map[MsgCode]*proto
|
||||||
protocolTable map[string]int
|
protoWG sync.WaitGroup
|
||||||
quit chan chan bool
|
|
||||||
err chan *PeerError
|
err chan error
|
||||||
pulse chan bool
|
pulse chan bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger {
|
func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger {
|
||||||
baseProtocol := NewBaseProtocol(peer)
|
return &messenger{
|
||||||
return &Messenger{
|
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
|
||||||
peer: peer,
|
peer: peer,
|
||||||
offsets: []MsgCode{baseProtocol.Offset()},
|
|
||||||
handlers: handlers,
|
handlers: handlers,
|
||||||
protocols: []Protocol{baseProtocol},
|
protocols: make(map[string]*proto),
|
||||||
protocolTable: make(map[string]int),
|
|
||||||
err: errchan,
|
err: errchan,
|
||||||
pulse: make(chan bool, 1),
|
pulse: make(chan bool, 1),
|
||||||
quit: make(chan chan bool, 1),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Messenger) Start() {
|
func (m *messenger) Start() {
|
||||||
self.conn.Open()
|
m.protocols[""] = m.startProto(0, "", &baseProtocol{})
|
||||||
go self.messenger()
|
go m.readLoop()
|
||||||
self.protocolLock.RLock()
|
|
||||||
defer self.protocolLock.RUnlock()
|
|
||||||
self.protocols[0].Start()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Messenger) Stop() {
|
func (m *messenger) Stop() {
|
||||||
// close pulse to stop ping pong monitoring
|
m.conn.Close()
|
||||||
close(self.pulse)
|
m.protoWG.Wait()
|
||||||
self.protocolLock.RLock()
|
|
||||||
defer self.protocolLock.RUnlock()
|
|
||||||
for _, protocol := range self.protocols {
|
|
||||||
protocol.Stop() // could be parallel
|
|
||||||
}
|
|
||||||
q := make(chan bool)
|
|
||||||
self.quit <- q
|
|
||||||
<-q
|
|
||||||
self.conn.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Messenger) messenger() {
|
const (
|
||||||
in := self.conn.Read()
|
// maximum amount of time allowed for reading a message
|
||||||
|
msgReadTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// messages smaller than this many bytes will be read at
|
||||||
|
// once before passing them to a protocol.
|
||||||
|
wholePayloadSize = 64 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m *messenger) readLoop() {
|
||||||
|
defer m.closeProtocols()
|
||||||
for {
|
for {
|
||||||
select {
|
m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
|
||||||
case payload, ok := <-in:
|
msg, err := readMsg(m.bufconn)
|
||||||
//dispatches message to the protocol asynchronously
|
if err != nil {
|
||||||
if ok {
|
m.err <- err
|
||||||
go self.handle(payload)
|
|
||||||
} else {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case q := <-self.quit:
|
|
||||||
q <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handles each message by dispatching to the appropriate protocol
|
|
||||||
// using adaptive message codes
|
|
||||||
// this function is started as a separate go routine for each message
|
|
||||||
// it waits for the protocol response
|
|
||||||
// then encodes and sends outgoing messages to the connection's write channel
|
|
||||||
func (self *Messenger) handle(payload []byte) {
|
|
||||||
// send ping to heartbeat channel signalling time of last message
|
// send ping to heartbeat channel signalling time of last message
|
||||||
// select {
|
m.pulse <- true
|
||||||
// case self.pulse <- true:
|
proto, err := m.getProto(msg.Code)
|
||||||
// default:
|
|
||||||
// }
|
|
||||||
self.pulse <- true
|
|
||||||
// initialise message from payload
|
|
||||||
msg, err := NewMsgFromBytes(payload)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
self.err <- NewPeerError(MiscError, " %v", err)
|
m.err <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// retrieves protocol based on message Code
|
msg.Code -= proto.offset
|
||||||
protocol, offset, peerErr := self.getProtocol(msg.Code())
|
if msg.Size <= wholePayloadSize {
|
||||||
|
// optimization: msg is small enough, read all
|
||||||
|
// of it and move on to the next message
|
||||||
|
buf, err := ioutil.ReadAll(msg.Payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
self.err <- peerErr
|
m.err <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// reset message code based on adaptive offset
|
msg.Payload = bytes.NewReader(buf)
|
||||||
msg.Decode(offset)
|
proto.in <- msg
|
||||||
// dispatches
|
|
||||||
response := make(chan *Msg)
|
|
||||||
go protocol.HandleIn(msg, response)
|
|
||||||
// protocol reponse timeout to prevent leaks
|
|
||||||
timer := time.After(handlerTimeout * time.Millisecond)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case outgoing, ok := <-response:
|
|
||||||
// we check if response channel is not closed
|
|
||||||
if ok {
|
|
||||||
self.conn.Write() <- outgoing.Encode(offset)
|
|
||||||
} else {
|
} else {
|
||||||
return
|
pr := &eofSignal{msg.Payload, make(chan struct{})}
|
||||||
}
|
msg.Payload = pr
|
||||||
case <-timer:
|
proto.in <- msg
|
||||||
return
|
<-pr.eof
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// negotiated protocols
|
func (m *messenger) closeProtocols() {
|
||||||
// stores offsets needed for adaptive message id scheme
|
m.protocolLock.RLock()
|
||||||
|
for _, p := range m.protocols {
|
||||||
// based on offsets set at handshake
|
close(p.in)
|
||||||
// get the right protocol to handle the message
|
|
||||||
func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) {
|
|
||||||
self.protocolLock.RLock()
|
|
||||||
defer self.protocolLock.RUnlock()
|
|
||||||
base := MsgCode(0)
|
|
||||||
for index, offset := range self.offsets {
|
|
||||||
if code < offset {
|
|
||||||
return self.protocols[index], base, nil
|
|
||||||
}
|
}
|
||||||
base = offset
|
m.protocolLock.RUnlock()
|
||||||
}
|
|
||||||
return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) {
|
func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto {
|
||||||
fmt.Printf("pingpong keepalive started at %v", time.Now())
|
proto := &proto{
|
||||||
|
in: make(chan Msg),
|
||||||
timer := time.After(timeout)
|
offset: offset,
|
||||||
pinged := false
|
maxcode: impl.Offset(),
|
||||||
for {
|
messenger: m,
|
||||||
select {
|
|
||||||
case _, ok := <-self.pulse:
|
|
||||||
if ok {
|
|
||||||
pinged = false
|
|
||||||
timer = time.After(timeout)
|
|
||||||
} else {
|
|
||||||
// pulse is closed, stop monitoring
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case <-timer:
|
|
||||||
if pinged {
|
|
||||||
fmt.Printf("timeout at %v", time.Now())
|
|
||||||
timeoutCallback()
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
fmt.Printf("pinged at %v", time.Now())
|
|
||||||
pingCallback()
|
|
||||||
timer = time.After(gracePeriod)
|
|
||||||
pinged = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
m.protoWG.Add(1)
|
||||||
|
go func() {
|
||||||
|
if err := impl.Start(m.peer, proto); err != nil && err != io.EOF {
|
||||||
|
logger.Errorf("protocol %q error: %v\n", name, err)
|
||||||
|
m.err <- err
|
||||||
}
|
}
|
||||||
|
m.protoWG.Done()
|
||||||
|
}()
|
||||||
|
return proto
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Messenger) AddProtocols(protocols []string) {
|
// getProto finds the protocol responsible for handling
|
||||||
self.protocolLock.Lock()
|
// the given message code.
|
||||||
defer self.protocolLock.Unlock()
|
func (m *messenger) getProto(code MsgCode) (*proto, error) {
|
||||||
i := len(self.offsets)
|
m.protocolLock.RLock()
|
||||||
offset := self.offsets[i-1]
|
defer m.protocolLock.RUnlock()
|
||||||
|
for _, proto := range m.protocols {
|
||||||
|
if code >= proto.offset && code < proto.offset+proto.maxcode {
|
||||||
|
return proto, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, NewPeerError(InvalidMsgCode, "%d", code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// setProtocols starts all subprotocols shared with the
|
||||||
|
// remote peer. the protocols must be sorted alphabetically.
|
||||||
|
func (m *messenger) setRemoteProtocols(protocols []string) {
|
||||||
|
m.protocolLock.Lock()
|
||||||
|
defer m.protocolLock.Unlock()
|
||||||
|
offset := baseProtocolOffset
|
||||||
for _, name := range protocols {
|
for _, name := range protocols {
|
||||||
protocolFunc, ok := self.handlers[name]
|
protocolFunc, ok := m.handlers[name]
|
||||||
if ok {
|
|
||||||
protocol := protocolFunc(self.peer)
|
|
||||||
self.protocolTable[name] = i
|
|
||||||
i++
|
|
||||||
offset += protocol.Offset()
|
|
||||||
fmt.Println("offset ", name, offset)
|
|
||||||
|
|
||||||
self.offsets = append(self.offsets, offset)
|
|
||||||
self.protocols = append(self.protocols, protocol)
|
|
||||||
protocol.Start()
|
|
||||||
} else {
|
|
||||||
fmt.Println("no ", name)
|
|
||||||
// protocol not handled
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Messenger) Write(protocol string, msg *Msg) error {
|
|
||||||
self.protocolLock.RLock()
|
|
||||||
defer self.protocolLock.RUnlock()
|
|
||||||
i := 0
|
|
||||||
offset := MsgCode(0)
|
|
||||||
if len(protocol) > 0 {
|
|
||||||
var ok bool
|
|
||||||
i, ok = self.protocolTable[protocol]
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("protocol %v not handled by peer", protocol)
|
continue // not handled
|
||||||
}
|
}
|
||||||
offset = self.offsets[i-1]
|
inst := protocolFunc()
|
||||||
|
m.protocols[name] = m.startProto(offset, name, inst)
|
||||||
|
offset += inst.Offset()
|
||||||
}
|
}
|
||||||
handler := self.protocols[i]
|
}
|
||||||
// checking if protocol status/caps allows the message to be sent out
|
|
||||||
if handler.HandleOut(msg) {
|
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
||||||
self.conn.Write() <- msg.Encode(offset)
|
func (m *messenger) writeProtoMsg(protoName string, msg Msg) error {
|
||||||
}
|
m.protocolLock.RLock()
|
||||||
return nil
|
proto, ok := m.protocols[protoName]
|
||||||
|
m.protocolLock.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("protocol %s not handled by peer", protoName)
|
||||||
|
}
|
||||||
|
if msg.Code >= proto.maxcode {
|
||||||
|
return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
|
||||||
|
}
|
||||||
|
msg.Code += proto.offset
|
||||||
|
return m.writeMsg(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeMsg writes a message to the connection.
|
||||||
|
func (m *messenger) writeMsg(msg Msg) error {
|
||||||
|
m.writeMu.Lock()
|
||||||
|
defer m.writeMu.Unlock()
|
||||||
|
if err := writeMsg(m.bufconn, msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return m.bufconn.Flush()
|
||||||
}
|
}
|
||||||
|
@ -1,147 +1,157 @@
|
|||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
// "fmt"
|
"bufio"
|
||||||
"bytes"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/ethutil"
|
"github.com/ethereum/go-ethereum/ethutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) {
|
func init() {
|
||||||
errchan := NewPeerErrorChannel()
|
ethlog.AddLogSystem(ethlog.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlog.DebugLevel))
|
||||||
addr := &TestAddr{"test:30303"}
|
|
||||||
net := NewTestNetworkConnection(addr)
|
|
||||||
conn := NewConnection(net, errchan)
|
|
||||||
mess := NewMessenger(nil, conn, errchan, handlers)
|
|
||||||
mess.Start()
|
|
||||||
return net, errchan, mess
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TestProtocol struct {
|
func setupMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
|
||||||
Msgs []*Msg
|
conn1, conn2 := net.Pipe()
|
||||||
|
id := NewSimpleClientIdentity("test", "0", "0", "public key")
|
||||||
|
server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist())
|
||||||
|
peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0)
|
||||||
|
return conn2, peer, peer.messenger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *TestProtocol) Start() {
|
func performTestHandshake(r *bufio.Reader, w io.Writer) error {
|
||||||
}
|
// read remote handshake
|
||||||
|
msg, err := readMsg(r)
|
||||||
func (self *TestProtocol) Stop() {
|
if err != nil {
|
||||||
}
|
return fmt.Errorf("read error: %v", err)
|
||||||
|
|
||||||
func (self *TestProtocol) Offset() MsgCode {
|
|
||||||
return MsgCode(5)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) {
|
|
||||||
self.Msgs = append(self.Msgs, msg)
|
|
||||||
close(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestProtocol) HandleOut(msg *Msg) bool {
|
|
||||||
if msg.Code() > 3 {
|
|
||||||
return false
|
|
||||||
} else {
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
if msg.Code != handshakeMsg {
|
||||||
|
return fmt.Errorf("first message should be handshake, got %x", msg.Code)
|
||||||
|
}
|
||||||
|
if err := msg.Discard(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// send empty handshake
|
||||||
|
pubkey := make([]byte, 64)
|
||||||
|
msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey)
|
||||||
|
return writeMsg(w, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *TestProtocol) Name() string {
|
type testMsg struct {
|
||||||
return "a"
|
code MsgCode
|
||||||
|
data *ethutil.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte {
|
type testProto struct {
|
||||||
msg, _ := NewMsg(code, params...)
|
recv chan testMsg
|
||||||
encoded := msg.Encode(offset)
|
}
|
||||||
packet := []byte{34, 64, 8, 145}
|
|
||||||
packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...)
|
func (*testProto) Offset() MsgCode { return 5 }
|
||||||
return append(packet, encoded...)
|
|
||||||
|
func (tp *testProto) Start(peer *Peer, rw MsgReadWriter) error {
|
||||||
|
return MsgLoop(rw, 1024, func(code MsgCode, data *ethutil.Value) error {
|
||||||
|
logger.Debugf("testprotocol got msg: %d\n", code)
|
||||||
|
tp.recv <- testMsg{code, data}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRead(t *testing.T) {
|
func TestRead(t *testing.T) {
|
||||||
handlers := make(Handlers)
|
testProtocol := &testProto{make(chan testMsg)}
|
||||||
testProtocol := &TestProtocol{Msgs: []*Msg{}}
|
handlers := Handlers{"a": func() Protocol { return testProtocol }}
|
||||||
handlers["a"] = func(p *Peer) Protocol { return testProtocol }
|
net, peer, mess := setupMessenger(handlers)
|
||||||
net, _, mess := setupMessenger(handlers)
|
bufr := bufio.NewReader(net)
|
||||||
mess.AddProtocols([]string{"a"})
|
defer peer.Stop()
|
||||||
defer mess.Stop()
|
if err := performTestHandshake(bufr, net); err != nil {
|
||||||
wait := 1 * time.Millisecond
|
t.Fatalf("handshake failed: %v", err)
|
||||||
packet := Packet(16, 1, uint32(1), "000")
|
|
||||||
go net.In(0, packet)
|
|
||||||
time.Sleep(wait)
|
|
||||||
if len(testProtocol.Msgs) != 1 {
|
|
||||||
t.Errorf("msg not relayed to correct protocol")
|
|
||||||
} else {
|
|
||||||
if testProtocol.Msgs[0].Code() != 1 {
|
|
||||||
t.Errorf("incorrect msg code relayed to protocol")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mess.setRemoteProtocols([]string{"a"})
|
||||||
|
writeMsg(net, NewMsg(17, uint32(1), "000"))
|
||||||
|
select {
|
||||||
|
case msg := <-testProtocol.recv:
|
||||||
|
if msg.code != 1 {
|
||||||
|
t.Errorf("incorrect msg code %d relayed to protocol", msg.code)
|
||||||
|
}
|
||||||
|
expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
|
||||||
|
if !reflect.DeepEqual(msg.data.Slice(), expdata) {
|
||||||
|
t.Errorf("incorrect msg data %#v", msg.data.Slice())
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Errorf("receive timeout")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrite(t *testing.T) {
|
func TestWriteProtoMsg(t *testing.T) {
|
||||||
handlers := make(Handlers)
|
handlers := make(Handlers)
|
||||||
testProtocol := &TestProtocol{Msgs: []*Msg{}}
|
testProtocol := &testProto{recv: make(chan testMsg, 1)}
|
||||||
handlers["a"] = func(p *Peer) Protocol { return testProtocol }
|
handlers["a"] = func() Protocol { return testProtocol }
|
||||||
net, _, mess := setupMessenger(handlers)
|
net, peer, mess := setupMessenger(handlers)
|
||||||
mess.AddProtocols([]string{"a"})
|
defer peer.Stop()
|
||||||
defer mess.Stop()
|
bufr := bufio.NewReader(net)
|
||||||
wait := 1 * time.Millisecond
|
if err := performTestHandshake(bufr, net); err != nil {
|
||||||
msg, _ := NewMsg(3, uint32(1), "000")
|
t.Fatalf("handshake failed: %v", err)
|
||||||
err := mess.Write("b", msg)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("expect error for unknown protocol")
|
|
||||||
}
|
}
|
||||||
err = mess.Write("a", msg)
|
mess.setRemoteProtocols([]string{"a"})
|
||||||
if err != nil {
|
|
||||||
|
// test write errors
|
||||||
|
if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil {
|
||||||
|
t.Errorf("expected error for unknown protocol, got nil")
|
||||||
|
}
|
||||||
|
if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil {
|
||||||
|
t.Errorf("expected error for out-of-range msg code, got nil")
|
||||||
|
} else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode {
|
||||||
|
t.Errorf("wrong error for out-of-range msg code, got %#v")
|
||||||
|
}
|
||||||
|
|
||||||
|
// test succcessful write
|
||||||
|
read, readerr := make(chan Msg), make(chan error)
|
||||||
|
go func() {
|
||||||
|
if msg, err := readMsg(bufr); err != nil {
|
||||||
|
readerr <- err
|
||||||
|
} else {
|
||||||
|
read <- msg
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if err := mess.writeProtoMsg("a", NewMsg(3)); err != nil {
|
||||||
t.Errorf("expect no error for known protocol: %v", err)
|
t.Errorf("expect no error for known protocol: %v", err)
|
||||||
} else {
|
|
||||||
time.Sleep(wait)
|
|
||||||
if len(net.Out) != 1 {
|
|
||||||
t.Errorf("msg not written")
|
|
||||||
} else {
|
|
||||||
out := net.Out[0]
|
|
||||||
packet := Packet(16, 3, uint32(1), "000")
|
|
||||||
if bytes.Compare(out, packet) != 0 {
|
|
||||||
t.Errorf("incorrect packet %v", out)
|
|
||||||
}
|
}
|
||||||
|
select {
|
||||||
|
case msg := <-read:
|
||||||
|
if msg.Code != 19 {
|
||||||
|
t.Errorf("wrong code, got %d, expected %d", msg.Code, 19)
|
||||||
}
|
}
|
||||||
|
msg.Discard()
|
||||||
|
case err := <-readerr:
|
||||||
|
t.Errorf("read error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPulse(t *testing.T) {
|
func TestPulse(t *testing.T) {
|
||||||
net, _, mess := setupMessenger(make(Handlers))
|
net, peer, _ := setupMessenger(nil)
|
||||||
defer mess.Stop()
|
defer peer.Stop()
|
||||||
ping := false
|
bufr := bufio.NewReader(net)
|
||||||
timeout := false
|
if err := performTestHandshake(bufr, net); err != nil {
|
||||||
pingTimeout := 10 * time.Millisecond
|
t.Fatalf("handshake failed: %v", err)
|
||||||
gracePeriod := 200 * time.Millisecond
|
|
||||||
go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true })
|
|
||||||
net.In(0, Packet(0, 1))
|
|
||||||
if ping {
|
|
||||||
t.Errorf("ping sent too early")
|
|
||||||
}
|
}
|
||||||
time.Sleep(pingTimeout + 100*time.Millisecond)
|
|
||||||
if !ping {
|
before := time.Now()
|
||||||
t.Errorf("no ping sent after timeout")
|
msg, err := readMsg(bufr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read error: %v", err)
|
||||||
}
|
}
|
||||||
if timeout {
|
after := time.Now()
|
||||||
t.Errorf("timeout too early")
|
if msg.Code != pingMsg {
|
||||||
|
t.Errorf("expected ping message, got %x", msg.Code)
|
||||||
}
|
}
|
||||||
ping = false
|
if d := after.Sub(before); d < pingTimeout {
|
||||||
net.In(0, Packet(0, 1))
|
t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout)
|
||||||
time.Sleep(pingTimeout + 100*time.Millisecond)
|
|
||||||
if !ping {
|
|
||||||
t.Errorf("no ping sent after timeout")
|
|
||||||
}
|
|
||||||
if timeout {
|
|
||||||
t.Errorf("timeout too early")
|
|
||||||
}
|
|
||||||
ping = false
|
|
||||||
time.Sleep(gracePeriod)
|
|
||||||
if ping {
|
|
||||||
t.Errorf("ping called twice")
|
|
||||||
}
|
|
||||||
if !timeout {
|
|
||||||
t.Errorf("no timeout after grace period")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
29
p2p/peer.go
29
p2p/peer.go
@ -7,7 +7,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
// quit chan chan bool
|
|
||||||
Inbound bool // inbound (via listener) or outbound (via dialout)
|
Inbound bool // inbound (via listener) or outbound (via dialout)
|
||||||
Address net.Addr
|
Address net.Addr
|
||||||
Host []byte
|
Host []byte
|
||||||
@ -15,24 +14,12 @@ type Peer struct {
|
|||||||
Pubkey []byte
|
Pubkey []byte
|
||||||
Id string
|
Id string
|
||||||
Caps []string
|
Caps []string
|
||||||
peerErrorChan chan *PeerError
|
peerErrorChan chan error
|
||||||
messenger *Messenger
|
messenger *messenger
|
||||||
peerErrorHandler *PeerErrorHandler
|
peerErrorHandler *PeerErrorHandler
|
||||||
server *Server
|
server *Server
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Peer) Messenger() *Messenger {
|
|
||||||
return self.messenger
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Peer) PeerErrorChan() chan *PeerError {
|
|
||||||
return self.peerErrorChan
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Peer) Server() *Server {
|
|
||||||
return self.server
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer {
|
func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer {
|
||||||
peerErrorChan := NewPeerErrorChannel()
|
peerErrorChan := NewPeerErrorChannel()
|
||||||
host, port, _ := net.SplitHostPort(address.String())
|
host, port, _ := net.SplitHostPort(address.String())
|
||||||
@ -45,9 +32,8 @@ func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Pee
|
|||||||
peerErrorChan: peerErrorChan,
|
peerErrorChan: peerErrorChan,
|
||||||
server: server,
|
server: server,
|
||||||
}
|
}
|
||||||
connection := NewConnection(conn, peerErrorChan)
|
peer.messenger = newMessenger(peer, conn, peerErrorChan, server.Handlers())
|
||||||
peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers())
|
peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan)
|
||||||
peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist())
|
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,8 +47,8 @@ func (self *Peer) String() string {
|
|||||||
return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps)
|
return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Peer) Write(protocol string, msg *Msg) error {
|
func (self *Peer) Write(protocol string, msg Msg) error {
|
||||||
return self.messenger.Write(protocol, msg)
|
return self.messenger.writeProtoMsg(protocol, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Peer) Start() {
|
func (self *Peer) Start() {
|
||||||
@ -73,9 +59,6 @@ func (self *Peer) Start() {
|
|||||||
func (self *Peer) Stop() {
|
func (self *Peer) Stop() {
|
||||||
self.peerErrorHandler.Stop()
|
self.peerErrorHandler.Stop()
|
||||||
self.messenger.Stop()
|
self.messenger.Stop()
|
||||||
// q := make(chan bool)
|
|
||||||
// self.quit <- q
|
|
||||||
// <-q
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) Encode() []interface{} {
|
func (p *Peer) Encode() []interface{} {
|
||||||
|
@ -9,10 +9,9 @@ type ErrorCode int
|
|||||||
const errorChanCapacity = 10
|
const errorChanCapacity = 10
|
||||||
|
|
||||||
const (
|
const (
|
||||||
PacketTooShort = iota
|
PacketTooLong = iota
|
||||||
PayloadTooShort
|
PayloadTooShort
|
||||||
MagicTokenMismatch
|
MagicTokenMismatch
|
||||||
EmptyPayload
|
|
||||||
ReadError
|
ReadError
|
||||||
WriteError
|
WriteError
|
||||||
MiscError
|
MiscError
|
||||||
@ -31,10 +30,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var errorToString = map[ErrorCode]string{
|
var errorToString = map[ErrorCode]string{
|
||||||
PacketTooShort: "Packet too short",
|
PacketTooLong: "Packet too long",
|
||||||
PayloadTooShort: "Payload too short",
|
PayloadTooShort: "Payload too short",
|
||||||
MagicTokenMismatch: "Magic token mismatch",
|
MagicTokenMismatch: "Magic token mismatch",
|
||||||
EmptyPayload: "Empty payload",
|
|
||||||
ReadError: "Read error",
|
ReadError: "Read error",
|
||||||
WriteError: "Write error",
|
WriteError: "Write error",
|
||||||
MiscError: "Misc error",
|
MiscError: "Misc error",
|
||||||
@ -71,6 +69,6 @@ func (self *PeerError) Error() string {
|
|||||||
return self.message
|
return self.message
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPeerErrorChannel() chan *PeerError {
|
func NewPeerErrorChannel() chan error {
|
||||||
return make(chan *PeerError, errorChanCapacity)
|
return make(chan error, errorChanCapacity)
|
||||||
}
|
}
|
||||||
|
@ -18,17 +18,15 @@ type PeerErrorHandler struct {
|
|||||||
address net.Addr
|
address net.Addr
|
||||||
peerDisconnect chan DisconnectRequest
|
peerDisconnect chan DisconnectRequest
|
||||||
severity int
|
severity int
|
||||||
peerErrorChan chan *PeerError
|
errc chan error
|
||||||
blacklist Blacklist
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler {
|
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, errc chan error) *PeerErrorHandler {
|
||||||
return &PeerErrorHandler{
|
return &PeerErrorHandler{
|
||||||
quit: make(chan chan bool),
|
quit: make(chan chan bool),
|
||||||
address: address,
|
address: address,
|
||||||
peerDisconnect: peerDisconnect,
|
peerDisconnect: peerDisconnect,
|
||||||
peerErrorChan: peerErrorChan,
|
errc: errc,
|
||||||
blacklist: blacklist,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,10 +43,10 @@ func (self *PeerErrorHandler) Stop() {
|
|||||||
func (self *PeerErrorHandler) listen() {
|
func (self *PeerErrorHandler) listen() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case peerError, ok := <-self.peerErrorChan:
|
case err, ok := <-self.errc:
|
||||||
if ok {
|
if ok {
|
||||||
logger.Debugf("error %v\n", peerError)
|
logger.Debugf("error %v\n", err)
|
||||||
go self.handle(peerError)
|
go self.handle(err)
|
||||||
} else {
|
} else {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -59,8 +57,12 @@ func (self *PeerErrorHandler) listen() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *PeerErrorHandler) handle(peerError *PeerError) {
|
func (self *PeerErrorHandler) handle(err error) {
|
||||||
reason := DiscReason(' ')
|
reason := DiscReason(' ')
|
||||||
|
peerError, ok := err.(*PeerError)
|
||||||
|
if !ok {
|
||||||
|
peerError = NewPeerError(MiscError, " %v", err)
|
||||||
|
}
|
||||||
switch peerError.Code {
|
switch peerError.Code {
|
||||||
case P2PVersionMismatch:
|
case P2PVersionMismatch:
|
||||||
reason = DiscIncompatibleVersion
|
reason = DiscIncompatibleVersion
|
||||||
@ -68,11 +70,11 @@ func (self *PeerErrorHandler) handle(peerError *PeerError) {
|
|||||||
reason = DiscInvalidIdentity
|
reason = DiscInvalidIdentity
|
||||||
case PubkeyForbidden:
|
case PubkeyForbidden:
|
||||||
reason = DiscUselessPeer
|
reason = DiscUselessPeer
|
||||||
case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach:
|
case InvalidMsgCode, PacketTooLong, PayloadTooShort, MagicTokenMismatch, ProtocolBreach:
|
||||||
reason = DiscProtocolError
|
reason = DiscProtocolError
|
||||||
case PingTimeout:
|
case PingTimeout:
|
||||||
reason = DiscReadTimeout
|
reason = DiscReadTimeout
|
||||||
case WriteError, MiscError:
|
case ReadError, WriteError, MiscError:
|
||||||
reason = DiscNetworkError
|
reason = DiscNetworkError
|
||||||
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
|
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
|
||||||
reason = DiscSubprotocolError
|
reason = DiscSubprotocolError
|
||||||
@ -92,10 +94,5 @@ func (self *PeerErrorHandler) handle(peerError *PeerError) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
|
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
|
||||||
switch peerError.Code {
|
|
||||||
case ReadError:
|
|
||||||
return 4 //tolerate 3 :)
|
|
||||||
default:
|
|
||||||
return 1
|
return 1
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,7 @@ func TestPeerErrorHandler(t *testing.T) {
|
|||||||
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
|
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
|
||||||
peerDisconnect := make(chan DisconnectRequest)
|
peerDisconnect := make(chan DisconnectRequest)
|
||||||
peerErrorChan := NewPeerErrorChannel()
|
peerErrorChan := NewPeerErrorChannel()
|
||||||
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist())
|
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan)
|
||||||
peh.Start()
|
peh.Start()
|
||||||
defer peh.Stop()
|
defer peh.Stop()
|
||||||
for i := 0; i < 11; i++ {
|
for i := 0; i < 11; i++ {
|
||||||
|
170
p2p/peer_test.go
170
p2p/peer_test.go
@ -1,96 +1,90 @@
|
|||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
// "net"
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
// "net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPeer(t *testing.T) {
|
// func TestPeer(t *testing.T) {
|
||||||
handlers := make(Handlers)
|
// handlers := make(Handlers)
|
||||||
testProtocol := &TestProtocol{Msgs: []*Msg{}}
|
// testProtocol := &TestProtocol{recv: make(chan testMsg)}
|
||||||
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
|
// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
|
||||||
handlers["ccc"] = func(p *Peer) Protocol { return testProtocol }
|
// handlers["ccc"] = func(p *Peer) Protocol { return testProtocol }
|
||||||
addr := &TestAddr{"test:30"}
|
// addr := &TestAddr{"test:30"}
|
||||||
conn := NewTestNetworkConnection(addr)
|
// conn := NewTestNetworkConnection(addr)
|
||||||
_, server := SetupTestServer(handlers)
|
// _, server := SetupTestServer(handlers)
|
||||||
server.Handshake()
|
// server.Handshake()
|
||||||
peer := NewPeer(conn, addr, true, server)
|
// peer := NewPeer(conn, addr, true, server)
|
||||||
// peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
|
// // peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
|
||||||
peer.Start()
|
// peer.Start()
|
||||||
defer peer.Stop()
|
// defer peer.Stop()
|
||||||
time.Sleep(2 * time.Millisecond)
|
// time.Sleep(2 * time.Millisecond)
|
||||||
if len(conn.Out) != 1 {
|
// if len(conn.Out) != 1 {
|
||||||
t.Errorf("handshake not sent")
|
// t.Errorf("handshake not sent")
|
||||||
} else {
|
// } else {
|
||||||
out := conn.Out[0]
|
// out := conn.Out[0]
|
||||||
packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
|
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
|
||||||
if bytes.Compare(out, packet) != 0 {
|
// if bytes.Compare(out, packet) != 0 {
|
||||||
t.Errorf("incorrect handshake packet %v != %v", out, packet)
|
// t.Errorf("incorrect handshake packet %v != %v", out, packet)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
|
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
|
||||||
conn.In(0, packet)
|
// conn.In(0, packet)
|
||||||
time.Sleep(10 * time.Millisecond)
|
// time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
|
// pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
|
||||||
if pro.state != handshakeReceived {
|
// if pro.state != handshakeReceived {
|
||||||
t.Errorf("handshake not received")
|
// t.Errorf("handshake not received")
|
||||||
}
|
// }
|
||||||
if peer.Port != 30 {
|
// if peer.Port != 30 {
|
||||||
t.Errorf("port incorrectly set")
|
// t.Errorf("port incorrectly set")
|
||||||
}
|
// }
|
||||||
if peer.Id != "peer" {
|
// if peer.Id != "peer" {
|
||||||
t.Errorf("id incorrectly set")
|
// t.Errorf("id incorrectly set")
|
||||||
}
|
// }
|
||||||
if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
|
// if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
|
||||||
t.Errorf("pubkey incorrectly set")
|
// t.Errorf("pubkey incorrectly set")
|
||||||
}
|
// }
|
||||||
fmt.Println(peer.Caps)
|
// fmt.Println(peer.Caps)
|
||||||
if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
|
// if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
|
||||||
t.Errorf("protocols incorrectly set")
|
// t.Errorf("protocols incorrectly set")
|
||||||
}
|
// }
|
||||||
|
|
||||||
msg, _ := NewMsg(3)
|
// msg := NewMsg(3)
|
||||||
err := peer.Write("aaa", msg)
|
// err := peer.Write("aaa", msg)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
t.Errorf("expect no error for known protocol: %v", err)
|
// t.Errorf("expect no error for known protocol: %v", err)
|
||||||
} else {
|
// } else {
|
||||||
time.Sleep(1 * time.Millisecond)
|
// time.Sleep(1 * time.Millisecond)
|
||||||
if len(conn.Out) != 2 {
|
// if len(conn.Out) != 2 {
|
||||||
t.Errorf("msg not written")
|
// t.Errorf("msg not written")
|
||||||
} else {
|
// } else {
|
||||||
out := conn.Out[1]
|
// out := conn.Out[1]
|
||||||
packet := Packet(16, 3)
|
// packet := Packet(16, 3)
|
||||||
if bytes.Compare(out, packet) != 0 {
|
// if bytes.Compare(out, packet) != 0 {
|
||||||
t.Errorf("incorrect packet %v != %v", out, packet)
|
// t.Errorf("incorrect packet %v != %v", out, packet)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
msg, _ = NewMsg(2)
|
// msg = NewMsg(2)
|
||||||
err = peer.Write("ccc", msg)
|
// err = peer.Write("ccc", msg)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
t.Errorf("expect no error for known protocol: %v", err)
|
// t.Errorf("expect no error for known protocol: %v", err)
|
||||||
} else {
|
// } else {
|
||||||
time.Sleep(1 * time.Millisecond)
|
// time.Sleep(1 * time.Millisecond)
|
||||||
if len(conn.Out) != 3 {
|
// if len(conn.Out) != 3 {
|
||||||
t.Errorf("msg not written")
|
// t.Errorf("msg not written")
|
||||||
} else {
|
// } else {
|
||||||
out := conn.Out[2]
|
// out := conn.Out[2]
|
||||||
packet := Packet(21, 2)
|
// packet := Packet(21, 2)
|
||||||
if bytes.Compare(out, packet) != 0 {
|
// if bytes.Compare(out, packet) != 0 {
|
||||||
t.Errorf("incorrect packet %v != %v", out, packet)
|
// t.Errorf("incorrect packet %v != %v", out, packet)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
err = peer.Write("bbb", msg)
|
// err = peer.Write("bbb", msg)
|
||||||
time.Sleep(1 * time.Millisecond)
|
// time.Sleep(1 * time.Millisecond)
|
||||||
if err == nil {
|
// if err == nil {
|
||||||
t.Errorf("expect error for unknown protocol")
|
// t.Errorf("expect error for unknown protocol")
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
353
p2p/protocol.go
353
p2p/protocol.go
@ -2,43 +2,101 @@ package p2p
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/ethutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Protocol is implemented by P2P subprotocols.
|
||||||
type Protocol interface {
|
type Protocol interface {
|
||||||
Start()
|
// Start is called when the protocol becomes active.
|
||||||
Stop()
|
// It should read and write messages from rw.
|
||||||
HandleIn(*Msg, chan *Msg)
|
// Messages must be fully consumed.
|
||||||
HandleOut(*Msg) bool
|
//
|
||||||
|
// The connection is closed when Start returns. It should return
|
||||||
|
// any protocol-level error (such as an I/O error) that is
|
||||||
|
// encountered.
|
||||||
|
Start(peer *Peer, rw MsgReadWriter) error
|
||||||
|
|
||||||
|
// Offset should return the number of message codes
|
||||||
|
// used by the protocol.
|
||||||
Offset() MsgCode
|
Offset() MsgCode
|
||||||
Name() string
|
}
|
||||||
|
|
||||||
|
type MsgReader interface {
|
||||||
|
ReadMsg() (Msg, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MsgWriter interface {
|
||||||
|
WriteMsg(Msg) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// MsgReadWriter is passed to protocols. Protocol implementations can
|
||||||
|
// use it to write messages back to a connected peer.
|
||||||
|
type MsgReadWriter interface {
|
||||||
|
MsgReader
|
||||||
|
MsgWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type MsgHandler func(code MsgCode, data *ethutil.Value) error
|
||||||
|
|
||||||
|
// MsgLoop reads messages off the given reader and
|
||||||
|
// calls the handler function for each decoded message until
|
||||||
|
// it returns an error or the peer connection is closed.
|
||||||
|
//
|
||||||
|
// If a message is larger than the given maximum size, RunProtocol
|
||||||
|
// returns an appropriate error.n
|
||||||
|
func MsgLoop(r MsgReader, maxsize uint32, handler MsgHandler) error {
|
||||||
|
for {
|
||||||
|
msg, err := r.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Size > maxsize {
|
||||||
|
return NewPeerError(InvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
|
||||||
|
}
|
||||||
|
value, err := msg.Data()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := handler(msg.Code, value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// the ÐΞVp2p base protocol
|
||||||
|
type baseProtocol struct {
|
||||||
|
rw MsgReadWriter
|
||||||
|
peer *Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
type bpMsg struct {
|
||||||
|
code MsgCode
|
||||||
|
data *ethutil.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
P2PVersion = 0
|
p2pVersion = 0
|
||||||
pingTimeout = 2
|
pingTimeout = 2 * time.Second
|
||||||
pingGracePeriod = 2
|
pingGracePeriod = 2 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
HandshakeMsg = iota
|
// message codes
|
||||||
DiscMsg
|
handshakeMsg = iota
|
||||||
PingMsg
|
discMsg
|
||||||
PongMsg
|
pingMsg
|
||||||
GetPeersMsg
|
pongMsg
|
||||||
PeersMsg
|
getPeersMsg
|
||||||
offset = 16
|
peersMsg
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProtocolState uint8
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
nullState = iota
|
baseProtocolOffset MsgCode = 16
|
||||||
handshakeReceived
|
baseProtocolMaxMsgSize = 500 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
type DiscReason byte
|
type DiscReason byte
|
||||||
@ -62,7 +120,7 @@ const (
|
|||||||
DiscSubprotocolError = 0x10
|
DiscSubprotocolError = 0x10
|
||||||
)
|
)
|
||||||
|
|
||||||
var discReasonToString = map[DiscReason]string{
|
var discReasonToString = [DiscSubprotocolError + 1]string{
|
||||||
DiscRequested: "Disconnect requested",
|
DiscRequested: "Disconnect requested",
|
||||||
DiscNetworkError: "Network error",
|
DiscNetworkError: "Network error",
|
||||||
DiscProtocolError: "Breach of protocol",
|
DiscProtocolError: "Breach of protocol",
|
||||||
@ -82,197 +140,178 @@ func (d DiscReason) String() string {
|
|||||||
if len(discReasonToString) < int(d) {
|
if len(discReasonToString) < int(d) {
|
||||||
return "Unknown"
|
return "Unknown"
|
||||||
}
|
}
|
||||||
|
|
||||||
return discReasonToString[d]
|
return discReasonToString[d]
|
||||||
}
|
}
|
||||||
|
|
||||||
type BaseProtocol struct {
|
func (bp *baseProtocol) Ping() {
|
||||||
peer *Peer
|
|
||||||
state ProtocolState
|
|
||||||
stateLock sync.RWMutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBaseProtocol(peer *Peer) *BaseProtocol {
|
func (bp *baseProtocol) Offset() MsgCode {
|
||||||
self := &BaseProtocol{
|
return baseProtocolOffset
|
||||||
peer: peer,
|
}
|
||||||
|
|
||||||
|
func (bp *baseProtocol) Start(peer *Peer, rw MsgReadWriter) error {
|
||||||
|
bp.peer, bp.rw = peer, rw
|
||||||
|
|
||||||
|
// Do the handshake.
|
||||||
|
// TODO: disconnect is valid before handshake, too.
|
||||||
|
rw.WriteMsg(bp.peer.server.handshakeMsg())
|
||||||
|
msg, err := rw.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Code != handshakeMsg {
|
||||||
|
return NewPeerError(ProtocolBreach, " first message must be handshake")
|
||||||
|
}
|
||||||
|
data, err := msg.Data()
|
||||||
|
if err != nil {
|
||||||
|
return NewPeerError(InvalidMsg, "%v", err)
|
||||||
|
}
|
||||||
|
if err := bp.handleHandshake(data); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return self
|
msgin := make(chan bpMsg)
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- MsgLoop(rw, baseProtocolMaxMsgSize,
|
||||||
|
func(code MsgCode, data *ethutil.Value) error {
|
||||||
|
msgin <- bpMsg{code, data}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
return bp.loop(msgin, done)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *BaseProtocol) Start() {
|
func (bp *baseProtocol) loop(msgin <-chan bpMsg, quit <-chan error) error {
|
||||||
if self.peer != nil {
|
logger.Debugf("pingpong keepalive started at %v\n", time.Now())
|
||||||
self.peer.Write("", self.peer.Server().Handshake())
|
messenger := bp.rw.(*proto).messenger
|
||||||
go self.peer.Messenger().PingPong(
|
pingTimer := time.NewTimer(pingTimeout)
|
||||||
pingTimeout*time.Second,
|
pinged := true
|
||||||
pingGracePeriod*time.Second,
|
|
||||||
self.Ping,
|
for {
|
||||||
self.Timeout,
|
select {
|
||||||
)
|
case msg := <-msgin:
|
||||||
|
if err := bp.handle(msg.code, msg.data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case err := <-quit:
|
||||||
|
return err
|
||||||
|
case <-messenger.pulse:
|
||||||
|
pingTimer.Reset(pingTimeout)
|
||||||
|
pinged = false
|
||||||
|
case <-pingTimer.C:
|
||||||
|
if pinged {
|
||||||
|
return NewPeerError(PingTimeout, "")
|
||||||
|
}
|
||||||
|
logger.Debugf("pinging at %v\n", time.Now())
|
||||||
|
if err := bp.rw.WriteMsg(NewMsg(pingMsg)); err != nil {
|
||||||
|
return NewPeerError(WriteError, "%v", err)
|
||||||
|
}
|
||||||
|
pinged = true
|
||||||
|
pingTimer.Reset(pingTimeout)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *BaseProtocol) Stop() {
|
func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error {
|
||||||
}
|
switch code {
|
||||||
|
case handshakeMsg:
|
||||||
|
return NewPeerError(ProtocolBreach, " extra handshake received")
|
||||||
|
|
||||||
func (self *BaseProtocol) Ping() {
|
case discMsg:
|
||||||
msg, _ := NewMsg(PingMsg)
|
logger.Infof("Disconnect requested from peer %v, reason", DiscReason(data.Get(0).Uint()))
|
||||||
self.peer.Write("", msg)
|
bp.peer.server.PeerDisconnect() <- DisconnectRequest{
|
||||||
}
|
addr: bp.peer.Address,
|
||||||
|
|
||||||
func (self *BaseProtocol) Timeout() {
|
|
||||||
self.peerError(PingTimeout, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) Name() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) Offset() MsgCode {
|
|
||||||
return offset
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) CheckState(state ProtocolState) bool {
|
|
||||||
self.stateLock.RLock()
|
|
||||||
self.stateLock.RUnlock()
|
|
||||||
if self.state != state {
|
|
||||||
return false
|
|
||||||
} else {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) {
|
|
||||||
if msg.Code() == HandshakeMsg {
|
|
||||||
self.handleHandshake(msg)
|
|
||||||
} else {
|
|
||||||
if !self.CheckState(handshakeReceived) {
|
|
||||||
self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code())
|
|
||||||
close(response)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
switch msg.Code() {
|
|
||||||
case DiscMsg:
|
|
||||||
logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint()))
|
|
||||||
self.peer.Server().PeerDisconnect() <- DisconnectRequest{
|
|
||||||
addr: self.peer.Address,
|
|
||||||
reason: DiscRequested,
|
reason: DiscRequested,
|
||||||
}
|
}
|
||||||
case PingMsg:
|
|
||||||
out, _ := NewMsg(PongMsg)
|
case pingMsg:
|
||||||
response <- out
|
return bp.rw.WriteMsg(NewMsg(pongMsg))
|
||||||
case PongMsg:
|
|
||||||
case GetPeersMsg:
|
case pongMsg:
|
||||||
// Peer asked for list of connected peers
|
// reply for ping
|
||||||
if out, err := self.peer.Server().PeersMessage(); err != nil {
|
|
||||||
response <- out
|
case getPeersMsg:
|
||||||
|
// Peer asked for list of connected peers.
|
||||||
|
peersRLP := bp.peer.server.encodedPeerList()
|
||||||
|
if peersRLP != nil {
|
||||||
|
msg := Msg{
|
||||||
|
Code: peersMsg,
|
||||||
|
Size: uint32(len(peersRLP)),
|
||||||
|
Payload: bytes.NewReader(peersRLP),
|
||||||
}
|
}
|
||||||
case PeersMsg:
|
return bp.rw.WriteMsg(msg)
|
||||||
self.handlePeers(msg)
|
}
|
||||||
|
|
||||||
|
case peersMsg:
|
||||||
|
bp.handlePeers(data)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code())
|
return NewPeerError(InvalidMsgCode, "unknown message code %v", code)
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
close(response)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) {
|
func (bp *baseProtocol) handlePeers(data *ethutil.Value) {
|
||||||
// somewhat overly paranoid
|
it := data.NewIterator()
|
||||||
allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) {
|
|
||||||
err := NewPeerError(errorCode, format, v...)
|
|
||||||
logger.Warnln(err)
|
|
||||||
fmt.Println(self.peer, err)
|
|
||||||
if self.peer != nil {
|
|
||||||
self.peer.PeerErrorChan() <- err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) handlePeers(msg *Msg) {
|
|
||||||
it := msg.Data().NewIterator()
|
|
||||||
for it.Next() {
|
for it.Next() {
|
||||||
ip := net.IP(it.Value().Get(0).Bytes())
|
ip := net.IP(it.Value().Get(0).Bytes())
|
||||||
port := it.Value().Get(1).Uint()
|
port := it.Value().Get(1).Uint()
|
||||||
address := &net.TCPAddr{IP: ip, Port: int(port)}
|
address := &net.TCPAddr{IP: ip, Port: int(port)}
|
||||||
go self.peer.Server().PeerConnect(address)
|
go bp.peer.server.PeerConnect(address)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *BaseProtocol) handleHandshake(msg *Msg) {
|
func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
|
||||||
self.stateLock.Lock()
|
|
||||||
defer self.stateLock.Unlock()
|
|
||||||
if self.state != nullState {
|
|
||||||
self.peerError(ProtocolBreach, "extra handshake")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c := msg.Data()
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
p2pVersion = c.Get(0).Uint()
|
remoteVersion = c.Get(0).Uint()
|
||||||
id = c.Get(1).Str()
|
id = c.Get(1).Str()
|
||||||
caps = c.Get(2)
|
caps = c.Get(2)
|
||||||
port = c.Get(3).Uint()
|
port = c.Get(3).Uint()
|
||||||
pubkey = c.Get(4).Bytes()
|
pubkey = c.Get(4).Bytes()
|
||||||
)
|
)
|
||||||
fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey)
|
|
||||||
|
|
||||||
// Check correctness of p2p protocol version
|
// Check correctness of p2p protocol version
|
||||||
if p2pVersion != P2PVersion {
|
if remoteVersion != p2pVersion {
|
||||||
self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion)
|
return NewPeerError(P2PVersionMismatch, "Require protocol %d, received %d\n", p2pVersion, remoteVersion)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle the pub key (validation, uniqueness)
|
// Handle the pub key (validation, uniqueness)
|
||||||
if len(pubkey) == 0 {
|
if len(pubkey) == 0 {
|
||||||
self.peerError(PubkeyMissing, "not supplied in handshake.")
|
return NewPeerError(PubkeyMissing, "not supplied in handshake.")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(pubkey) != 64 {
|
if len(pubkey) != 64 {
|
||||||
self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8)
|
return NewPeerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Self connect detection
|
// self connect detection
|
||||||
if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 {
|
if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 {
|
||||||
self.peerError(PubkeyForbidden, "not allowed to connect to self")
|
return NewPeerError(PubkeyForbidden, "not allowed to connect to bp")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// register pubkey on server. this also sets the pubkey on the peer (need lock)
|
// register pubkey on server. this also sets the pubkey on the peer (need lock)
|
||||||
if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil {
|
if err := bp.peer.server.RegisterPubkey(bp.peer, pubkey); err != nil {
|
||||||
self.peerError(PubkeyForbidden, err.Error())
|
return NewPeerError(PubkeyForbidden, err.Error())
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// check port
|
// check port
|
||||||
if self.peer.Inbound {
|
if bp.peer.Inbound {
|
||||||
uint16port := uint16(port)
|
uint16port := uint16(port)
|
||||||
if self.peer.Port > 0 && self.peer.Port != uint16port {
|
if bp.peer.Port > 0 && bp.peer.Port != uint16port {
|
||||||
self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port)
|
return NewPeerError(PortMismatch, "port mismatch: %v != %v", bp.peer.Port, port)
|
||||||
return
|
|
||||||
} else {
|
} else {
|
||||||
self.peer.Port = uint16port
|
bp.peer.Port = uint16port
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
capsIt := caps.NewIterator()
|
capsIt := caps.NewIterator()
|
||||||
for capsIt.Next() {
|
for capsIt.Next() {
|
||||||
cap := capsIt.Value().Str()
|
cap := capsIt.Value().Str()
|
||||||
self.peer.Caps = append(self.peer.Caps, cap)
|
bp.peer.Caps = append(bp.peer.Caps, cap)
|
||||||
}
|
}
|
||||||
sort.Strings(self.peer.Caps)
|
sort.Strings(bp.peer.Caps)
|
||||||
self.peer.Messenger().AddProtocols(self.peer.Caps)
|
bp.rw.(*proto).messenger.setRemoteProtocols(bp.peer.Caps)
|
||||||
|
bp.peer.Id = id
|
||||||
self.peer.Id = id
|
return nil
|
||||||
|
|
||||||
self.state = handshakeReceived
|
|
||||||
|
|
||||||
//p.ethereum.PushPeer(p)
|
|
||||||
// p.ethereum.reactor.Post("peerList", p.ethereum.Peers())
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
128
p2p/server.go
128
p2p/server.go
@ -84,8 +84,8 @@ type Server struct {
|
|||||||
peers []*Peer
|
peers []*Peer
|
||||||
peerSlots chan int
|
peerSlots chan int
|
||||||
peersTable map[string]int
|
peersTable map[string]int
|
||||||
peersMsg *Msg
|
|
||||||
peerCount int
|
peerCount int
|
||||||
|
cachedEncodedPeers []byte
|
||||||
|
|
||||||
peerConnect chan net.Addr
|
peerConnect chan net.Addr
|
||||||
peerDisconnect chan DisconnectRequest
|
peerDisconnect chan DisconnectRequest
|
||||||
@ -147,27 +147,6 @@ func (self *Server) ClientIdentity() ClientIdentity {
|
|||||||
return self.identity
|
return self.identity
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Server) PeersMessage() (msg *Msg, err error) {
|
|
||||||
// TODO: memoize and reset when peers change
|
|
||||||
self.peersLock.RLock()
|
|
||||||
defer self.peersLock.RUnlock()
|
|
||||||
msg = self.peersMsg
|
|
||||||
if msg == nil {
|
|
||||||
var peerData []interface{}
|
|
||||||
for _, i := range self.peersTable {
|
|
||||||
peer := self.peers[i]
|
|
||||||
peerData = append(peerData, peer.Encode())
|
|
||||||
}
|
|
||||||
if len(peerData) == 0 {
|
|
||||||
err = fmt.Errorf("no peers")
|
|
||||||
} else {
|
|
||||||
msg, err = NewMsg(PeersMsg, peerData...)
|
|
||||||
self.peersMsg = msg //memoize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Server) Peers() (peers []*Peer) {
|
func (self *Server) Peers() (peers []*Peer) {
|
||||||
self.peersLock.RLock()
|
self.peersLock.RLock()
|
||||||
defer self.peersLock.RUnlock()
|
defer self.peersLock.RUnlock()
|
||||||
@ -185,8 +164,6 @@ func (self *Server) PeerCount() int {
|
|||||||
return self.peerCount
|
return self.peerCount
|
||||||
}
|
}
|
||||||
|
|
||||||
var getPeersMsg, _ = NewMsg(GetPeersMsg)
|
|
||||||
|
|
||||||
func (self *Server) PeerConnect(addr net.Addr) {
|
func (self *Server) PeerConnect(addr net.Addr) {
|
||||||
// TODO: should buffer, filter and uniq
|
// TODO: should buffer, filter and uniq
|
||||||
// send GetPeersMsg if not blocking
|
// send GetPeersMsg if not blocking
|
||||||
@ -209,12 +186,21 @@ func (self *Server) Handlers() Handlers {
|
|||||||
return self.handlers
|
return self.handlers
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Server) Broadcast(protocol string, msg *Msg) {
|
func (self *Server) Broadcast(protocol string, code MsgCode, data ...interface{}) {
|
||||||
|
var payload []byte
|
||||||
|
if data != nil {
|
||||||
|
payload = encodePayload(data...)
|
||||||
|
}
|
||||||
self.peersLock.RLock()
|
self.peersLock.RLock()
|
||||||
defer self.peersLock.RUnlock()
|
defer self.peersLock.RUnlock()
|
||||||
for _, peer := range self.peers {
|
for _, peer := range self.peers {
|
||||||
if peer != nil {
|
if peer != nil {
|
||||||
peer.Write(protocol, msg)
|
var msg = Msg{Code: code}
|
||||||
|
if data != nil {
|
||||||
|
msg.Payload = bytes.NewReader(payload)
|
||||||
|
msg.Size = uint32(len(payload))
|
||||||
|
}
|
||||||
|
peer.messenger.writeProtoMsg(protocol, msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -296,7 +282,7 @@ FOR:
|
|||||||
select {
|
select {
|
||||||
case slot := <-self.peerSlots:
|
case slot := <-self.peerSlots:
|
||||||
i++
|
i++
|
||||||
fmt.Printf("%v: found slot %v", i, slot)
|
fmt.Printf("%v: found slot %v\n", i, slot)
|
||||||
if i == self.maxPeers {
|
if i == self.maxPeers {
|
||||||
break FOR
|
break FOR
|
||||||
}
|
}
|
||||||
@ -358,70 +344,68 @@ func (self *Server) outboundPeerHandler(dialer Dialer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check if peer address already connected
|
// check if peer address already connected
|
||||||
func (self *Server) connected(address net.Addr) (err error) {
|
func (self *Server) isConnected(address net.Addr) bool {
|
||||||
self.peersLock.RLock()
|
self.peersLock.RLock()
|
||||||
defer self.peersLock.RUnlock()
|
defer self.peersLock.RUnlock()
|
||||||
// fmt.Printf("address: %v\n", address)
|
_, found := self.peersTable[address.String()]
|
||||||
slot, found := self.peersTable[address.String()]
|
return found
|
||||||
if found {
|
|
||||||
err = fmt.Errorf("already connected as peer %v (%v)", slot, address)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// connect to peer via listener.Accept()
|
// connect to peer via listener.Accept()
|
||||||
func (self *Server) connectInboundPeer(listener net.Listener, slot int) {
|
func (self *Server) connectInboundPeer(listener net.Listener, slot int) {
|
||||||
var address net.Addr
|
var address net.Addr
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
if err == nil {
|
|
||||||
address = conn.RemoteAddr()
|
|
||||||
err = self.connected(address)
|
|
||||||
if err != nil {
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debugln(err)
|
logger.Debugln(err)
|
||||||
self.peerSlots <- slot
|
self.peerSlots <- slot
|
||||||
} else {
|
return
|
||||||
|
}
|
||||||
|
address = conn.RemoteAddr()
|
||||||
|
// XXX: this won't work because the remote socket
|
||||||
|
// address does not identify the peer. we should
|
||||||
|
// probably get rid of this check and rely on public
|
||||||
|
// key detection in the base protocol.
|
||||||
|
if self.isConnected(address) {
|
||||||
|
conn.Close()
|
||||||
|
self.peerSlots <- slot
|
||||||
|
return
|
||||||
|
}
|
||||||
fmt.Printf("adding %v\n", address)
|
fmt.Printf("adding %v\n", address)
|
||||||
go self.addPeer(conn, address, true, slot)
|
go self.addPeer(conn, address, true, slot)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// connect to peer via dial out
|
// connect to peer via dial out
|
||||||
func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) {
|
func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) {
|
||||||
var conn net.Conn
|
if self.isConnected(address) {
|
||||||
err := self.connected(address)
|
return
|
||||||
if err == nil {
|
|
||||||
conn, err = dialer.Dial(address.Network(), address.String())
|
|
||||||
}
|
}
|
||||||
|
conn, err := dialer.Dial(address.Network(), address.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debugln(err)
|
|
||||||
self.peerSlots <- slot
|
self.peerSlots <- slot
|
||||||
} else {
|
return
|
||||||
go self.addPeer(conn, address, false, slot)
|
|
||||||
}
|
}
|
||||||
|
go self.addPeer(conn, address, false, slot)
|
||||||
}
|
}
|
||||||
|
|
||||||
// creates the new peer object and inserts it into its slot
|
// creates the new peer object and inserts it into its slot
|
||||||
func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) {
|
func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) *Peer {
|
||||||
self.peersLock.Lock()
|
self.peersLock.Lock()
|
||||||
defer self.peersLock.Unlock()
|
defer self.peersLock.Unlock()
|
||||||
if self.closed {
|
if self.closed {
|
||||||
fmt.Println("oopsy, not no longer need peer")
|
fmt.Println("oopsy, not no longer need peer")
|
||||||
conn.Close() //oopsy our bad
|
conn.Close() //oopsy our bad
|
||||||
self.peerSlots <- slot // release slot
|
self.peerSlots <- slot // release slot
|
||||||
} else {
|
return nil
|
||||||
|
}
|
||||||
|
logger.Infoln("adding new peer", address)
|
||||||
peer := NewPeer(conn, address, inbound, self)
|
peer := NewPeer(conn, address, inbound, self)
|
||||||
self.peers[slot] = peer
|
self.peers[slot] = peer
|
||||||
self.peersTable[address.String()] = slot
|
self.peersTable[address.String()] = slot
|
||||||
self.peerCount++
|
self.peerCount++
|
||||||
// reset peersmsg
|
self.cachedEncodedPeers = nil
|
||||||
self.peersMsg = nil
|
|
||||||
fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
|
fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
|
||||||
peer.Start()
|
peer.Start()
|
||||||
}
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
|
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
|
||||||
@ -441,13 +425,12 @@ func (self *Server) removePeer(request DisconnectRequest) {
|
|||||||
self.peerCount--
|
self.peerCount--
|
||||||
self.peers[slot] = nil
|
self.peers[slot] = nil
|
||||||
delete(self.peersTable, address.String())
|
delete(self.peersTable, address.String())
|
||||||
// reset peersmsg
|
self.cachedEncodedPeers = nil
|
||||||
self.peersMsg = nil
|
|
||||||
fmt.Printf("removed peer %v (slot %v)\n", peer, slot)
|
fmt.Printf("removed peer %v (slot %v)\n", peer, slot)
|
||||||
self.peersLock.Unlock()
|
self.peersLock.Unlock()
|
||||||
|
|
||||||
// sending disconnect message
|
// sending disconnect message
|
||||||
disconnectMsg, _ := NewMsg(DiscMsg, request.reason)
|
disconnectMsg := NewMsg(discMsg, request.reason)
|
||||||
peer.Write("", disconnectMsg)
|
peer.Write("", disconnectMsg)
|
||||||
// be nice and wait
|
// be nice and wait
|
||||||
time.Sleep(disconnectGracePeriod * time.Second)
|
time.Sleep(disconnectGracePeriod * time.Second)
|
||||||
@ -459,11 +442,32 @@ func (self *Server) removePeer(request DisconnectRequest) {
|
|||||||
self.peerSlots <- slot
|
self.peerSlots <- slot
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// encodedPeerList returns an RLP-encoded list of peers.
|
||||||
|
// the returned slice will be nil if there are no peers.
|
||||||
|
func (self *Server) encodedPeerList() []byte {
|
||||||
|
// TODO: memoize and reset when peers change
|
||||||
|
self.peersLock.RLock()
|
||||||
|
defer self.peersLock.RUnlock()
|
||||||
|
if self.cachedEncodedPeers == nil && self.peerCount > 0 {
|
||||||
|
var peerData []interface{}
|
||||||
|
for _, i := range self.peersTable {
|
||||||
|
peer := self.peers[i]
|
||||||
|
peerData = append(peerData, peer.Encode())
|
||||||
|
}
|
||||||
|
self.cachedEncodedPeers = encodePayload(peerData)
|
||||||
|
}
|
||||||
|
return self.cachedEncodedPeers
|
||||||
|
}
|
||||||
|
|
||||||
// fix handshake message to push to peers
|
// fix handshake message to push to peers
|
||||||
func (self *Server) Handshake() *Msg {
|
func (self *Server) handshakeMsg() Msg {
|
||||||
fmt.Println(self.identity.Pubkey()[1:])
|
return NewMsg(handshakeMsg,
|
||||||
msg, _ := NewMsg(HandshakeMsg, P2PVersion, []byte(self.identity.String()), []interface{}{self.protocols}, self.port, self.identity.Pubkey()[1:])
|
p2pVersion,
|
||||||
return msg
|
[]byte(self.identity.String()),
|
||||||
|
[]interface{}{self.protocols},
|
||||||
|
self.port,
|
||||||
|
self.identity.Pubkey()[1:],
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error {
|
func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error {
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -32,6 +32,7 @@ func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) {
|
|||||||
connections: self.connections,
|
connections: self.connections,
|
||||||
addr: addr,
|
addr: addr,
|
||||||
max: self.maxinbound,
|
max: self.maxinbound,
|
||||||
|
close: make(chan struct{}),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,24 +77,25 @@ type TestListener struct {
|
|||||||
addr net.Addr
|
addr net.Addr
|
||||||
max int
|
max int
|
||||||
i int
|
i int
|
||||||
|
close chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *TestListener) Accept() (conn net.Conn, err error) {
|
func (self *TestListener) Accept() (net.Conn, error) {
|
||||||
self.i++
|
self.i++
|
||||||
if self.i > self.max {
|
if self.i > self.max {
|
||||||
err = fmt.Errorf("no more")
|
<-self.close
|
||||||
} else {
|
return nil, io.EOF
|
||||||
|
}
|
||||||
addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)}
|
addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)}
|
||||||
tconn := NewTestNetworkConnection(addr)
|
tconn := NewTestNetworkConnection(addr)
|
||||||
key := tconn.RemoteAddr().String()
|
key := tconn.RemoteAddr().String()
|
||||||
self.connections[key] = tconn
|
self.connections[key] = tconn
|
||||||
conn = net.Conn(tconn)
|
|
||||||
fmt.Printf("accepted connection from: %v \n", addr)
|
fmt.Printf("accepted connection from: %v \n", addr)
|
||||||
}
|
return tconn, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *TestListener) Close() error {
|
func (self *TestListener) Close() error {
|
||||||
|
close(self.close)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -101,6 +103,86 @@ func (self *TestListener) Addr() net.Addr {
|
|||||||
return self.addr
|
return self.addr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TestNetworkConnection struct {
|
||||||
|
in chan []byte
|
||||||
|
close chan struct{}
|
||||||
|
current []byte
|
||||||
|
Out [][]byte
|
||||||
|
addr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
|
||||||
|
return &TestNetworkConnection{
|
||||||
|
in: make(chan []byte),
|
||||||
|
close: make(chan struct{}),
|
||||||
|
current: []byte{},
|
||||||
|
Out: [][]byte{},
|
||||||
|
addr: addr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
|
||||||
|
time.Sleep(latency)
|
||||||
|
for _, s := range packets {
|
||||||
|
self.in <- s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
|
||||||
|
if len(self.current) == 0 {
|
||||||
|
var ok bool
|
||||||
|
select {
|
||||||
|
case self.current, ok = <-self.in:
|
||||||
|
if !ok {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
case <-self.close:
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
length := len(self.current)
|
||||||
|
if length > len(buff) {
|
||||||
|
copy(buff[:], self.current[:len(buff)])
|
||||||
|
self.current = self.current[len(buff):]
|
||||||
|
return len(buff), nil
|
||||||
|
} else {
|
||||||
|
copy(buff[:length], self.current[:])
|
||||||
|
self.current = []byte{}
|
||||||
|
return length, io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
|
||||||
|
self.Out = append(self.Out, buff)
|
||||||
|
fmt.Printf("net write(%d): %x\n", len(self.Out), buff)
|
||||||
|
return len(buff), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) Close() error {
|
||||||
|
close(self.close)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
|
||||||
|
return self.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) {
|
func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) {
|
||||||
network = NewTestNetwork(1)
|
network = NewTestNetwork(1)
|
||||||
addr := &TestAddr{"test:30303"}
|
addr := &TestAddr{"test:30303"}
|
||||||
@ -124,12 +206,10 @@ func TestServerListener(t *testing.T) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
t.Error("not found inbound peer 1")
|
t.Error("not found inbound peer 1")
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("out: %v\n", peer1.Out)
|
|
||||||
if len(peer1.Out) != 2 {
|
if len(peer1.Out) != 2 {
|
||||||
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
|
t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerDialer(t *testing.T) {
|
func TestServerDialer(t *testing.T) {
|
||||||
@ -142,65 +222,63 @@ func TestServerDialer(t *testing.T) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
t.Error("not found outbound peer 1")
|
t.Error("not found outbound peer 1")
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("out: %v\n", peer1.Out)
|
|
||||||
if len(peer1.Out) != 2 {
|
if len(peer1.Out) != 2 {
|
||||||
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
|
t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerBroadcast(t *testing.T) {
|
// func TestServerBroadcast(t *testing.T) {
|
||||||
handlers := make(Handlers)
|
// handlers := make(Handlers)
|
||||||
testProtocol := &TestProtocol{Msgs: []*Msg{}}
|
// testProtocol := &TestProtocol{Msgs: []*Msg{}}
|
||||||
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
|
// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
|
||||||
network, server := SetupTestServer(handlers)
|
// network, server := SetupTestServer(handlers)
|
||||||
server.Start(true, true)
|
// server.Start(true, true)
|
||||||
server.peerConnect <- &TestAddr{"outboundpeer-1"}
|
// server.peerConnect <- &TestAddr{"outboundpeer-1"}
|
||||||
time.Sleep(10 * time.Millisecond)
|
// time.Sleep(10 * time.Millisecond)
|
||||||
msg, _ := NewMsg(0)
|
// msg := NewMsg(0)
|
||||||
server.Broadcast("", msg)
|
// server.Broadcast("", msg)
|
||||||
packet := Packet(0, 0)
|
// packet := Packet(0, 0)
|
||||||
time.Sleep(10 * time.Millisecond)
|
// time.Sleep(10 * time.Millisecond)
|
||||||
server.Stop()
|
// server.Stop()
|
||||||
peer1, ok := network.connections["outboundpeer-1"]
|
// peer1, ok := network.connections["outboundpeer-1"]
|
||||||
if !ok {
|
// if !ok {
|
||||||
t.Error("not found outbound peer 1")
|
// t.Error("not found outbound peer 1")
|
||||||
} else {
|
// } else {
|
||||||
fmt.Printf("out: %v\n", peer1.Out)
|
// fmt.Printf("out: %v\n", peer1.Out)
|
||||||
if len(peer1.Out) != 3 {
|
// if len(peer1.Out) != 3 {
|
||||||
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
|
// t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
|
||||||
} else {
|
// } else {
|
||||||
if bytes.Compare(peer1.Out[1], packet) != 0 {
|
// if bytes.Compare(peer1.Out[1], packet) != 0 {
|
||||||
t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
|
// t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
peer2, ok := network.connections["inboundpeer-1"]
|
// peer2, ok := network.connections["inboundpeer-1"]
|
||||||
if !ok {
|
// if !ok {
|
||||||
t.Error("not found inbound peer 2")
|
// t.Error("not found inbound peer 2")
|
||||||
} else {
|
// } else {
|
||||||
fmt.Printf("out: %v\n", peer2.Out)
|
// fmt.Printf("out: %v\n", peer2.Out)
|
||||||
if len(peer1.Out) != 3 {
|
// if len(peer1.Out) != 3 {
|
||||||
t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
|
// t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
|
||||||
} else {
|
// } else {
|
||||||
if bytes.Compare(peer2.Out[1], packet) != 0 {
|
// if bytes.Compare(peer2.Out[1], packet) != 0 {
|
||||||
t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
|
// t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
func TestServerPeersMessage(t *testing.T) {
|
func TestServerPeersMessage(t *testing.T) {
|
||||||
handlers := make(Handlers)
|
_, server := SetupTestServer(nil)
|
||||||
_, server := SetupTestServer(handlers)
|
|
||||||
server.Start(true, true)
|
server.Start(true, true)
|
||||||
defer server.Stop()
|
defer server.Stop()
|
||||||
server.peerConnect <- &TestAddr{"outboundpeer-1"}
|
server.peerConnect <- &TestAddr{"outboundpeer-1"}
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(2000 * time.Millisecond)
|
||||||
peersMsg, err := server.PeersMessage()
|
|
||||||
fmt.Println(peersMsg)
|
pl := server.encodedPeerList()
|
||||||
if err != nil {
|
if pl == nil {
|
||||||
t.Errorf("expect no error, got %v", err)
|
t.Errorf("expect non-nil peer list")
|
||||||
}
|
}
|
||||||
if c := server.PeerCount(); c != 2 {
|
if c := server.PeerCount(); c != 2 {
|
||||||
t.Errorf("expect 2 peers, got %v", c)
|
t.Errorf("expect 2 peers, got %v", c)
|
||||||
|
Loading…
Reference in New Issue
Block a user