forked from cerc-io/plugeth
Merge branch 'feature/p2p-protocol-interface' of https://github.com/fjl/go-ethereum into fjl-feature/p2p-protocol-interface
This commit is contained in:
commit
384b8c75f0
@ -5,10 +5,10 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc.
|
// ClientIdentity represents the identity of a peer.
|
||||||
type ClientIdentity interface {
|
type ClientIdentity interface {
|
||||||
String() string
|
String() string // human readable identity
|
||||||
Pubkey() []byte
|
Pubkey() []byte // 512-bit public key
|
||||||
}
|
}
|
||||||
|
|
||||||
type SimpleClientIdentity struct {
|
type SimpleClientIdentity struct {
|
||||||
|
@ -1,275 +0,0 @@
|
|||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
// "fmt"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/ethutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Connection struct {
|
|
||||||
conn net.Conn
|
|
||||||
// conn NetworkConnection
|
|
||||||
timeout time.Duration
|
|
||||||
in chan []byte
|
|
||||||
out chan []byte
|
|
||||||
err chan *PeerError
|
|
||||||
closingIn chan chan bool
|
|
||||||
closingOut chan chan bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// const readBufferLength = 2 //for testing
|
|
||||||
|
|
||||||
const readBufferLength = 1440
|
|
||||||
const partialsQueueSize = 10
|
|
||||||
const maxPendingQueueSize = 1
|
|
||||||
const defaultTimeout = 500
|
|
||||||
|
|
||||||
var magicToken = []byte{34, 64, 8, 145}
|
|
||||||
|
|
||||||
func (self *Connection) Open() {
|
|
||||||
go self.startRead()
|
|
||||||
go self.startWrite()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) Close() {
|
|
||||||
self.closeIn()
|
|
||||||
self.closeOut()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) closeIn() {
|
|
||||||
errc := make(chan bool)
|
|
||||||
self.closingIn <- errc
|
|
||||||
<-errc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) closeOut() {
|
|
||||||
errc := make(chan bool)
|
|
||||||
self.closingOut <- errc
|
|
||||||
<-errc
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection {
|
|
||||||
return &Connection{
|
|
||||||
conn: conn,
|
|
||||||
timeout: defaultTimeout,
|
|
||||||
in: make(chan []byte),
|
|
||||||
out: make(chan []byte),
|
|
||||||
err: errchan,
|
|
||||||
closingIn: make(chan chan bool, 1),
|
|
||||||
closingOut: make(chan chan bool, 1),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) Read() <-chan []byte {
|
|
||||||
return self.in
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) Write() chan<- []byte {
|
|
||||||
return self.out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) Error() <-chan *PeerError {
|
|
||||||
return self.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) startRead() {
|
|
||||||
payloads := make(chan []byte)
|
|
||||||
done := make(chan *PeerError)
|
|
||||||
pending := [][]byte{}
|
|
||||||
var head []byte
|
|
||||||
var wait time.Duration // initally 0 (no delay)
|
|
||||||
read := time.After(wait * time.Millisecond)
|
|
||||||
|
|
||||||
for {
|
|
||||||
// if pending empty, nil channel blocks
|
|
||||||
var in chan []byte
|
|
||||||
if len(pending) > 0 {
|
|
||||||
in = self.in // enable send case
|
|
||||||
head = pending[0]
|
|
||||||
} else {
|
|
||||||
in = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-read:
|
|
||||||
go self.read(payloads, done)
|
|
||||||
case err := <-done:
|
|
||||||
if err == nil { // no error but nothing to read
|
|
||||||
if len(pending) < maxPendingQueueSize {
|
|
||||||
wait = 100
|
|
||||||
} else if wait == 0 {
|
|
||||||
wait = 100
|
|
||||||
} else {
|
|
||||||
wait = 2 * wait
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
self.err <- err // report error
|
|
||||||
wait = 100
|
|
||||||
}
|
|
||||||
read = time.After(wait * time.Millisecond)
|
|
||||||
case payload := <-payloads:
|
|
||||||
pending = append(pending, payload)
|
|
||||||
if len(pending) < maxPendingQueueSize {
|
|
||||||
wait = 0
|
|
||||||
} else {
|
|
||||||
wait = 100
|
|
||||||
}
|
|
||||||
read = time.After(wait * time.Millisecond)
|
|
||||||
case in <- head:
|
|
||||||
pending = pending[1:]
|
|
||||||
case errc := <-self.closingIn:
|
|
||||||
errc <- true
|
|
||||||
close(self.in)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) startWrite() {
|
|
||||||
pending := [][]byte{}
|
|
||||||
done := make(chan *PeerError)
|
|
||||||
writing := false
|
|
||||||
for {
|
|
||||||
if len(pending) > 0 && !writing {
|
|
||||||
writing = true
|
|
||||||
go self.write(pending[0], done)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case payload := <-self.out:
|
|
||||||
pending = append(pending, payload)
|
|
||||||
case err := <-done:
|
|
||||||
if err == nil {
|
|
||||||
pending = pending[1:]
|
|
||||||
writing = false
|
|
||||||
} else {
|
|
||||||
self.err <- err // report error
|
|
||||||
}
|
|
||||||
case errc := <-self.closingOut:
|
|
||||||
errc <- true
|
|
||||||
close(self.out)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func pack(payload []byte) (packet []byte) {
|
|
||||||
length := ethutil.NumberToBytes(uint32(len(payload)), 32)
|
|
||||||
// return error if too long?
|
|
||||||
// Write magic token and payload length (first 8 bytes)
|
|
||||||
packet = append(magicToken, length...)
|
|
||||||
packet = append(packet, payload...)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func avoidPanic(done chan *PeerError) {
|
|
||||||
if rec := recover(); rec != nil {
|
|
||||||
err := NewPeerError(MiscError, " %v", rec)
|
|
||||||
logger.Debugln(err)
|
|
||||||
done <- err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) write(payload []byte, done chan *PeerError) {
|
|
||||||
defer avoidPanic(done)
|
|
||||||
var err *PeerError
|
|
||||||
_, ok := self.conn.Write(pack(payload))
|
|
||||||
if ok != nil {
|
|
||||||
err = NewPeerError(WriteError, " %v", ok)
|
|
||||||
logger.Debugln(err)
|
|
||||||
}
|
|
||||||
done <- err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) read(payloads chan []byte, done chan *PeerError) {
|
|
||||||
//defer avoidPanic(done)
|
|
||||||
|
|
||||||
partials := make(chan []byte, partialsQueueSize)
|
|
||||||
errc := make(chan *PeerError)
|
|
||||||
go self.readPartials(partials, errc)
|
|
||||||
|
|
||||||
packet := []byte{}
|
|
||||||
length := 8
|
|
||||||
start := true
|
|
||||||
var err *PeerError
|
|
||||||
out:
|
|
||||||
for {
|
|
||||||
// appends partials read via connection until packet is
|
|
||||||
// - either parseable (>=8bytes)
|
|
||||||
// - or complete (payload fully consumed)
|
|
||||||
for len(packet) < length {
|
|
||||||
partial, ok := <-partials
|
|
||||||
if !ok { // partials channel is closed
|
|
||||||
err = <-errc
|
|
||||||
if err == nil && len(packet) > 0 {
|
|
||||||
if start {
|
|
||||||
err = NewPeerError(PacketTooShort, "%v", packet)
|
|
||||||
} else {
|
|
||||||
err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break out
|
|
||||||
}
|
|
||||||
packet = append(packet, partial...)
|
|
||||||
}
|
|
||||||
if start {
|
|
||||||
// at least 8 bytes read, can validate packet
|
|
||||||
if bytes.Compare(magicToken, packet[:4]) != 0 {
|
|
||||||
err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4])
|
|
||||||
break
|
|
||||||
}
|
|
||||||
length = int(ethutil.BytesToNumber(packet[4:8]))
|
|
||||||
packet = packet[8:]
|
|
||||||
|
|
||||||
if length > 0 {
|
|
||||||
start = false // now consuming payload
|
|
||||||
} else { //penalize peer but read on
|
|
||||||
self.err <- NewPeerError(EmptyPayload, "")
|
|
||||||
length = 8
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// packet complete (payload fully consumed)
|
|
||||||
payloads <- packet[:length]
|
|
||||||
packet = packet[length:] // resclice packet
|
|
||||||
start = true
|
|
||||||
length = 8
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// this stops partials read via the connection, should we?
|
|
||||||
//if err != nil {
|
|
||||||
// select {
|
|
||||||
// case errc <- err
|
|
||||||
// default:
|
|
||||||
//}
|
|
||||||
done <- err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) {
|
|
||||||
defer close(partials)
|
|
||||||
for {
|
|
||||||
// Give buffering some time
|
|
||||||
self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond))
|
|
||||||
buffer := make([]byte, readBufferLength)
|
|
||||||
// read partial from connection
|
|
||||||
bytesRead, err := self.conn.Read(buffer)
|
|
||||||
if err == nil || err.Error() == "EOF" {
|
|
||||||
if bytesRead > 0 {
|
|
||||||
partials <- buffer[:bytesRead]
|
|
||||||
}
|
|
||||||
if err != nil && err.Error() == "EOF" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// unexpected error, report to errc
|
|
||||||
err := NewPeerError(ReadError, " %v", err)
|
|
||||||
logger.Debugln(err)
|
|
||||||
errc <- err
|
|
||||||
return // will close partials channel
|
|
||||||
}
|
|
||||||
}
|
|
||||||
close(errc)
|
|
||||||
}
|
|
@ -1,222 +0,0 @@
|
|||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TestNetworkConnection struct {
|
|
||||||
in chan []byte
|
|
||||||
current []byte
|
|
||||||
Out [][]byte
|
|
||||||
addr net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
|
|
||||||
return &TestNetworkConnection{
|
|
||||||
in: make(chan []byte),
|
|
||||||
current: []byte{},
|
|
||||||
Out: [][]byte{},
|
|
||||||
addr: addr,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
|
|
||||||
time.Sleep(latency)
|
|
||||||
for _, s := range packets {
|
|
||||||
self.in <- s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
|
|
||||||
if len(self.current) == 0 {
|
|
||||||
select {
|
|
||||||
case self.current = <-self.in:
|
|
||||||
default:
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
}
|
|
||||||
length := len(self.current)
|
|
||||||
if length > len(buff) {
|
|
||||||
copy(buff[:], self.current[:len(buff)])
|
|
||||||
self.current = self.current[len(buff):]
|
|
||||||
return len(buff), nil
|
|
||||||
} else {
|
|
||||||
copy(buff[:length], self.current[:])
|
|
||||||
self.current = []byte{}
|
|
||||||
return length, io.EOF
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
|
|
||||||
self.Out = append(self.Out, buff)
|
|
||||||
fmt.Printf("net write %v\n%v\n", len(self.Out), buff)
|
|
||||||
return len(buff), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) Close() (err error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
|
|
||||||
return self.addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupConnection() (*Connection, *TestNetworkConnection) {
|
|
||||||
addr := &TestAddr{"test:30303"}
|
|
||||||
net := NewTestNetworkConnection(addr)
|
|
||||||
conn := NewConnection(net, NewPeerErrorChannel())
|
|
||||||
conn.Open()
|
|
||||||
return conn, net
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingNilPacket(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{})
|
|
||||||
// time.Sleep(10 * time.Millisecond)
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
t.Errorf("read %v", packet)
|
|
||||||
case err := <-conn.Error():
|
|
||||||
t.Errorf("incorrect error %v", err)
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingShortPacket(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{0})
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
t.Errorf("read %v", packet)
|
|
||||||
case err := <-conn.Error():
|
|
||||||
if err.Code != PacketTooShort {
|
|
||||||
t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingInvalidPacket(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0})
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
t.Errorf("read %v", packet)
|
|
||||||
case err := <-conn.Error():
|
|
||||||
if err.Code != MagicTokenMismatch {
|
|
||||||
t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingInvalidPayload(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0})
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
t.Errorf("read %v", packet)
|
|
||||||
case err := <-conn.Error():
|
|
||||||
if err.Code != PayloadTooShort {
|
|
||||||
t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingEmptyPayload(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0})
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
t.Errorf("read %v", packet)
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case err := <-conn.Error():
|
|
||||||
code := err.Code
|
|
||||||
if code != EmptyPayload {
|
|
||||||
t.Errorf("incorrect error, expected EmptyPayload, got %v", code)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
t.Errorf("no error, expected EmptyPayload")
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingCompletePacket(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1})
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
if bytes.Compare(packet, []byte{1}) != 0 {
|
|
||||||
t.Errorf("incorrect payload read")
|
|
||||||
}
|
|
||||||
case err := <-conn.Error():
|
|
||||||
t.Errorf("incorrect error %v", err)
|
|
||||||
default:
|
|
||||||
t.Errorf("nothing read")
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadingTwoCompletePackets(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1})
|
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
select {
|
|
||||||
case packet := <-conn.Read():
|
|
||||||
if bytes.Compare(packet, []byte{byte(i)}) != 0 {
|
|
||||||
t.Errorf("incorrect payload read")
|
|
||||||
}
|
|
||||||
case err := <-conn.Error():
|
|
||||||
t.Errorf("incorrect error %v", err)
|
|
||||||
default:
|
|
||||||
t.Errorf("nothing read")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriting(t *testing.T) {
|
|
||||||
conn, net := setupConnection()
|
|
||||||
conn.Write() <- []byte{0}
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
if len(net.Out) == 0 {
|
|
||||||
t.Errorf("no output")
|
|
||||||
} else {
|
|
||||||
out := net.Out[0]
|
|
||||||
if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 {
|
|
||||||
t.Errorf("incorrect packet %v", out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243
|
|
202
p2p/message.go
202
p2p/message.go
@ -1,75 +1,155 @@
|
|||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
// "fmt"
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/big"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/ethutil"
|
"github.com/ethereum/go-ethereum/ethutil"
|
||||||
|
"github.com/ethereum/go-ethereum/rlp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MsgCode uint8
|
// Msg defines the structure of a p2p message.
|
||||||
|
//
|
||||||
|
// Note that a Msg can only be sent once since the Payload reader is
|
||||||
|
// consumed during sending. It is not possible to create a Msg and
|
||||||
|
// send it any number of times. If you want to reuse an encoded
|
||||||
|
// structure, encode the payload into a byte array and create a
|
||||||
|
// separate Msg with a bytes.Reader as Payload for each send.
|
||||||
type Msg struct {
|
type Msg struct {
|
||||||
code MsgCode // this is the raw code as per adaptive msg code scheme
|
Code uint64
|
||||||
data *ethutil.Value
|
Size uint32 // size of the paylod
|
||||||
encoded []byte
|
Payload io.Reader
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Msg) Code() MsgCode {
|
// NewMsg creates an RLP-encoded message with the given code.
|
||||||
return self.code
|
func NewMsg(code uint64, params ...interface{}) Msg {
|
||||||
}
|
buf := new(bytes.Buffer)
|
||||||
|
for _, p := range params {
|
||||||
func (self *Msg) Data() *ethutil.Value {
|
buf.Write(ethutil.Encode(p))
|
||||||
return self.data
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) {
|
|
||||||
|
|
||||||
// // data := [][]interface{}{}
|
|
||||||
// data := []interface{}{}
|
|
||||||
// for _, value := range params {
|
|
||||||
// if encodable, ok := value.(ethutil.RlpEncodeDecode); ok {
|
|
||||||
// data = append(data, encodable.RlpValue())
|
|
||||||
// } else if raw, ok := value.([]interface{}); ok {
|
|
||||||
// data = append(data, raw)
|
|
||||||
// } else {
|
|
||||||
// // data = append(data, interface{}(raw))
|
|
||||||
// err = fmt.Errorf("Unable to encode object of type %T", value)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
return &Msg{
|
|
||||||
code: code,
|
|
||||||
data: ethutil.NewValue(interface{}(params)),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) {
|
|
||||||
value := ethutil.NewValueFromBytes(encoded)
|
|
||||||
// Type of message
|
|
||||||
code := value.Get(0).Uint()
|
|
||||||
// Actual data
|
|
||||||
data := value.SliceFrom(1)
|
|
||||||
|
|
||||||
msg = &Msg{
|
|
||||||
code: MsgCode(code),
|
|
||||||
data: data,
|
|
||||||
// data: ethutil.NewValue(data),
|
|
||||||
encoded: encoded,
|
|
||||||
}
|
}
|
||||||
return
|
return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Msg) Decode(offset MsgCode) {
|
func encodePayload(params ...interface{}) []byte {
|
||||||
self.code = self.code - offset
|
buf := new(bytes.Buffer)
|
||||||
}
|
for _, p := range params {
|
||||||
|
buf.Write(ethutil.Encode(p))
|
||||||
// encode takes an offset argument to implement adaptive message coding
|
|
||||||
// the encoded message is memoized to make msgs relayed to several peers more efficient
|
|
||||||
func (self *Msg) Encode(offset MsgCode) (res []byte) {
|
|
||||||
if len(self.encoded) == 0 {
|
|
||||||
res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode()
|
|
||||||
self.encoded = res
|
|
||||||
} else {
|
|
||||||
res = self.encoded
|
|
||||||
}
|
}
|
||||||
return
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode parse the RLP content of a message into
|
||||||
|
// the given value, which must be a pointer.
|
||||||
|
//
|
||||||
|
// For the decoding rules, please see package rlp.
|
||||||
|
func (msg Msg) Decode(val interface{}) error {
|
||||||
|
s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
|
||||||
|
return s.Decode(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Discard reads any remaining payload data into a black hole.
|
||||||
|
func (msg Msg) Discard() error {
|
||||||
|
_, err := io.Copy(ioutil.Discard, msg.Payload)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type MsgReader interface {
|
||||||
|
ReadMsg() (Msg, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MsgWriter interface {
|
||||||
|
// WriteMsg sends an existing message.
|
||||||
|
// The Payload reader of the message is consumed.
|
||||||
|
// Note that messages can be sent only once.
|
||||||
|
WriteMsg(Msg) error
|
||||||
|
|
||||||
|
// EncodeMsg writes an RLP-encoded message with the given
|
||||||
|
// code and data elements.
|
||||||
|
EncodeMsg(code uint64, data ...interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// MsgReadWriter provides reading and writing of encoded messages.
|
||||||
|
type MsgReadWriter interface {
|
||||||
|
MsgReader
|
||||||
|
MsgWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
var magicToken = []byte{34, 64, 8, 145}
|
||||||
|
|
||||||
|
func writeMsg(w io.Writer, msg Msg) error {
|
||||||
|
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
|
||||||
|
code := ethutil.Encode(uint32(msg.Code))
|
||||||
|
listhdr := makeListHeader(msg.Size + uint32(len(code)))
|
||||||
|
payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size
|
||||||
|
|
||||||
|
start := make([]byte, 8)
|
||||||
|
copy(start, magicToken)
|
||||||
|
binary.BigEndian.PutUint32(start[4:], payloadLen)
|
||||||
|
|
||||||
|
for _, b := range [][]byte{start, listhdr, code} {
|
||||||
|
if _, err := w.Write(b); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err := io.CopyN(w, msg.Payload, int64(msg.Size))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeListHeader(length uint32) []byte {
|
||||||
|
if length < 56 {
|
||||||
|
return []byte{byte(length + 0xc0)}
|
||||||
|
}
|
||||||
|
enc := big.NewInt(int64(length)).Bytes()
|
||||||
|
lenb := byte(len(enc)) + 0xf7
|
||||||
|
return append([]byte{lenb}, enc...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// readMsg reads a message header from r.
|
||||||
|
// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer.
|
||||||
|
func readMsg(r rlp.ByteReader) (msg Msg, err error) {
|
||||||
|
// read magic and payload size
|
||||||
|
start := make([]byte, 8)
|
||||||
|
if _, err = io.ReadFull(r, start); err != nil {
|
||||||
|
return msg, newPeerError(errRead, "%v", err)
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(start, magicToken) {
|
||||||
|
return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
|
||||||
|
}
|
||||||
|
size := binary.BigEndian.Uint32(start[4:])
|
||||||
|
|
||||||
|
// decode start of RLP message to get the message code
|
||||||
|
posr := &postrack{r, 0}
|
||||||
|
s := rlp.NewStream(posr)
|
||||||
|
if _, err := s.List(); err != nil {
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
code, err := s.Uint()
|
||||||
|
if err != nil {
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
payloadsize := size - posr.p
|
||||||
|
return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// postrack wraps an rlp.ByteReader with a position counter.
|
||||||
|
type postrack struct {
|
||||||
|
r rlp.ByteReader
|
||||||
|
p uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *postrack) Read(buf []byte) (int, error) {
|
||||||
|
n, err := r.r.Read(buf)
|
||||||
|
r.p += uint32(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *postrack) ReadByte() (byte, error) {
|
||||||
|
b, err := r.r.ReadByte()
|
||||||
|
if err == nil {
|
||||||
|
r.p++
|
||||||
|
}
|
||||||
|
return b, err
|
||||||
}
|
}
|
||||||
|
@ -1,38 +1,70 @@
|
|||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io/ioutil"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/ethutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewMsg(t *testing.T) {
|
func TestNewMsg(t *testing.T) {
|
||||||
msg, _ := NewMsg(3, 1, "000")
|
msg := NewMsg(3, 1, "000")
|
||||||
if msg.Code() != 3 {
|
if msg.Code != 3 {
|
||||||
t.Errorf("incorrect code %v", msg.Code())
|
t.Errorf("incorrect code %d, want %d", msg.Code)
|
||||||
}
|
}
|
||||||
data0 := msg.Data().Get(0).Uint()
|
if msg.Size != 5 {
|
||||||
data1 := string(msg.Data().Get(1).Bytes())
|
t.Errorf("incorrect size %d, want %d", msg.Size, 5)
|
||||||
if data0 != 1 {
|
|
||||||
t.Errorf("incorrect data %v", data0)
|
|
||||||
}
|
}
|
||||||
if data1 != "000" {
|
pl, _ := ioutil.ReadAll(msg.Payload)
|
||||||
t.Errorf("incorrect data %v", data1)
|
expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30}
|
||||||
|
if !bytes.Equal(pl, expect) {
|
||||||
|
t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEncodeDecodeMsg(t *testing.T) {
|
func TestEncodeDecodeMsg(t *testing.T) {
|
||||||
msg, _ := NewMsg(3, 1, "000")
|
msg := NewMsg(3, 1, "000")
|
||||||
encoded := msg.Encode(3)
|
buf := new(bytes.Buffer)
|
||||||
msg, _ = NewMsgFromBytes(encoded)
|
if err := writeMsg(buf, msg); err != nil {
|
||||||
msg.Decode(3)
|
t.Fatalf("encodeMsg error: %v", err)
|
||||||
if msg.Code() != 3 {
|
|
||||||
t.Errorf("incorrect code %v", msg.Code())
|
|
||||||
}
|
}
|
||||||
data0 := msg.Data().Get(0).Uint()
|
// t.Logf("encoded: %x", buf.Bytes())
|
||||||
data1 := msg.Data().Get(1).Str()
|
|
||||||
if data0 != 1 {
|
decmsg, err := readMsg(buf)
|
||||||
t.Errorf("incorrect data %v", data0)
|
if err != nil {
|
||||||
|
t.Fatalf("readMsg error: %v", err)
|
||||||
}
|
}
|
||||||
if data1 != "000" {
|
if decmsg.Code != 3 {
|
||||||
t.Errorf("incorrect data %v", data1)
|
t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
|
||||||
|
}
|
||||||
|
if decmsg.Size != 5 {
|
||||||
|
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
var data struct {
|
||||||
|
I int
|
||||||
|
S string
|
||||||
|
}
|
||||||
|
if err := decmsg.Decode(&data); err != nil {
|
||||||
|
t.Fatalf("Decode error: %v", err)
|
||||||
|
}
|
||||||
|
if data.I != 1 {
|
||||||
|
t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
|
||||||
|
}
|
||||||
|
if data.S != "000" {
|
||||||
|
t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeRealMsg(t *testing.T) {
|
||||||
|
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
|
||||||
|
msg, err := readMsg(bytes.NewReader(data))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.Code != 0 {
|
||||||
|
t.Errorf("incorrect code %d, want %d", msg.Code, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
220
p2p/messenger.go
220
p2p/messenger.go
@ -1,220 +0,0 @@
|
|||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
handlerTimeout = 1000
|
|
||||||
)
|
|
||||||
|
|
||||||
type Handlers map[string](func(p *Peer) Protocol)
|
|
||||||
|
|
||||||
type Messenger struct {
|
|
||||||
conn *Connection
|
|
||||||
peer *Peer
|
|
||||||
handlers Handlers
|
|
||||||
protocolLock sync.RWMutex
|
|
||||||
protocols []Protocol
|
|
||||||
offsets []MsgCode // offsets for adaptive message idss
|
|
||||||
protocolTable map[string]int
|
|
||||||
quit chan chan bool
|
|
||||||
err chan *PeerError
|
|
||||||
pulse chan bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger {
|
|
||||||
baseProtocol := NewBaseProtocol(peer)
|
|
||||||
return &Messenger{
|
|
||||||
conn: conn,
|
|
||||||
peer: peer,
|
|
||||||
offsets: []MsgCode{baseProtocol.Offset()},
|
|
||||||
handlers: handlers,
|
|
||||||
protocols: []Protocol{baseProtocol},
|
|
||||||
protocolTable: make(map[string]int),
|
|
||||||
err: errchan,
|
|
||||||
pulse: make(chan bool, 1),
|
|
||||||
quit: make(chan chan bool, 1),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Messenger) Start() {
|
|
||||||
self.conn.Open()
|
|
||||||
go self.messenger()
|
|
||||||
self.protocolLock.RLock()
|
|
||||||
defer self.protocolLock.RUnlock()
|
|
||||||
self.protocols[0].Start()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Messenger) Stop() {
|
|
||||||
// close pulse to stop ping pong monitoring
|
|
||||||
close(self.pulse)
|
|
||||||
self.protocolLock.RLock()
|
|
||||||
defer self.protocolLock.RUnlock()
|
|
||||||
for _, protocol := range self.protocols {
|
|
||||||
protocol.Stop() // could be parallel
|
|
||||||
}
|
|
||||||
q := make(chan bool)
|
|
||||||
self.quit <- q
|
|
||||||
<-q
|
|
||||||
self.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Messenger) messenger() {
|
|
||||||
in := self.conn.Read()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case payload, ok := <-in:
|
|
||||||
//dispatches message to the protocol asynchronously
|
|
||||||
if ok {
|
|
||||||
go self.handle(payload)
|
|
||||||
} else {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case q := <-self.quit:
|
|
||||||
q <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handles each message by dispatching to the appropriate protocol
|
|
||||||
// using adaptive message codes
|
|
||||||
// this function is started as a separate go routine for each message
|
|
||||||
// it waits for the protocol response
|
|
||||||
// then encodes and sends outgoing messages to the connection's write channel
|
|
||||||
func (self *Messenger) handle(payload []byte) {
|
|
||||||
// send ping to heartbeat channel signalling time of last message
|
|
||||||
// select {
|
|
||||||
// case self.pulse <- true:
|
|
||||||
// default:
|
|
||||||
// }
|
|
||||||
self.pulse <- true
|
|
||||||
// initialise message from payload
|
|
||||||
msg, err := NewMsgFromBytes(payload)
|
|
||||||
if err != nil {
|
|
||||||
self.err <- NewPeerError(MiscError, " %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// retrieves protocol based on message Code
|
|
||||||
protocol, offset, peerErr := self.getProtocol(msg.Code())
|
|
||||||
if err != nil {
|
|
||||||
self.err <- peerErr
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// reset message code based on adaptive offset
|
|
||||||
msg.Decode(offset)
|
|
||||||
// dispatches
|
|
||||||
response := make(chan *Msg)
|
|
||||||
go protocol.HandleIn(msg, response)
|
|
||||||
// protocol reponse timeout to prevent leaks
|
|
||||||
timer := time.After(handlerTimeout * time.Millisecond)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case outgoing, ok := <-response:
|
|
||||||
// we check if response channel is not closed
|
|
||||||
if ok {
|
|
||||||
self.conn.Write() <- outgoing.Encode(offset)
|
|
||||||
} else {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case <-timer:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// negotiated protocols
|
|
||||||
// stores offsets needed for adaptive message id scheme
|
|
||||||
|
|
||||||
// based on offsets set at handshake
|
|
||||||
// get the right protocol to handle the message
|
|
||||||
func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) {
|
|
||||||
self.protocolLock.RLock()
|
|
||||||
defer self.protocolLock.RUnlock()
|
|
||||||
base := MsgCode(0)
|
|
||||||
for index, offset := range self.offsets {
|
|
||||||
if code < offset {
|
|
||||||
return self.protocols[index], base, nil
|
|
||||||
}
|
|
||||||
base = offset
|
|
||||||
}
|
|
||||||
return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) {
|
|
||||||
fmt.Printf("pingpong keepalive started at %v", time.Now())
|
|
||||||
|
|
||||||
timer := time.After(timeout)
|
|
||||||
pinged := false
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case _, ok := <-self.pulse:
|
|
||||||
if ok {
|
|
||||||
pinged = false
|
|
||||||
timer = time.After(timeout)
|
|
||||||
} else {
|
|
||||||
// pulse is closed, stop monitoring
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case <-timer:
|
|
||||||
if pinged {
|
|
||||||
fmt.Printf("timeout at %v", time.Now())
|
|
||||||
timeoutCallback()
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
fmt.Printf("pinged at %v", time.Now())
|
|
||||||
pingCallback()
|
|
||||||
timer = time.After(gracePeriod)
|
|
||||||
pinged = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Messenger) AddProtocols(protocols []string) {
|
|
||||||
self.protocolLock.Lock()
|
|
||||||
defer self.protocolLock.Unlock()
|
|
||||||
i := len(self.offsets)
|
|
||||||
offset := self.offsets[i-1]
|
|
||||||
for _, name := range protocols {
|
|
||||||
protocolFunc, ok := self.handlers[name]
|
|
||||||
if ok {
|
|
||||||
protocol := protocolFunc(self.peer)
|
|
||||||
self.protocolTable[name] = i
|
|
||||||
i++
|
|
||||||
offset += protocol.Offset()
|
|
||||||
fmt.Println("offset ", name, offset)
|
|
||||||
|
|
||||||
self.offsets = append(self.offsets, offset)
|
|
||||||
self.protocols = append(self.protocols, protocol)
|
|
||||||
protocol.Start()
|
|
||||||
} else {
|
|
||||||
fmt.Println("no ", name)
|
|
||||||
// protocol not handled
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Messenger) Write(protocol string, msg *Msg) error {
|
|
||||||
self.protocolLock.RLock()
|
|
||||||
defer self.protocolLock.RUnlock()
|
|
||||||
i := 0
|
|
||||||
offset := MsgCode(0)
|
|
||||||
if len(protocol) > 0 {
|
|
||||||
var ok bool
|
|
||||||
i, ok = self.protocolTable[protocol]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("protocol %v not handled by peer", protocol)
|
|
||||||
}
|
|
||||||
offset = self.offsets[i-1]
|
|
||||||
}
|
|
||||||
handler := self.protocols[i]
|
|
||||||
// checking if protocol status/caps allows the message to be sent out
|
|
||||||
if handler.HandleOut(msg) {
|
|
||||||
self.conn.Write() <- msg.Encode(offset)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,147 +0,0 @@
|
|||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
// "fmt"
|
|
||||||
"bytes"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/ethutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) {
|
|
||||||
errchan := NewPeerErrorChannel()
|
|
||||||
addr := &TestAddr{"test:30303"}
|
|
||||||
net := NewTestNetworkConnection(addr)
|
|
||||||
conn := NewConnection(net, errchan)
|
|
||||||
mess := NewMessenger(nil, conn, errchan, handlers)
|
|
||||||
mess.Start()
|
|
||||||
return net, errchan, mess
|
|
||||||
}
|
|
||||||
|
|
||||||
type TestProtocol struct {
|
|
||||||
Msgs []*Msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestProtocol) Start() {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestProtocol) Stop() {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestProtocol) Offset() MsgCode {
|
|
||||||
return MsgCode(5)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) {
|
|
||||||
self.Msgs = append(self.Msgs, msg)
|
|
||||||
close(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestProtocol) HandleOut(msg *Msg) bool {
|
|
||||||
if msg.Code() > 3 {
|
|
||||||
return false
|
|
||||||
} else {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestProtocol) Name() string {
|
|
||||||
return "a"
|
|
||||||
}
|
|
||||||
|
|
||||||
func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte {
|
|
||||||
msg, _ := NewMsg(code, params...)
|
|
||||||
encoded := msg.Encode(offset)
|
|
||||||
packet := []byte{34, 64, 8, 145}
|
|
||||||
packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...)
|
|
||||||
return append(packet, encoded...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRead(t *testing.T) {
|
|
||||||
handlers := make(Handlers)
|
|
||||||
testProtocol := &TestProtocol{Msgs: []*Msg{}}
|
|
||||||
handlers["a"] = func(p *Peer) Protocol { return testProtocol }
|
|
||||||
net, _, mess := setupMessenger(handlers)
|
|
||||||
mess.AddProtocols([]string{"a"})
|
|
||||||
defer mess.Stop()
|
|
||||||
wait := 1 * time.Millisecond
|
|
||||||
packet := Packet(16, 1, uint32(1), "000")
|
|
||||||
go net.In(0, packet)
|
|
||||||
time.Sleep(wait)
|
|
||||||
if len(testProtocol.Msgs) != 1 {
|
|
||||||
t.Errorf("msg not relayed to correct protocol")
|
|
||||||
} else {
|
|
||||||
if testProtocol.Msgs[0].Code() != 1 {
|
|
||||||
t.Errorf("incorrect msg code relayed to protocol")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWrite(t *testing.T) {
|
|
||||||
handlers := make(Handlers)
|
|
||||||
testProtocol := &TestProtocol{Msgs: []*Msg{}}
|
|
||||||
handlers["a"] = func(p *Peer) Protocol { return testProtocol }
|
|
||||||
net, _, mess := setupMessenger(handlers)
|
|
||||||
mess.AddProtocols([]string{"a"})
|
|
||||||
defer mess.Stop()
|
|
||||||
wait := 1 * time.Millisecond
|
|
||||||
msg, _ := NewMsg(3, uint32(1), "000")
|
|
||||||
err := mess.Write("b", msg)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("expect error for unknown protocol")
|
|
||||||
}
|
|
||||||
err = mess.Write("a", msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("expect no error for known protocol: %v", err)
|
|
||||||
} else {
|
|
||||||
time.Sleep(wait)
|
|
||||||
if len(net.Out) != 1 {
|
|
||||||
t.Errorf("msg not written")
|
|
||||||
} else {
|
|
||||||
out := net.Out[0]
|
|
||||||
packet := Packet(16, 3, uint32(1), "000")
|
|
||||||
if bytes.Compare(out, packet) != 0 {
|
|
||||||
t.Errorf("incorrect packet %v", out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPulse(t *testing.T) {
|
|
||||||
net, _, mess := setupMessenger(make(Handlers))
|
|
||||||
defer mess.Stop()
|
|
||||||
ping := false
|
|
||||||
timeout := false
|
|
||||||
pingTimeout := 10 * time.Millisecond
|
|
||||||
gracePeriod := 200 * time.Millisecond
|
|
||||||
go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true })
|
|
||||||
net.In(0, Packet(0, 1))
|
|
||||||
if ping {
|
|
||||||
t.Errorf("ping sent too early")
|
|
||||||
}
|
|
||||||
time.Sleep(pingTimeout + 100*time.Millisecond)
|
|
||||||
if !ping {
|
|
||||||
t.Errorf("no ping sent after timeout")
|
|
||||||
}
|
|
||||||
if timeout {
|
|
||||||
t.Errorf("timeout too early")
|
|
||||||
}
|
|
||||||
ping = false
|
|
||||||
net.In(0, Packet(0, 1))
|
|
||||||
time.Sleep(pingTimeout + 100*time.Millisecond)
|
|
||||||
if !ping {
|
|
||||||
t.Errorf("no ping sent after timeout")
|
|
||||||
}
|
|
||||||
if timeout {
|
|
||||||
t.Errorf("timeout too early")
|
|
||||||
}
|
|
||||||
ping = false
|
|
||||||
time.Sleep(gracePeriod)
|
|
||||||
if ping {
|
|
||||||
t.Errorf("ping called twice")
|
|
||||||
}
|
|
||||||
if !timeout {
|
|
||||||
t.Errorf("no timeout after grace period")
|
|
||||||
}
|
|
||||||
}
|
|
@ -3,6 +3,7 @@ package p2p
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
natpmp "github.com/jackpal/go-nat-pmp"
|
natpmp "github.com/jackpal/go-nat-pmp"
|
||||||
)
|
)
|
||||||
@ -13,38 +14,37 @@ import (
|
|||||||
// + Register for changes to the external address.
|
// + Register for changes to the external address.
|
||||||
// + Re-register port mapping when router reboots.
|
// + Re-register port mapping when router reboots.
|
||||||
// + A mechanism for keeping a port mapping registered.
|
// + A mechanism for keeping a port mapping registered.
|
||||||
|
// + Discover gateway address automatically.
|
||||||
|
|
||||||
type natPMPClient struct {
|
type natPMPClient struct {
|
||||||
client *natpmp.Client
|
client *natpmp.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNatPMP(gateway net.IP) (nat NAT) {
|
// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
|
||||||
|
// address should be the IP of your router.
|
||||||
|
func PMP(gateway net.IP) (nat NAT) {
|
||||||
return &natPMPClient{natpmp.NewClient(gateway)}
|
return &natPMPClient{natpmp.NewClient(gateway)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) {
|
func (*natPMPClient) String() string {
|
||||||
response, err := n.client.GetExternalAddress()
|
return "NAT-PMP"
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ip := response.ExternalIPAddress
|
|
||||||
addr = net.IPv4(ip[0], ip[1], ip[2], ip[3])
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int,
|
func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
|
||||||
description string, timeout int) (mappedExternalPort int, err error) {
|
response, err := n.client.GetExternalAddress()
|
||||||
if timeout <= 0 {
|
if err != nil {
|
||||||
err = fmt.Errorf("timeout must not be <= 0")
|
return nil, err
|
||||||
return
|
}
|
||||||
|
return response.ExternalIPAddress[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
|
||||||
|
if lifetime <= 0 {
|
||||||
|
return fmt.Errorf("lifetime must not be <= 0")
|
||||||
}
|
}
|
||||||
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
|
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
|
||||||
response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout)
|
_, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
|
||||||
if err != nil {
|
return err
|
||||||
return
|
|
||||||
}
|
|
||||||
mappedExternalPort = int(response.MappedExternalPort)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
|
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
|
||||||
|
198
p2p/natupnp.go
198
p2p/natupnp.go
@ -7,6 +7,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@ -15,28 +16,46 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
upnpDiscoverAttempts = 3
|
||||||
|
upnpDiscoverTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
|
||||||
|
// discover the address of your router using UDP broadcasts.
|
||||||
|
func UPNP() NAT {
|
||||||
|
return &upnpNAT{}
|
||||||
|
}
|
||||||
|
|
||||||
type upnpNAT struct {
|
type upnpNAT struct {
|
||||||
serviceURL string
|
serviceURL string
|
||||||
ourIP string
|
ourIP string
|
||||||
}
|
}
|
||||||
|
|
||||||
func upnpDiscover(attempts int) (nat NAT, err error) {
|
func (n *upnpNAT) String() string {
|
||||||
|
return "UPNP"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnpNAT) discover() error {
|
||||||
|
if n.serviceURL != "" {
|
||||||
|
// already discovered
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
|
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
// TODO: try on all network interfaces simultaneously.
|
||||||
|
// Broadcasting on 0.0.0.0 could select a random interface
|
||||||
|
// to send on (platform specific).
|
||||||
conn, err := net.ListenPacket("udp4", ":0")
|
conn, err := net.ListenPacket("udp4", ":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
|
||||||
socket := conn.(*net.UDPConn)
|
|
||||||
defer socket.Close()
|
|
||||||
|
|
||||||
err = socket.SetDeadline(time.Now().Add(10 * time.Second))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
conn.SetDeadline(time.Now().Add(10 * time.Second))
|
||||||
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
|
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
|
||||||
buf := bytes.NewBufferString(
|
buf := bytes.NewBufferString(
|
||||||
"M-SEARCH * HTTP/1.1\r\n" +
|
"M-SEARCH * HTTP/1.1\r\n" +
|
||||||
@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
|
|||||||
"MX: 2\r\n\r\n")
|
"MX: 2\r\n\r\n")
|
||||||
message := buf.Bytes()
|
message := buf.Bytes()
|
||||||
answerBytes := make([]byte, 1024)
|
answerBytes := make([]byte, 1024)
|
||||||
for i := 0; i < attempts; i++ {
|
for i := 0; i < upnpDiscoverAttempts; i++ {
|
||||||
_, err = socket.WriteToUDP(message, ssdp)
|
_, err = conn.WriteTo(message, ssdp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
var n int
|
nn, _, err := conn.ReadFrom(answerBytes)
|
||||||
n, _, err = socket.ReadFromUDP(answerBytes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
// socket.Close()
|
|
||||||
// return
|
|
||||||
}
|
}
|
||||||
answer := string(answerBytes[0:n])
|
answer := string(answerBytes[0:nn])
|
||||||
if strings.Index(answer, "\r\n"+st) < 0 {
|
if strings.Index(answer, "\r\n"+st) < 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
|
|||||||
var serviceURL string
|
var serviceURL string
|
||||||
serviceURL, err = getServiceURL(locURL)
|
serviceURL, err = getServiceURL(locURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
var ourIP string
|
var ourIP string
|
||||||
ourIP, err = getOurIP()
|
ourIP, err = getOurIP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP}
|
n.serviceURL = serviceURL
|
||||||
|
n.ourIP = ourIP
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errors.New("UPnP port discovery failed.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
|
||||||
|
if err := n.discover(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
info, err := n.getStatusInfo()
|
||||||
|
return net.ParseIP(info.externalIpAddress), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
|
||||||
|
if err := n.discover(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// A single concatenation would break ARM compilation.
|
||||||
|
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
||||||
|
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
|
||||||
|
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
|
||||||
|
message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
|
||||||
|
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
|
||||||
|
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
|
||||||
|
message += description +
|
||||||
|
"</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
|
||||||
|
"</NewLeaseDuration></u:AddPortMapping>"
|
||||||
|
|
||||||
|
// TODO: check response to see if the port was forwarded
|
||||||
|
_, err := soapRequest(n.serviceURL, "AddPortMapping", message)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
|
||||||
|
if err := n.discover(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
||||||
|
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
|
||||||
|
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
|
||||||
|
"</u:DeletePortMapping>"
|
||||||
|
|
||||||
|
// TODO: check response to see if the port was deleted
|
||||||
|
_, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type statusInfo struct {
|
||||||
|
externalIpAddress string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
|
||||||
|
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
||||||
|
"</u:GetStatusInfo>"
|
||||||
|
|
||||||
|
var response *http.Response
|
||||||
|
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
|
||||||
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = errors.New("UPnP port discovery failed.")
|
|
||||||
|
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
|
||||||
|
|
||||||
|
response.Body.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
type statusInfo struct {
|
|
||||||
externalIpAddress string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
|
|
||||||
|
|
||||||
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
|
||||||
"</u:GetStatusInfo>"
|
|
||||||
|
|
||||||
var response *http.Response
|
|
||||||
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
|
|
||||||
|
|
||||||
response.Body.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
|
|
||||||
info, err := n.getStatusInfo()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
addr = net.ParseIP(info.externalIpAddress)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
|
|
||||||
// A single concatenation would break ARM compilation.
|
|
||||||
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
|
||||||
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort)
|
|
||||||
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
|
|
||||||
message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" +
|
|
||||||
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
|
|
||||||
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
|
|
||||||
message += description +
|
|
||||||
"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) +
|
|
||||||
"</NewLeaseDuration></u:AddPortMapping>"
|
|
||||||
|
|
||||||
var response *http.Response
|
|
||||||
response, err = soapRequest(n.serviceURL, "AddPortMapping", message)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: check response to see if the port was forwarded
|
|
||||||
// log.Println(message, response)
|
|
||||||
mappedExternalPort = externalPort
|
|
||||||
_ = response
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
|
|
||||||
|
|
||||||
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
|
||||||
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
|
|
||||||
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
|
|
||||||
"</u:DeletePortMapping>"
|
|
||||||
|
|
||||||
var response *http.Response
|
|
||||||
response, err = soapRequest(n.serviceURL, "DeletePortMapping", message)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: check response to see if the port was deleted
|
|
||||||
// log.Println(message, response)
|
|
||||||
_ = response
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
196
p2p/network.go
196
p2p/network.go
@ -1,196 +0,0 @@
|
|||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
DialerTimeout = 180 //seconds
|
|
||||||
KeepAlivePeriod = 60 //minutes
|
|
||||||
portMappingUpdateInterval = 900 // seconds = 15 mins
|
|
||||||
upnpDiscoverAttempts = 3
|
|
||||||
)
|
|
||||||
|
|
||||||
// Dialer is not an interface in net, so we define one
|
|
||||||
// *net.Dialer conforms to this
|
|
||||||
type Dialer interface {
|
|
||||||
Dial(network, address string) (net.Conn, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Network interface {
|
|
||||||
Start() error
|
|
||||||
Listener(net.Addr) (net.Listener, error)
|
|
||||||
Dialer(net.Addr) (Dialer, error)
|
|
||||||
NewAddr(string, int) (addr net.Addr, err error)
|
|
||||||
ParseAddr(string) (addr net.Addr, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type NAT interface {
|
|
||||||
GetExternalAddress() (addr net.IP, err error)
|
|
||||||
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
|
|
||||||
DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type TCPNetwork struct {
|
|
||||||
nat NAT
|
|
||||||
natType NATType
|
|
||||||
quit chan chan bool
|
|
||||||
ports chan string
|
|
||||||
}
|
|
||||||
|
|
||||||
type NATType int
|
|
||||||
|
|
||||||
const (
|
|
||||||
NONE = iota
|
|
||||||
UPNP
|
|
||||||
PMP
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
portMappingTimeout = 1200 // 20 mins
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewTCPNetwork(natType NATType) (net *TCPNetwork) {
|
|
||||||
return &TCPNetwork{
|
|
||||||
natType: natType,
|
|
||||||
ports: make(chan string),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) {
|
|
||||||
return &net.Dialer{
|
|
||||||
Timeout: DialerTimeout * time.Second,
|
|
||||||
// KeepAlive: KeepAlivePeriod * time.Minute,
|
|
||||||
LocalAddr: addr,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) {
|
|
||||||
if self.natType == UPNP {
|
|
||||||
_, port, _ := net.SplitHostPort(addr.String())
|
|
||||||
if self.quit == nil {
|
|
||||||
self.quit = make(chan chan bool)
|
|
||||||
go self.updatePortMappings()
|
|
||||||
}
|
|
||||||
self.ports <- port
|
|
||||||
}
|
|
||||||
return net.Listen(addr.Network(), addr.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TCPNetwork) Start() (err error) {
|
|
||||||
switch self.natType {
|
|
||||||
case NONE:
|
|
||||||
case UPNP:
|
|
||||||
nat, uerr := upnpDiscover(upnpDiscoverAttempts)
|
|
||||||
if uerr != nil {
|
|
||||||
err = fmt.Errorf("UPNP failed: ", uerr)
|
|
||||||
} else {
|
|
||||||
self.nat = nat
|
|
||||||
}
|
|
||||||
case PMP:
|
|
||||||
err = fmt.Errorf("PMP not implemented")
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("Invalid NAT type: %v", self.natType)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TCPNetwork) Stop() {
|
|
||||||
q := make(chan bool)
|
|
||||||
self.quit <- q
|
|
||||||
<-q
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TCPNetwork) addPortMapping(lport int) (err error) {
|
|
||||||
_, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("unable to add port mapping on %v: %v", lport, err)
|
|
||||||
} else {
|
|
||||||
logger.Debugf("succesfully added port mapping on %v", lport)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TCPNetwork) updatePortMappings() {
|
|
||||||
timer := time.NewTimer(portMappingUpdateInterval * time.Second)
|
|
||||||
lports := []int{}
|
|
||||||
out:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case port := <-self.ports:
|
|
||||||
int64lport, _ := strconv.ParseInt(port, 10, 16)
|
|
||||||
lport := int(int64lport)
|
|
||||||
if err := self.addPortMapping(lport); err != nil {
|
|
||||||
lports = append(lports, lport)
|
|
||||||
}
|
|
||||||
case <-timer.C:
|
|
||||||
for lport := range lports {
|
|
||||||
if err := self.addPortMapping(lport); err != nil {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case errc := <-self.quit:
|
|
||||||
errc <- true
|
|
||||||
break out
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
timer.Stop()
|
|
||||||
for lport := range lports {
|
|
||||||
if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil {
|
|
||||||
logger.Debugf("unable to remove port mapping on %v: %v", lport, err)
|
|
||||||
} else {
|
|
||||||
logger.Debugf("succesfully removed port mapping on %v", lport)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) {
|
|
||||||
ip, err := self.lookupIP(host)
|
|
||||||
if err == nil {
|
|
||||||
return &net.TCPAddr{
|
|
||||||
IP: ip,
|
|
||||||
Port: port,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) {
|
|
||||||
host, port, err := net.SplitHostPort(address)
|
|
||||||
if err == nil {
|
|
||||||
iport, _ := strconv.Atoi(port)
|
|
||||||
addr, e := self.NewAddr(host, iport)
|
|
||||||
return addr, e
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) {
|
|
||||||
if ip = net.ParseIP(host); ip != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var ips []net.IP
|
|
||||||
ips, err = net.LookupIP(host)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warnln(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(ips) == 0 {
|
|
||||||
err = fmt.Errorf("No IP addresses available for %v", host)
|
|
||||||
logger.Warnln(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(ips) > 1 {
|
|
||||||
// Pick a random IP address, simulating round-robin DNS.
|
|
||||||
rand.Seed(time.Now().UTC().UnixNano())
|
|
||||||
ip = ips[rand.Intn(len(ips))]
|
|
||||||
} else {
|
|
||||||
ip = ips[0]
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
490
p2p/peer.go
490
p2p/peer.go
@ -1,83 +1,455 @@
|
|||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/event"
|
||||||
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Peer struct {
|
// peerAddr is the structure of a peer list element.
|
||||||
// quit chan chan bool
|
// It is also a valid net.Addr.
|
||||||
Inbound bool // inbound (via listener) or outbound (via dialout)
|
type peerAddr struct {
|
||||||
Address net.Addr
|
IP net.IP
|
||||||
Host []byte
|
Port uint64
|
||||||
Port uint16
|
Pubkey []byte // optional
|
||||||
Pubkey []byte
|
|
||||||
Id string
|
|
||||||
Caps []string
|
|
||||||
peerErrorChan chan *PeerError
|
|
||||||
messenger *Messenger
|
|
||||||
peerErrorHandler *PeerErrorHandler
|
|
||||||
server *Server
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Peer) Messenger() *Messenger {
|
func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr {
|
||||||
return self.messenger
|
n := addr.Network()
|
||||||
}
|
if n != "tcp" && n != "tcp4" && n != "tcp6" {
|
||||||
|
// for testing with non-TCP
|
||||||
func (self *Peer) PeerErrorChan() chan *PeerError {
|
return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
|
||||||
return self.peerErrorChan
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Peer) Server() *Server {
|
|
||||||
return self.server
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer {
|
|
||||||
peerErrorChan := NewPeerErrorChannel()
|
|
||||||
host, port, _ := net.SplitHostPort(address.String())
|
|
||||||
intport, _ := strconv.Atoi(port)
|
|
||||||
peer := &Peer{
|
|
||||||
Inbound: inbound,
|
|
||||||
Address: address,
|
|
||||||
Port: uint16(intport),
|
|
||||||
Host: net.ParseIP(host),
|
|
||||||
peerErrorChan: peerErrorChan,
|
|
||||||
server: server,
|
|
||||||
}
|
}
|
||||||
connection := NewConnection(conn, peerErrorChan)
|
ta := addr.(*net.TCPAddr)
|
||||||
peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers())
|
return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
|
||||||
peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist())
|
}
|
||||||
|
|
||||||
|
func (d peerAddr) Network() string {
|
||||||
|
if d.IP.To4() != nil {
|
||||||
|
return "tcp4"
|
||||||
|
} else {
|
||||||
|
return "tcp6"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d peerAddr) String() string {
|
||||||
|
return fmt.Sprintf("%v:%d", d.IP, d.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d peerAddr) RlpData() interface{} {
|
||||||
|
return []interface{}{d.IP, d.Port, d.Pubkey}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peer represents a remote peer.
|
||||||
|
type Peer struct {
|
||||||
|
// Peers have all the log methods.
|
||||||
|
// Use them to display messages related to the peer.
|
||||||
|
*logger.Logger
|
||||||
|
|
||||||
|
infolock sync.Mutex
|
||||||
|
identity ClientIdentity
|
||||||
|
caps []Cap
|
||||||
|
listenAddr *peerAddr // what remote peer is listening on
|
||||||
|
dialAddr *peerAddr // non-nil if dialing
|
||||||
|
|
||||||
|
// The mutex protects the connection
|
||||||
|
// so only one protocol can write at a time.
|
||||||
|
writeMu sync.Mutex
|
||||||
|
conn net.Conn
|
||||||
|
bufconn *bufio.ReadWriter
|
||||||
|
|
||||||
|
// These fields maintain the running protocols.
|
||||||
|
protocols []Protocol
|
||||||
|
runBaseProtocol bool // for testing
|
||||||
|
|
||||||
|
runlock sync.RWMutex // protects running
|
||||||
|
running map[string]*proto
|
||||||
|
|
||||||
|
protoWG sync.WaitGroup
|
||||||
|
protoErr chan error
|
||||||
|
closed chan struct{}
|
||||||
|
disc chan DiscReason
|
||||||
|
|
||||||
|
activity event.TypeMux // for activity events
|
||||||
|
|
||||||
|
slot int // index into Server peer list
|
||||||
|
|
||||||
|
// These fields are kept so base protocol can access them.
|
||||||
|
// TODO: this should be one or more interfaces
|
||||||
|
ourID ClientIdentity // client id of the Server
|
||||||
|
ourListenAddr *peerAddr // listen addr of Server, nil if not listening
|
||||||
|
newPeerAddr chan<- *peerAddr // tell server about received peers
|
||||||
|
otherPeers func() []*Peer // should return the list of all peers
|
||||||
|
pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPeer returns a peer for testing purposes.
|
||||||
|
func NewPeer(id ClientIdentity, caps []Cap) *Peer {
|
||||||
|
conn, _ := net.Pipe()
|
||||||
|
peer := newPeer(conn, nil, nil)
|
||||||
|
peer.setHandshakeInfo(id, nil, caps)
|
||||||
|
close(peer.closed)
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Peer) String() string {
|
func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
|
||||||
var kind string
|
p := newPeer(conn, server.Protocols, dialAddr)
|
||||||
if self.Inbound {
|
p.ourID = server.Identity
|
||||||
kind = "inbound"
|
p.newPeerAddr = server.peerConnect
|
||||||
} else {
|
p.otherPeers = server.Peers
|
||||||
|
p.pubkeyHook = server.verifyPeer
|
||||||
|
p.runBaseProtocol = true
|
||||||
|
|
||||||
|
// laddr can be updated concurrently by NAT traversal.
|
||||||
|
// newServerPeer must be called with the server lock held.
|
||||||
|
if server.laddr != nil {
|
||||||
|
p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer {
|
||||||
|
p := &Peer{
|
||||||
|
Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()),
|
||||||
|
conn: conn,
|
||||||
|
dialAddr: dialAddr,
|
||||||
|
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
|
||||||
|
protocols: protocols,
|
||||||
|
running: make(map[string]*proto),
|
||||||
|
disc: make(chan DiscReason),
|
||||||
|
protoErr: make(chan error),
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identity returns the client identity of the remote peer. The
|
||||||
|
// identity can be nil if the peer has not yet completed the
|
||||||
|
// handshake.
|
||||||
|
func (p *Peer) Identity() ClientIdentity {
|
||||||
|
p.infolock.Lock()
|
||||||
|
defer p.infolock.Unlock()
|
||||||
|
return p.identity
|
||||||
|
}
|
||||||
|
|
||||||
|
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
||||||
|
func (p *Peer) Caps() []Cap {
|
||||||
|
p.infolock.Lock()
|
||||||
|
defer p.infolock.Unlock()
|
||||||
|
return p.caps
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) {
|
||||||
|
p.infolock.Lock()
|
||||||
|
p.identity = id
|
||||||
|
p.listenAddr = laddr
|
||||||
|
p.caps = caps
|
||||||
|
p.infolock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr returns the remote address of the network connection.
|
||||||
|
func (p *Peer) RemoteAddr() net.Addr {
|
||||||
|
return p.conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr returns the local address of the network connection.
|
||||||
|
func (p *Peer) LocalAddr() net.Addr {
|
||||||
|
return p.conn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disconnect terminates the peer connection with the given reason.
|
||||||
|
// It returns immediately and does not wait until the connection is closed.
|
||||||
|
func (p *Peer) Disconnect(reason DiscReason) {
|
||||||
|
select {
|
||||||
|
case p.disc <- reason:
|
||||||
|
case <-p.closed:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements fmt.Stringer.
|
||||||
|
func (p *Peer) String() string {
|
||||||
|
kind := "inbound"
|
||||||
|
p.infolock.Lock()
|
||||||
|
if p.dialAddr != nil {
|
||||||
kind = "outbound"
|
kind = "outbound"
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps)
|
p.infolock.Unlock()
|
||||||
|
return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Peer) Write(protocol string, msg *Msg) error {
|
const (
|
||||||
return self.messenger.Write(protocol, msg)
|
// maximum amount of time allowed for reading a message
|
||||||
|
msgReadTimeout = 5 * time.Second
|
||||||
|
// maximum amount of time allowed for writing a message
|
||||||
|
msgWriteTimeout = 5 * time.Second
|
||||||
|
// messages smaller than this many bytes will be read at
|
||||||
|
// once before passing them to a protocol.
|
||||||
|
wholePayloadSize = 64 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
inactivityTimeout = 2 * time.Second
|
||||||
|
disconnectGracePeriod = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *Peer) loop() (reason DiscReason, err error) {
|
||||||
|
defer p.activity.Stop()
|
||||||
|
defer p.closeProtocols()
|
||||||
|
defer close(p.closed)
|
||||||
|
defer p.conn.Close()
|
||||||
|
|
||||||
|
// read loop
|
||||||
|
readMsg := make(chan Msg)
|
||||||
|
readErr := make(chan error)
|
||||||
|
readNext := make(chan bool, 1)
|
||||||
|
protoDone := make(chan struct{}, 1)
|
||||||
|
go p.readLoop(readMsg, readErr, readNext)
|
||||||
|
readNext <- true
|
||||||
|
|
||||||
|
if p.runBaseProtocol {
|
||||||
|
p.startBaseProtocol()
|
||||||
|
}
|
||||||
|
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg := <-readMsg:
|
||||||
|
// a new message has arrived.
|
||||||
|
var wait bool
|
||||||
|
if wait, err = p.dispatch(msg, protoDone); err != nil {
|
||||||
|
p.Errorf("msg dispatch error: %v\n", err)
|
||||||
|
reason = discReasonForError(err)
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
if !wait {
|
||||||
|
// Msg has already been read completely, continue with next message.
|
||||||
|
readNext <- true
|
||||||
|
}
|
||||||
|
p.activity.Post(time.Now())
|
||||||
|
case <-protoDone:
|
||||||
|
// protocol has consumed the message payload,
|
||||||
|
// we can continue reading from the socket.
|
||||||
|
readNext <- true
|
||||||
|
|
||||||
|
case err := <-readErr:
|
||||||
|
// read failed. there is no need to run the
|
||||||
|
// polite disconnect sequence because the connection
|
||||||
|
// is probably dead anyway.
|
||||||
|
// TODO: handle write errors as well
|
||||||
|
return DiscNetworkError, err
|
||||||
|
case err = <-p.protoErr:
|
||||||
|
reason = discReasonForError(err)
|
||||||
|
break loop
|
||||||
|
case reason = <-p.disc:
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for read loop to return.
|
||||||
|
close(readNext)
|
||||||
|
<-readErr
|
||||||
|
// tell the remote end to disconnect
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod))
|
||||||
|
p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod)
|
||||||
|
io.Copy(ioutil.Discard, p.conn)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(disconnectGracePeriod):
|
||||||
|
}
|
||||||
|
return reason, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Peer) Start() {
|
func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) {
|
||||||
self.peerErrorHandler.Start()
|
for _ = range unblock {
|
||||||
self.messenger.Start()
|
p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
|
||||||
|
if msg, err := readMsg(p.bufconn); err != nil {
|
||||||
|
errc <- err
|
||||||
|
} else {
|
||||||
|
msgc <- msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(errc)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Peer) Stop() {
|
func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) {
|
||||||
self.peerErrorHandler.Stop()
|
proto, err := p.getProto(msg.Code)
|
||||||
self.messenger.Stop()
|
if err != nil {
|
||||||
// q := make(chan bool)
|
return false, err
|
||||||
// self.quit <- q
|
}
|
||||||
// <-q
|
if msg.Size <= wholePayloadSize {
|
||||||
|
// optimization: msg is small enough, read all
|
||||||
|
// of it and move on to the next message
|
||||||
|
buf, err := ioutil.ReadAll(msg.Payload)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
msg.Payload = bytes.NewReader(buf)
|
||||||
|
proto.in <- msg
|
||||||
|
} else {
|
||||||
|
wait = true
|
||||||
|
pr := &eofSignal{msg.Payload, protoDone}
|
||||||
|
msg.Payload = pr
|
||||||
|
proto.in <- msg
|
||||||
|
}
|
||||||
|
return wait, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) Encode() []interface{} {
|
func (p *Peer) startBaseProtocol() {
|
||||||
return []interface{}{p.Host, p.Port, p.Pubkey}
|
p.runlock.Lock()
|
||||||
|
defer p.runlock.Unlock()
|
||||||
|
p.running[""] = p.startProto(0, Protocol{
|
||||||
|
Length: baseProtocolLength,
|
||||||
|
Run: runBaseProtocol,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// startProtocols starts matching named subprotocols.
|
||||||
|
func (p *Peer) startSubprotocols(caps []Cap) {
|
||||||
|
sort.Sort(capsByName(caps))
|
||||||
|
|
||||||
|
p.runlock.Lock()
|
||||||
|
defer p.runlock.Unlock()
|
||||||
|
offset := baseProtocolLength
|
||||||
|
outer:
|
||||||
|
for _, cap := range caps {
|
||||||
|
for _, proto := range p.protocols {
|
||||||
|
if proto.Name == cap.Name &&
|
||||||
|
proto.Version == cap.Version &&
|
||||||
|
p.running[cap.Name] == nil {
|
||||||
|
p.running[cap.Name] = p.startProto(offset, proto)
|
||||||
|
offset += proto.Length
|
||||||
|
continue outer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
|
||||||
|
rw := &proto{
|
||||||
|
in: make(chan Msg),
|
||||||
|
offset: offset,
|
||||||
|
maxcode: impl.Length,
|
||||||
|
peer: p,
|
||||||
|
}
|
||||||
|
p.protoWG.Add(1)
|
||||||
|
go func() {
|
||||||
|
err := impl.Run(p, rw)
|
||||||
|
if err == nil {
|
||||||
|
p.Infof("protocol %q returned", impl.Name)
|
||||||
|
err = newPeerError(errMisc, "protocol returned")
|
||||||
|
} else {
|
||||||
|
p.Errorf("protocol %q error: %v\n", impl.Name, err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case p.protoErr <- err:
|
||||||
|
case <-p.closed:
|
||||||
|
}
|
||||||
|
p.protoWG.Done()
|
||||||
|
}()
|
||||||
|
return rw
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProto finds the protocol responsible for handling
|
||||||
|
// the given message code.
|
||||||
|
func (p *Peer) getProto(code uint64) (*proto, error) {
|
||||||
|
p.runlock.RLock()
|
||||||
|
defer p.runlock.RUnlock()
|
||||||
|
for _, proto := range p.running {
|
||||||
|
if code >= proto.offset && code < proto.offset+proto.maxcode {
|
||||||
|
return proto, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, newPeerError(errInvalidMsgCode, "%d", code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) closeProtocols() {
|
||||||
|
p.runlock.RLock()
|
||||||
|
for _, p := range p.running {
|
||||||
|
close(p.in)
|
||||||
|
}
|
||||||
|
p.runlock.RUnlock()
|
||||||
|
p.protoWG.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
||||||
|
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
|
||||||
|
p.runlock.RLock()
|
||||||
|
proto, ok := p.running[protoName]
|
||||||
|
p.runlock.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("protocol %s not handled by peer", protoName)
|
||||||
|
}
|
||||||
|
if msg.Code >= proto.maxcode {
|
||||||
|
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
|
||||||
|
}
|
||||||
|
msg.Code += proto.offset
|
||||||
|
return p.writeMsg(msg, msgWriteTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeMsg writes a message to the connection.
|
||||||
|
func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
|
||||||
|
p.writeMu.Lock()
|
||||||
|
defer p.writeMu.Unlock()
|
||||||
|
p.conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||||
|
if err := writeMsg(p.bufconn, msg); err != nil {
|
||||||
|
return newPeerError(errWrite, "%v", err)
|
||||||
|
}
|
||||||
|
return p.bufconn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
type proto struct {
|
||||||
|
name string
|
||||||
|
in chan Msg
|
||||||
|
maxcode, offset uint64
|
||||||
|
peer *Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *proto) WriteMsg(msg Msg) error {
|
||||||
|
if msg.Code >= rw.maxcode {
|
||||||
|
return newPeerError(errInvalidMsgCode, "not handled")
|
||||||
|
}
|
||||||
|
msg.Code += rw.offset
|
||||||
|
return rw.peer.writeMsg(msg, msgWriteTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
|
||||||
|
return rw.WriteMsg(NewMsg(code, data))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *proto) ReadMsg() (Msg, error) {
|
||||||
|
msg, ok := <-rw.in
|
||||||
|
if !ok {
|
||||||
|
return msg, io.EOF
|
||||||
|
}
|
||||||
|
msg.Code -= rw.offset
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// eofSignal wraps a reader with eof signaling.
|
||||||
|
// the eof channel is closed when the wrapped reader
|
||||||
|
// reaches EOF.
|
||||||
|
type eofSignal struct {
|
||||||
|
wrapped io.Reader
|
||||||
|
eof chan<- struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *eofSignal) Read(buf []byte) (int, error) {
|
||||||
|
n, err := r.wrapped.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
r.eof <- struct{}{} // tell Peer that msg has been consumed
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
}
|
}
|
||||||
|
@ -4,73 +4,121 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ErrorCode int
|
|
||||||
|
|
||||||
const errorChanCapacity = 10
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
PacketTooShort = iota
|
errMagicTokenMismatch = iota
|
||||||
PayloadTooShort
|
errRead
|
||||||
MagicTokenMismatch
|
errWrite
|
||||||
EmptyPayload
|
errMisc
|
||||||
ReadError
|
errInvalidMsgCode
|
||||||
WriteError
|
errInvalidMsg
|
||||||
MiscError
|
errP2PVersionMismatch
|
||||||
InvalidMsgCode
|
errPubkeyMissing
|
||||||
InvalidMsg
|
errPubkeyInvalid
|
||||||
P2PVersionMismatch
|
errPubkeyForbidden
|
||||||
PubkeyMissing
|
errProtocolBreach
|
||||||
PubkeyInvalid
|
errPingTimeout
|
||||||
PubkeyForbidden
|
errInvalidNetworkId
|
||||||
ProtocolBreach
|
errInvalidProtocolVersion
|
||||||
PortMismatch
|
|
||||||
PingTimeout
|
|
||||||
InvalidGenesis
|
|
||||||
InvalidNetworkId
|
|
||||||
InvalidProtocolVersion
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var errorToString = map[ErrorCode]string{
|
var errorToString = map[int]string{
|
||||||
PacketTooShort: "Packet too short",
|
errMagicTokenMismatch: "Magic token mismatch",
|
||||||
PayloadTooShort: "Payload too short",
|
errRead: "Read error",
|
||||||
MagicTokenMismatch: "Magic token mismatch",
|
errWrite: "Write error",
|
||||||
EmptyPayload: "Empty payload",
|
errMisc: "Misc error",
|
||||||
ReadError: "Read error",
|
errInvalidMsgCode: "Invalid message code",
|
||||||
WriteError: "Write error",
|
errInvalidMsg: "Invalid message",
|
||||||
MiscError: "Misc error",
|
errP2PVersionMismatch: "P2P Version Mismatch",
|
||||||
InvalidMsgCode: "Invalid message code",
|
errPubkeyMissing: "Public key missing",
|
||||||
InvalidMsg: "Invalid message",
|
errPubkeyInvalid: "Public key invalid",
|
||||||
P2PVersionMismatch: "P2P Version Mismatch",
|
errPubkeyForbidden: "Public key forbidden",
|
||||||
PubkeyMissing: "Public key missing",
|
errProtocolBreach: "Protocol Breach",
|
||||||
PubkeyInvalid: "Public key invalid",
|
errPingTimeout: "Ping timeout",
|
||||||
PubkeyForbidden: "Public key forbidden",
|
errInvalidNetworkId: "Invalid network id",
|
||||||
ProtocolBreach: "Protocol Breach",
|
errInvalidProtocolVersion: "Invalid protocol version",
|
||||||
PortMismatch: "Port mismatch",
|
|
||||||
PingTimeout: "Ping timeout",
|
|
||||||
InvalidGenesis: "Invalid genesis block",
|
|
||||||
InvalidNetworkId: "Invalid network id",
|
|
||||||
InvalidProtocolVersion: "Invalid protocol version",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PeerError struct {
|
type peerError struct {
|
||||||
Code ErrorCode
|
Code int
|
||||||
message string
|
message string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError {
|
func newPeerError(code int, format string, v ...interface{}) *peerError {
|
||||||
desc, ok := errorToString[code]
|
desc, ok := errorToString[code]
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("invalid error code")
|
panic("invalid error code")
|
||||||
}
|
}
|
||||||
format = desc + ": " + format
|
err := &peerError{code, desc}
|
||||||
message := fmt.Sprintf(format, v...)
|
if format != "" {
|
||||||
return &PeerError{code, message}
|
err.message += ": " + fmt.Sprintf(format, v...)
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *PeerError) Error() string {
|
func (self *peerError) Error() string {
|
||||||
return self.message
|
return self.message
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPeerErrorChannel() chan *PeerError {
|
type DiscReason byte
|
||||||
return make(chan *PeerError, errorChanCapacity)
|
|
||||||
|
const (
|
||||||
|
DiscRequested DiscReason = 0x00
|
||||||
|
DiscNetworkError = 0x01
|
||||||
|
DiscProtocolError = 0x02
|
||||||
|
DiscUselessPeer = 0x03
|
||||||
|
DiscTooManyPeers = 0x04
|
||||||
|
DiscAlreadyConnected = 0x05
|
||||||
|
DiscIncompatibleVersion = 0x06
|
||||||
|
DiscInvalidIdentity = 0x07
|
||||||
|
DiscQuitting = 0x08
|
||||||
|
DiscUnexpectedIdentity = 0x09
|
||||||
|
DiscSelf = 0x0a
|
||||||
|
DiscReadTimeout = 0x0b
|
||||||
|
DiscSubprotocolError = 0x10
|
||||||
|
)
|
||||||
|
|
||||||
|
var discReasonToString = [DiscSubprotocolError + 1]string{
|
||||||
|
DiscRequested: "Disconnect requested",
|
||||||
|
DiscNetworkError: "Network error",
|
||||||
|
DiscProtocolError: "Breach of protocol",
|
||||||
|
DiscUselessPeer: "Useless peer",
|
||||||
|
DiscTooManyPeers: "Too many peers",
|
||||||
|
DiscAlreadyConnected: "Already connected",
|
||||||
|
DiscIncompatibleVersion: "Incompatible P2P protocol version",
|
||||||
|
DiscInvalidIdentity: "Invalid node identity",
|
||||||
|
DiscQuitting: "Client quitting",
|
||||||
|
DiscUnexpectedIdentity: "Unexpected identity",
|
||||||
|
DiscSelf: "Connected to self",
|
||||||
|
DiscReadTimeout: "Read timeout",
|
||||||
|
DiscSubprotocolError: "Subprotocol error",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d DiscReason) String() string {
|
||||||
|
if len(discReasonToString) < int(d) {
|
||||||
|
return fmt.Sprintf("Unknown Reason(%d)", d)
|
||||||
|
}
|
||||||
|
return discReasonToString[d]
|
||||||
|
}
|
||||||
|
|
||||||
|
func discReasonForError(err error) DiscReason {
|
||||||
|
peerError, ok := err.(*peerError)
|
||||||
|
if !ok {
|
||||||
|
return DiscSubprotocolError
|
||||||
|
}
|
||||||
|
switch peerError.Code {
|
||||||
|
case errP2PVersionMismatch:
|
||||||
|
return DiscIncompatibleVersion
|
||||||
|
case errPubkeyMissing, errPubkeyInvalid:
|
||||||
|
return DiscInvalidIdentity
|
||||||
|
case errPubkeyForbidden:
|
||||||
|
return DiscUselessPeer
|
||||||
|
case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
|
||||||
|
return DiscProtocolError
|
||||||
|
case errPingTimeout:
|
||||||
|
return DiscReadTimeout
|
||||||
|
case errRead, errWrite, errMisc:
|
||||||
|
return DiscNetworkError
|
||||||
|
default:
|
||||||
|
return DiscSubprotocolError
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,101 +0,0 @@
|
|||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
severityThreshold = 10
|
|
||||||
)
|
|
||||||
|
|
||||||
type DisconnectRequest struct {
|
|
||||||
addr net.Addr
|
|
||||||
reason DiscReason
|
|
||||||
}
|
|
||||||
|
|
||||||
type PeerErrorHandler struct {
|
|
||||||
quit chan chan bool
|
|
||||||
address net.Addr
|
|
||||||
peerDisconnect chan DisconnectRequest
|
|
||||||
severity int
|
|
||||||
peerErrorChan chan *PeerError
|
|
||||||
blacklist Blacklist
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler {
|
|
||||||
return &PeerErrorHandler{
|
|
||||||
quit: make(chan chan bool),
|
|
||||||
address: address,
|
|
||||||
peerDisconnect: peerDisconnect,
|
|
||||||
peerErrorChan: peerErrorChan,
|
|
||||||
blacklist: blacklist,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *PeerErrorHandler) Start() {
|
|
||||||
go self.listen()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *PeerErrorHandler) Stop() {
|
|
||||||
q := make(chan bool)
|
|
||||||
self.quit <- q
|
|
||||||
<-q
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *PeerErrorHandler) listen() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case peerError, ok := <-self.peerErrorChan:
|
|
||||||
if ok {
|
|
||||||
logger.Debugf("error %v\n", peerError)
|
|
||||||
go self.handle(peerError)
|
|
||||||
} else {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case q := <-self.quit:
|
|
||||||
q <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *PeerErrorHandler) handle(peerError *PeerError) {
|
|
||||||
reason := DiscReason(' ')
|
|
||||||
switch peerError.Code {
|
|
||||||
case P2PVersionMismatch:
|
|
||||||
reason = DiscIncompatibleVersion
|
|
||||||
case PubkeyMissing, PubkeyInvalid:
|
|
||||||
reason = DiscInvalidIdentity
|
|
||||||
case PubkeyForbidden:
|
|
||||||
reason = DiscUselessPeer
|
|
||||||
case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach:
|
|
||||||
reason = DiscProtocolError
|
|
||||||
case PingTimeout:
|
|
||||||
reason = DiscReadTimeout
|
|
||||||
case WriteError, MiscError:
|
|
||||||
reason = DiscNetworkError
|
|
||||||
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
|
|
||||||
reason = DiscSubprotocolError
|
|
||||||
default:
|
|
||||||
self.severity += self.getSeverity(peerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.severity >= severityThreshold {
|
|
||||||
reason = DiscSubprotocolError
|
|
||||||
}
|
|
||||||
if reason != DiscReason(' ') {
|
|
||||||
self.peerDisconnect <- DisconnectRequest{
|
|
||||||
addr: self.address,
|
|
||||||
reason: reason,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
|
|
||||||
switch peerError.Code {
|
|
||||||
case ReadError:
|
|
||||||
return 4 //tolerate 3 :)
|
|
||||||
default:
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,34 +0,0 @@
|
|||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
// "fmt"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPeerErrorHandler(t *testing.T) {
|
|
||||||
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
|
|
||||||
peerDisconnect := make(chan DisconnectRequest)
|
|
||||||
peerErrorChan := NewPeerErrorChannel()
|
|
||||||
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist())
|
|
||||||
peh.Start()
|
|
||||||
defer peh.Stop()
|
|
||||||
for i := 0; i < 11; i++ {
|
|
||||||
select {
|
|
||||||
case <-peerDisconnect:
|
|
||||||
t.Errorf("expected no disconnect request")
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
peerErrorChan <- NewPeerError(MiscError, "")
|
|
||||||
}
|
|
||||||
time.Sleep(1 * time.Millisecond)
|
|
||||||
select {
|
|
||||||
case request := <-peerDisconnect:
|
|
||||||
if request.addr.String() != address.String() {
|
|
||||||
t.Errorf("incorrect address %v != %v", request.addr, address)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
t.Errorf("expected disconnect request")
|
|
||||||
}
|
|
||||||
}
|
|
303
p2p/peer_test.go
303
p2p/peer_test.go
@ -1,96 +1,239 @@
|
|||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"encoding/hex"
|
||||||
// "net"
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPeer(t *testing.T) {
|
var discard = Protocol{
|
||||||
handlers := make(Handlers)
|
Name: "discard",
|
||||||
testProtocol := &TestProtocol{Msgs: []*Msg{}}
|
Length: 1,
|
||||||
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
|
Run: func(p *Peer, rw MsgReadWriter) error {
|
||||||
handlers["ccc"] = func(p *Peer) Protocol { return testProtocol }
|
for {
|
||||||
addr := &TestAddr{"test:30"}
|
msg, err := rw.ReadMsg()
|
||||||
conn := NewTestNetworkConnection(addr)
|
if err != nil {
|
||||||
_, server := SetupTestServer(handlers)
|
return err
|
||||||
server.Handshake()
|
}
|
||||||
peer := NewPeer(conn, addr, true, server)
|
if err = msg.Discard(); err != nil {
|
||||||
// peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
|
return err
|
||||||
peer.Start()
|
|
||||||
defer peer.Stop()
|
|
||||||
time.Sleep(2 * time.Millisecond)
|
|
||||||
if len(conn.Out) != 1 {
|
|
||||||
t.Errorf("handshake not sent")
|
|
||||||
} else {
|
|
||||||
out := conn.Out[0]
|
|
||||||
packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
|
|
||||||
if bytes.Compare(out, packet) != 0 {
|
|
||||||
t.Errorf("incorrect handshake packet %v != %v", out, packet)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
|
|
||||||
conn.In(0, packet)
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
|
|
||||||
pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
|
|
||||||
if pro.state != handshakeReceived {
|
|
||||||
t.Errorf("handshake not received")
|
|
||||||
}
|
|
||||||
if peer.Port != 30 {
|
|
||||||
t.Errorf("port incorrectly set")
|
|
||||||
}
|
|
||||||
if peer.Id != "peer" {
|
|
||||||
t.Errorf("id incorrectly set")
|
|
||||||
}
|
|
||||||
if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
|
|
||||||
t.Errorf("pubkey incorrectly set")
|
|
||||||
}
|
|
||||||
fmt.Println(peer.Caps)
|
|
||||||
if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
|
|
||||||
t.Errorf("protocols incorrectly set")
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, _ := NewMsg(3)
|
|
||||||
err := peer.Write("aaa", msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("expect no error for known protocol: %v", err)
|
|
||||||
} else {
|
|
||||||
time.Sleep(1 * time.Millisecond)
|
|
||||||
if len(conn.Out) != 2 {
|
|
||||||
t.Errorf("msg not written")
|
|
||||||
} else {
|
|
||||||
out := conn.Out[1]
|
|
||||||
packet := Packet(16, 3)
|
|
||||||
if bytes.Compare(out, packet) != 0 {
|
|
||||||
t.Errorf("incorrect packet %v != %v", out, packet)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
}
|
||||||
|
|
||||||
msg, _ = NewMsg(2)
|
func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
|
||||||
err = peer.Write("ccc", msg)
|
conn1, conn2 := net.Pipe()
|
||||||
if err != nil {
|
id := NewSimpleClientIdentity("test", "0", "0", "public key")
|
||||||
t.Errorf("expect no error for known protocol: %v", err)
|
peer := newPeer(conn1, protos, nil)
|
||||||
} else {
|
peer.ourID = id
|
||||||
time.Sleep(1 * time.Millisecond)
|
peer.pubkeyHook = func(*peerAddr) error { return nil }
|
||||||
if len(conn.Out) != 3 {
|
errc := make(chan error, 1)
|
||||||
t.Errorf("msg not written")
|
go func() {
|
||||||
} else {
|
_, err := peer.loop()
|
||||||
out := conn.Out[2]
|
errc <- err
|
||||||
packet := Packet(21, 2)
|
}()
|
||||||
if bytes.Compare(out, packet) != 0 {
|
return conn2, peer, errc
|
||||||
t.Errorf("incorrect packet %v != %v", out, packet)
|
}
|
||||||
|
|
||||||
|
func TestPeerProtoReadMsg(t *testing.T) {
|
||||||
|
defer testlog(t).detach()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
proto := Protocol{
|
||||||
|
Name: "a",
|
||||||
|
Length: 5,
|
||||||
|
Run: func(peer *Peer, rw MsgReadWriter) error {
|
||||||
|
msg, err := rw.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("read error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
if msg.Code != 2 {
|
||||||
|
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
|
||||||
|
}
|
||||||
|
data, err := ioutil.ReadAll(msg.Payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("payload read error: %v", err)
|
||||||
|
}
|
||||||
|
expdata, _ := hex.DecodeString("0183303030")
|
||||||
|
if !bytes.Equal(expdata, data) {
|
||||||
|
t.Errorf("incorrect msg data %x", data)
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err = peer.Write("bbb", msg)
|
net, peer, errc := testPeer([]Protocol{proto})
|
||||||
time.Sleep(1 * time.Millisecond)
|
defer net.Close()
|
||||||
if err == nil {
|
peer.startSubprotocols([]Cap{proto.cap()})
|
||||||
t.Errorf("expect error for unknown protocol")
|
|
||||||
|
writeMsg(net, NewMsg(18, 1, "000"))
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case err := <-errc:
|
||||||
|
t.Errorf("peer returned: %v", err)
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Errorf("receive timeout")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPeerProtoReadLargeMsg(t *testing.T) {
|
||||||
|
defer testlog(t).detach()
|
||||||
|
|
||||||
|
msgsize := uint32(10 * 1024 * 1024)
|
||||||
|
done := make(chan struct{})
|
||||||
|
proto := Protocol{
|
||||||
|
Name: "a",
|
||||||
|
Length: 5,
|
||||||
|
Run: func(peer *Peer, rw MsgReadWriter) error {
|
||||||
|
msg, err := rw.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("read error: %v", err)
|
||||||
|
}
|
||||||
|
if msg.Size != msgsize+4 {
|
||||||
|
t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
|
||||||
|
}
|
||||||
|
msg.Discard()
|
||||||
|
close(done)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
net, peer, errc := testPeer([]Protocol{proto})
|
||||||
|
defer net.Close()
|
||||||
|
peer.startSubprotocols([]Cap{proto.cap()})
|
||||||
|
|
||||||
|
writeMsg(net, NewMsg(18, make([]byte, msgsize)))
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case err := <-errc:
|
||||||
|
t.Errorf("peer returned: %v", err)
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Errorf("receive timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerProtoEncodeMsg(t *testing.T) {
|
||||||
|
defer testlog(t).detach()
|
||||||
|
|
||||||
|
proto := Protocol{
|
||||||
|
Name: "a",
|
||||||
|
Length: 2,
|
||||||
|
Run: func(peer *Peer, rw MsgReadWriter) error {
|
||||||
|
if err := rw.EncodeMsg(2); err == nil {
|
||||||
|
t.Error("expected error for out-of-range msg code, got nil")
|
||||||
|
}
|
||||||
|
if err := rw.EncodeMsg(1); err != nil {
|
||||||
|
t.Errorf("write error: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
net, peer, _ := testPeer([]Protocol{proto})
|
||||||
|
defer net.Close()
|
||||||
|
peer.startSubprotocols([]Cap{proto.cap()})
|
||||||
|
|
||||||
|
bufr := bufio.NewReader(net)
|
||||||
|
msg, err := readMsg(bufr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("read error: %v", err)
|
||||||
|
}
|
||||||
|
if msg.Code != 17 {
|
||||||
|
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerWrite(t *testing.T) {
|
||||||
|
defer testlog(t).detach()
|
||||||
|
|
||||||
|
net, peer, peerErr := testPeer([]Protocol{discard})
|
||||||
|
defer net.Close()
|
||||||
|
peer.startSubprotocols([]Cap{discard.cap()})
|
||||||
|
|
||||||
|
// test write errors
|
||||||
|
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
|
||||||
|
t.Errorf("expected error for unknown protocol, got nil")
|
||||||
|
}
|
||||||
|
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
|
||||||
|
t.Errorf("expected error for out-of-range msg code, got nil")
|
||||||
|
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
|
||||||
|
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// setup for reading the message on the other end
|
||||||
|
read := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
bufr := bufio.NewReader(net)
|
||||||
|
msg, err := readMsg(bufr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("read error: %v", err)
|
||||||
|
} else if msg.Code != 16 {
|
||||||
|
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
|
||||||
|
}
|
||||||
|
msg.Discard()
|
||||||
|
close(read)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// test succcessful write
|
||||||
|
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
|
||||||
|
t.Errorf("expect no error for known protocol: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-read:
|
||||||
|
case err := <-peerErr:
|
||||||
|
t.Fatalf("peer stopped: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerActivity(t *testing.T) {
|
||||||
|
// shorten inactivityTimeout while this test is running
|
||||||
|
oldT := inactivityTimeout
|
||||||
|
defer func() { inactivityTimeout = oldT }()
|
||||||
|
inactivityTimeout = 20 * time.Millisecond
|
||||||
|
|
||||||
|
net, peer, peerErr := testPeer([]Protocol{discard})
|
||||||
|
defer net.Close()
|
||||||
|
peer.startSubprotocols([]Cap{discard.cap()})
|
||||||
|
|
||||||
|
sub := peer.activity.Subscribe(time.Time{})
|
||||||
|
defer sub.Unsubscribe()
|
||||||
|
|
||||||
|
for i := 0; i < 6; i++ {
|
||||||
|
writeMsg(net, NewMsg(16))
|
||||||
|
select {
|
||||||
|
case <-sub.Chan():
|
||||||
|
case <-time.After(inactivityTimeout / 2):
|
||||||
|
t.Fatal("no event within ", inactivityTimeout/2)
|
||||||
|
case err := <-peerErr:
|
||||||
|
t.Fatal("peer error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(inactivityTimeout * 2):
|
||||||
|
case <-sub.Chan():
|
||||||
|
t.Fatal("got activity event while connection was inactive")
|
||||||
|
case err := <-peerErr:
|
||||||
|
t.Fatal("peer error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPeer(t *testing.T) {
|
||||||
|
id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey")
|
||||||
|
caps := []Cap{{"foo", 2}, {"bar", 3}}
|
||||||
|
p := NewPeer(id, caps)
|
||||||
|
if !reflect.DeepEqual(p.Caps(), caps) {
|
||||||
|
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
|
||||||
|
}
|
||||||
|
if p.Identity() != id {
|
||||||
|
t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id)
|
||||||
|
}
|
||||||
|
// Should not hang.
|
||||||
|
p.Disconnect(DiscAlreadyConnected)
|
||||||
|
}
|
||||||
|
501
p2p/protocol.go
501
p2p/protocol.go
@ -2,277 +2,294 @@ package p2p
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sort"
|
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/ethutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Protocol interface {
|
// Protocol represents a P2P subprotocol implementation.
|
||||||
Start()
|
type Protocol struct {
|
||||||
Stop()
|
// Name should contain the official protocol name,
|
||||||
HandleIn(*Msg, chan *Msg)
|
// often a three-letter word.
|
||||||
HandleOut(*Msg) bool
|
Name string
|
||||||
Offset() MsgCode
|
|
||||||
Name() string
|
// Version should contain the version number of the protocol.
|
||||||
|
Version uint
|
||||||
|
|
||||||
|
// Length should contain the number of message codes used
|
||||||
|
// by the protocol.
|
||||||
|
Length uint64
|
||||||
|
|
||||||
|
// Run is called in a new groutine when the protocol has been
|
||||||
|
// negotiated with a peer. It should read and write messages from
|
||||||
|
// rw. The Payload for each message must be fully consumed.
|
||||||
|
//
|
||||||
|
// The peer connection is closed when Start returns. It should return
|
||||||
|
// any protocol-level error (such as an I/O error) that is
|
||||||
|
// encountered.
|
||||||
|
Run func(peer *Peer, rw MsgReadWriter) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Protocol) cap() Cap {
|
||||||
|
return Cap{p.Name, p.Version}
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
P2PVersion = 0
|
baseProtocolVersion = 2
|
||||||
pingTimeout = 2
|
baseProtocolLength = uint64(16)
|
||||||
pingGracePeriod = 2
|
baseProtocolMaxMsgSize = 10 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
HandshakeMsg = iota
|
// devp2p message codes
|
||||||
DiscMsg
|
handshakeMsg = 0x00
|
||||||
PingMsg
|
discMsg = 0x01
|
||||||
PongMsg
|
pingMsg = 0x02
|
||||||
GetPeersMsg
|
pongMsg = 0x03
|
||||||
PeersMsg
|
getPeersMsg = 0x04
|
||||||
offset = 16
|
peersMsg = 0x05
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProtocolState uint8
|
// handshake is the structure of a handshake list.
|
||||||
|
type handshake struct {
|
||||||
const (
|
Version uint64
|
||||||
nullState = iota
|
ID string
|
||||||
handshakeReceived
|
Caps []Cap
|
||||||
)
|
ListenPort uint64
|
||||||
|
NodeID []byte
|
||||||
type DiscReason byte
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Values are given explicitly instead of by iota because these values are
|
|
||||||
// defined by the wire protocol spec; it is easier for humans to ensure
|
|
||||||
// correctness when values are explicit.
|
|
||||||
DiscRequested = 0x00
|
|
||||||
DiscNetworkError = 0x01
|
|
||||||
DiscProtocolError = 0x02
|
|
||||||
DiscUselessPeer = 0x03
|
|
||||||
DiscTooManyPeers = 0x04
|
|
||||||
DiscAlreadyConnected = 0x05
|
|
||||||
DiscIncompatibleVersion = 0x06
|
|
||||||
DiscInvalidIdentity = 0x07
|
|
||||||
DiscQuitting = 0x08
|
|
||||||
DiscUnexpectedIdentity = 0x09
|
|
||||||
DiscSelf = 0x0a
|
|
||||||
DiscReadTimeout = 0x0b
|
|
||||||
DiscSubprotocolError = 0x10
|
|
||||||
)
|
|
||||||
|
|
||||||
var discReasonToString = map[DiscReason]string{
|
|
||||||
DiscRequested: "Disconnect requested",
|
|
||||||
DiscNetworkError: "Network error",
|
|
||||||
DiscProtocolError: "Breach of protocol",
|
|
||||||
DiscUselessPeer: "Useless peer",
|
|
||||||
DiscTooManyPeers: "Too many peers",
|
|
||||||
DiscAlreadyConnected: "Already connected",
|
|
||||||
DiscIncompatibleVersion: "Incompatible P2P protocol version",
|
|
||||||
DiscInvalidIdentity: "Invalid node identity",
|
|
||||||
DiscQuitting: "Client quitting",
|
|
||||||
DiscUnexpectedIdentity: "Unexpected identity",
|
|
||||||
DiscSelf: "Connected to self",
|
|
||||||
DiscReadTimeout: "Read timeout",
|
|
||||||
DiscSubprotocolError: "Subprotocol error",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d DiscReason) String() string {
|
func (h *handshake) String() string {
|
||||||
if len(discReasonToString) < int(d) {
|
return h.ID
|
||||||
return "Unknown"
|
}
|
||||||
|
func (h *handshake) Pubkey() []byte {
|
||||||
|
return h.NodeID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cap is the structure of a peer capability.
|
||||||
|
type Cap struct {
|
||||||
|
Name string
|
||||||
|
Version uint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cap Cap) RlpData() interface{} {
|
||||||
|
return []interface{}{cap.Name, cap.Version}
|
||||||
|
}
|
||||||
|
|
||||||
|
type capsByName []Cap
|
||||||
|
|
||||||
|
func (cs capsByName) Len() int { return len(cs) }
|
||||||
|
func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
|
||||||
|
func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
|
||||||
|
|
||||||
|
type baseProtocol struct {
|
||||||
|
rw MsgReadWriter
|
||||||
|
peer *Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
|
||||||
|
bp := &baseProtocol{rw, peer}
|
||||||
|
if err := bp.doHandshake(rw); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
// run main loop
|
||||||
return discReasonToString[d]
|
quit := make(chan error, 1)
|
||||||
}
|
go func() {
|
||||||
|
for {
|
||||||
type BaseProtocol struct {
|
if err := bp.handle(rw); err != nil {
|
||||||
peer *Peer
|
quit <- err
|
||||||
state ProtocolState
|
break
|
||||||
stateLock sync.RWMutex
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func NewBaseProtocol(peer *Peer) *BaseProtocol {
|
|
||||||
self := &BaseProtocol{
|
|
||||||
peer: peer,
|
|
||||||
}
|
|
||||||
|
|
||||||
return self
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) Start() {
|
|
||||||
if self.peer != nil {
|
|
||||||
self.peer.Write("", self.peer.Server().Handshake())
|
|
||||||
go self.peer.Messenger().PingPong(
|
|
||||||
pingTimeout*time.Second,
|
|
||||||
pingGracePeriod*time.Second,
|
|
||||||
self.Ping,
|
|
||||||
self.Timeout,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) Stop() {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) Ping() {
|
|
||||||
msg, _ := NewMsg(PingMsg)
|
|
||||||
self.peer.Write("", msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) Timeout() {
|
|
||||||
self.peerError(PingTimeout, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) Name() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) Offset() MsgCode {
|
|
||||||
return offset
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) CheckState(state ProtocolState) bool {
|
|
||||||
self.stateLock.RLock()
|
|
||||||
self.stateLock.RUnlock()
|
|
||||||
if self.state != state {
|
|
||||||
return false
|
|
||||||
} else {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) {
|
|
||||||
if msg.Code() == HandshakeMsg {
|
|
||||||
self.handleHandshake(msg)
|
|
||||||
} else {
|
|
||||||
if !self.CheckState(handshakeReceived) {
|
|
||||||
self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code())
|
|
||||||
close(response)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
switch msg.Code() {
|
}()
|
||||||
case DiscMsg:
|
return bp.loop(quit)
|
||||||
logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint()))
|
}
|
||||||
self.peer.Server().PeerDisconnect() <- DisconnectRequest{
|
|
||||||
addr: self.peer.Address,
|
var pingTimeout = 2 * time.Second
|
||||||
reason: DiscRequested,
|
|
||||||
|
func (bp *baseProtocol) loop(quit <-chan error) error {
|
||||||
|
ping := time.NewTimer(pingTimeout)
|
||||||
|
activity := bp.peer.activity.Subscribe(time.Time{})
|
||||||
|
lastActive := time.Time{}
|
||||||
|
defer ping.Stop()
|
||||||
|
defer activity.Unsubscribe()
|
||||||
|
|
||||||
|
getPeersTick := time.NewTicker(10 * time.Second)
|
||||||
|
defer getPeersTick.Stop()
|
||||||
|
err := bp.rw.EncodeMsg(getPeersMsg)
|
||||||
|
|
||||||
|
for err == nil {
|
||||||
|
select {
|
||||||
|
case err = <-quit:
|
||||||
|
return err
|
||||||
|
case <-getPeersTick.C:
|
||||||
|
err = bp.rw.EncodeMsg(getPeersMsg)
|
||||||
|
case event := <-activity.Chan():
|
||||||
|
ping.Reset(pingTimeout)
|
||||||
|
lastActive = event.(time.Time)
|
||||||
|
case t := <-ping.C:
|
||||||
|
if lastActive.Add(pingTimeout * 2).Before(t) {
|
||||||
|
err = newPeerError(errPingTimeout, "")
|
||||||
|
} else if lastActive.Add(pingTimeout).Before(t) {
|
||||||
|
err = bp.rw.EncodeMsg(pingMsg)
|
||||||
}
|
}
|
||||||
case PingMsg:
|
|
||||||
out, _ := NewMsg(PongMsg)
|
|
||||||
response <- out
|
|
||||||
case PongMsg:
|
|
||||||
case GetPeersMsg:
|
|
||||||
// Peer asked for list of connected peers
|
|
||||||
if out, err := self.peer.Server().PeersMessage(); err != nil {
|
|
||||||
response <- out
|
|
||||||
}
|
|
||||||
case PeersMsg:
|
|
||||||
self.handlePeers(msg)
|
|
||||||
default:
|
|
||||||
self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
close(response)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) {
|
func (bp *baseProtocol) handle(rw MsgReadWriter) error {
|
||||||
// somewhat overly paranoid
|
msg, err := rw.ReadMsg()
|
||||||
allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived)
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) {
|
|
||||||
err := NewPeerError(errorCode, format, v...)
|
|
||||||
logger.Warnln(err)
|
|
||||||
fmt.Println(self.peer, err)
|
|
||||||
if self.peer != nil {
|
|
||||||
self.peer.PeerErrorChan() <- err
|
|
||||||
}
|
}
|
||||||
}
|
if msg.Size > baseProtocolMaxMsgSize {
|
||||||
|
return newPeerError(errMisc, "message too big")
|
||||||
func (self *BaseProtocol) handlePeers(msg *Msg) {
|
|
||||||
it := msg.Data().NewIterator()
|
|
||||||
for it.Next() {
|
|
||||||
ip := net.IP(it.Value().Get(0).Bytes())
|
|
||||||
port := it.Value().Get(1).Uint()
|
|
||||||
address := &net.TCPAddr{IP: ip, Port: int(port)}
|
|
||||||
go self.peer.Server().PeerConnect(address)
|
|
||||||
}
|
}
|
||||||
|
// make sure that the payload has been fully consumed
|
||||||
|
defer msg.Discard()
|
||||||
|
|
||||||
|
switch msg.Code {
|
||||||
|
case handshakeMsg:
|
||||||
|
return newPeerError(errProtocolBreach, "extra handshake received")
|
||||||
|
|
||||||
|
case discMsg:
|
||||||
|
var reason DiscReason
|
||||||
|
if err := msg.Decode(&reason); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
bp.peer.Disconnect(reason)
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case pingMsg:
|
||||||
|
return bp.rw.EncodeMsg(pongMsg)
|
||||||
|
|
||||||
|
case pongMsg:
|
||||||
|
|
||||||
|
case getPeersMsg:
|
||||||
|
peers := bp.peerList()
|
||||||
|
// this is dangerous. the spec says that we should _delay_
|
||||||
|
// sending the response if no new information is available.
|
||||||
|
// this means that would need to send a response later when
|
||||||
|
// new peers become available.
|
||||||
|
//
|
||||||
|
// TODO: add event mechanism to notify baseProtocol for new peers
|
||||||
|
if len(peers) > 0 {
|
||||||
|
return bp.rw.EncodeMsg(peersMsg, peers)
|
||||||
|
}
|
||||||
|
|
||||||
|
case peersMsg:
|
||||||
|
var peers []*peerAddr
|
||||||
|
if err := msg.Decode(&peers); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, addr := range peers {
|
||||||
|
bp.peer.Debugf("received peer suggestion: %v", addr)
|
||||||
|
bp.peer.newPeerAddr <- addr
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *BaseProtocol) handleHandshake(msg *Msg) {
|
func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error {
|
||||||
self.stateLock.Lock()
|
// send our handshake
|
||||||
defer self.stateLock.Unlock()
|
if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
|
||||||
if self.state != nullState {
|
return err
|
||||||
self.peerError(ProtocolBreach, "extra handshake")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c := msg.Data()
|
// read and handle remote handshake
|
||||||
|
msg, err := rw.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Code != handshakeMsg {
|
||||||
|
return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
|
||||||
|
}
|
||||||
|
if msg.Size > baseProtocolMaxMsgSize {
|
||||||
|
return newPeerError(errMisc, "message too big")
|
||||||
|
}
|
||||||
|
|
||||||
|
var hs handshake
|
||||||
|
if err := msg.Decode(&hs); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate handshake info
|
||||||
|
if hs.Version != baseProtocolVersion {
|
||||||
|
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
|
||||||
|
baseProtocolVersion, hs.Version)
|
||||||
|
}
|
||||||
|
if len(hs.NodeID) == 0 {
|
||||||
|
return newPeerError(errPubkeyMissing, "")
|
||||||
|
}
|
||||||
|
if len(hs.NodeID) != 64 {
|
||||||
|
return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
|
||||||
|
}
|
||||||
|
if da := bp.peer.dialAddr; da != nil {
|
||||||
|
// verify that the peer we wanted to connect to
|
||||||
|
// actually holds the target public key.
|
||||||
|
if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
|
||||||
|
return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
|
||||||
|
if err := bp.peer.pubkeyHook(pa); err != nil {
|
||||||
|
return newPeerError(errPubkeyForbidden, "%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: remove Caps with empty name
|
||||||
|
|
||||||
|
var addr *peerAddr
|
||||||
|
if hs.ListenPort != 0 {
|
||||||
|
addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
|
||||||
|
addr.Port = hs.ListenPort
|
||||||
|
}
|
||||||
|
bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
|
||||||
|
bp.peer.startSubprotocols(hs.Caps)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bp *baseProtocol) handshakeMsg() Msg {
|
||||||
var (
|
var (
|
||||||
p2pVersion = c.Get(0).Uint()
|
port uint64
|
||||||
id = c.Get(1).Str()
|
caps []interface{}
|
||||||
caps = c.Get(2)
|
|
||||||
port = c.Get(3).Uint()
|
|
||||||
pubkey = c.Get(4).Bytes()
|
|
||||||
)
|
)
|
||||||
fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey)
|
if bp.peer.ourListenAddr != nil {
|
||||||
|
port = bp.peer.ourListenAddr.Port
|
||||||
// Check correctness of p2p protocol version
|
|
||||||
if p2pVersion != P2PVersion {
|
|
||||||
self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
for _, proto := range bp.peer.protocols {
|
||||||
// Handle the pub key (validation, uniqueness)
|
caps = append(caps, proto.cap())
|
||||||
if len(pubkey) == 0 {
|
|
||||||
self.peerError(PubkeyMissing, "not supplied in handshake.")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return NewMsg(handshakeMsg,
|
||||||
if len(pubkey) != 64 {
|
baseProtocolVersion,
|
||||||
self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8)
|
bp.peer.ourID.String(),
|
||||||
return
|
caps,
|
||||||
}
|
port,
|
||||||
|
bp.peer.ourID.Pubkey()[1:],
|
||||||
// Self connect detection
|
)
|
||||||
if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 {
|
}
|
||||||
self.peerError(PubkeyForbidden, "not allowed to connect to self")
|
|
||||||
return
|
func (bp *baseProtocol) peerList() []ethutil.RlpEncodable {
|
||||||
}
|
peers := bp.peer.otherPeers()
|
||||||
|
ds := make([]ethutil.RlpEncodable, 0, len(peers))
|
||||||
// register pubkey on server. this also sets the pubkey on the peer (need lock)
|
for _, p := range peers {
|
||||||
if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil {
|
p.infolock.Lock()
|
||||||
self.peerError(PubkeyForbidden, err.Error())
|
addr := p.listenAddr
|
||||||
return
|
p.infolock.Unlock()
|
||||||
}
|
// filter out this peer and peers that are not listening or
|
||||||
|
// have not completed the handshake.
|
||||||
// check port
|
// TODO: track previously sent peers and exclude them as well.
|
||||||
if self.peer.Inbound {
|
if p == bp.peer || addr == nil {
|
||||||
uint16port := uint16(port)
|
continue
|
||||||
if self.peer.Port > 0 && self.peer.Port != uint16port {
|
}
|
||||||
self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port)
|
ds = append(ds, addr)
|
||||||
return
|
}
|
||||||
} else {
|
ourAddr := bp.peer.ourListenAddr
|
||||||
self.peer.Port = uint16port
|
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
|
||||||
}
|
ds = append(ds, ourAddr)
|
||||||
}
|
}
|
||||||
|
return ds
|
||||||
capsIt := caps.NewIterator()
|
|
||||||
for capsIt.Next() {
|
|
||||||
cap := capsIt.Value().Str()
|
|
||||||
self.peer.Caps = append(self.peer.Caps, cap)
|
|
||||||
}
|
|
||||||
sort.Strings(self.peer.Caps)
|
|
||||||
self.peer.Messenger().AddProtocols(self.peer.Caps)
|
|
||||||
|
|
||||||
self.peer.Id = id
|
|
||||||
|
|
||||||
self.state = handshakeReceived
|
|
||||||
|
|
||||||
//p.ethereum.PushPeer(p)
|
|
||||||
// p.ethereum.reactor.Post("peerList", p.ethereum.Peers())
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
825
p2p/server.go
825
p2p/server.go
@ -2,21 +2,420 @@ package p2p
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
logpkg "github.com/ethereum/go-ethereum/logger"
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
outboundAddressPoolSize = 10
|
outboundAddressPoolSize = 500
|
||||||
disconnectGracePeriod = 2
|
defaultDialTimeout = 10 * time.Second
|
||||||
|
portMappingUpdateInterval = 15 * time.Minute
|
||||||
|
portMappingTimeout = 20 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var srvlog = logger.NewLogger("P2P Server")
|
||||||
|
|
||||||
|
// Server manages all peer connections.
|
||||||
|
//
|
||||||
|
// The fields of Server are used as configuration parameters.
|
||||||
|
// You should set them before starting the Server. Fields may not be
|
||||||
|
// modified while the server is running.
|
||||||
|
type Server struct {
|
||||||
|
// This field must be set to a valid client identity.
|
||||||
|
Identity ClientIdentity
|
||||||
|
|
||||||
|
// MaxPeers is the maximum number of peers that can be
|
||||||
|
// connected. It must be greater than zero.
|
||||||
|
MaxPeers int
|
||||||
|
|
||||||
|
// Protocols should contain the protocols supported
|
||||||
|
// by the server. Matching protocols are launched for
|
||||||
|
// each peer.
|
||||||
|
Protocols []Protocol
|
||||||
|
|
||||||
|
// If Blacklist is set to a non-nil value, the given Blacklist
|
||||||
|
// is used to verify peer connections.
|
||||||
|
Blacklist Blacklist
|
||||||
|
|
||||||
|
// If ListenAddr is set to a non-nil address, the server
|
||||||
|
// will listen for incoming connections.
|
||||||
|
//
|
||||||
|
// If the port is zero, the operating system will pick a port. The
|
||||||
|
// ListenAddr field will be updated with the actual address when
|
||||||
|
// the server is started.
|
||||||
|
ListenAddr string
|
||||||
|
|
||||||
|
// If set to a non-nil value, the given NAT port mapper
|
||||||
|
// is used to make the listening port available to the
|
||||||
|
// Internet.
|
||||||
|
NAT NAT
|
||||||
|
|
||||||
|
// If Dialer is set to a non-nil value, the given Dialer
|
||||||
|
// is used to dial outbound peer connections.
|
||||||
|
Dialer *net.Dialer
|
||||||
|
|
||||||
|
// If NoDial is true, the server will not dial any peers.
|
||||||
|
NoDial bool
|
||||||
|
|
||||||
|
// Hook for testing. This is useful because we can inhibit
|
||||||
|
// the whole protocol stack.
|
||||||
|
newPeerFunc peerFunc
|
||||||
|
|
||||||
|
lock sync.RWMutex
|
||||||
|
running bool
|
||||||
|
listener net.Listener
|
||||||
|
laddr *net.TCPAddr // real listen addr
|
||||||
|
peers []*Peer
|
||||||
|
peerSlots chan int
|
||||||
|
peerCount int
|
||||||
|
|
||||||
|
quit chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
peerConnect chan *peerAddr
|
||||||
|
peerDisconnect chan *Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NAT is implemented by NAT traversal methods.
|
||||||
|
type NAT interface {
|
||||||
|
GetExternalAddress() (net.IP, error)
|
||||||
|
AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
|
||||||
|
DeletePortMapping(protocol string, extport, intport int) error
|
||||||
|
|
||||||
|
// Should return name of the method.
|
||||||
|
String() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer
|
||||||
|
|
||||||
|
// Peers returns all connected peers.
|
||||||
|
func (srv *Server) Peers() (peers []*Peer) {
|
||||||
|
srv.lock.RLock()
|
||||||
|
defer srv.lock.RUnlock()
|
||||||
|
for _, peer := range srv.peers {
|
||||||
|
if peer != nil {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerCount returns the number of connected peers.
|
||||||
|
func (srv *Server) PeerCount() int {
|
||||||
|
srv.lock.RLock()
|
||||||
|
defer srv.lock.RUnlock()
|
||||||
|
return srv.peerCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// SuggestPeer injects an address into the outbound address pool.
|
||||||
|
func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
|
||||||
|
select {
|
||||||
|
case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}:
|
||||||
|
default: // don't block
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast sends an RLP-encoded message to all connected peers.
|
||||||
|
// This method is deprecated and will be removed later.
|
||||||
|
func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
|
||||||
|
var payload []byte
|
||||||
|
if data != nil {
|
||||||
|
payload = encodePayload(data...)
|
||||||
|
}
|
||||||
|
srv.lock.RLock()
|
||||||
|
defer srv.lock.RUnlock()
|
||||||
|
for _, peer := range srv.peers {
|
||||||
|
if peer != nil {
|
||||||
|
var msg = Msg{Code: code}
|
||||||
|
if data != nil {
|
||||||
|
msg.Payload = bytes.NewReader(payload)
|
||||||
|
msg.Size = uint32(len(payload))
|
||||||
|
}
|
||||||
|
peer.writeProtoMsg(protocol, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts running the server.
|
||||||
|
// Servers can be re-used and started again after stopping.
|
||||||
|
func (srv *Server) Start() (err error) {
|
||||||
|
srv.lock.Lock()
|
||||||
|
defer srv.lock.Unlock()
|
||||||
|
if srv.running {
|
||||||
|
return errors.New("server already running")
|
||||||
|
}
|
||||||
|
srvlog.Infoln("Starting Server")
|
||||||
|
|
||||||
|
// initialize fields
|
||||||
|
if srv.Identity == nil {
|
||||||
|
return fmt.Errorf("Server.Identity must be set to a non-nil identity")
|
||||||
|
}
|
||||||
|
if srv.MaxPeers <= 0 {
|
||||||
|
return fmt.Errorf("Server.MaxPeers must be > 0")
|
||||||
|
}
|
||||||
|
srv.quit = make(chan struct{})
|
||||||
|
srv.peers = make([]*Peer, srv.MaxPeers)
|
||||||
|
srv.peerSlots = make(chan int, srv.MaxPeers)
|
||||||
|
srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize)
|
||||||
|
srv.peerDisconnect = make(chan *Peer)
|
||||||
|
if srv.newPeerFunc == nil {
|
||||||
|
srv.newPeerFunc = newServerPeer
|
||||||
|
}
|
||||||
|
if srv.Blacklist == nil {
|
||||||
|
srv.Blacklist = NewBlacklist()
|
||||||
|
}
|
||||||
|
if srv.Dialer == nil {
|
||||||
|
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv.ListenAddr != "" {
|
||||||
|
if err := srv.startListening(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !srv.NoDial {
|
||||||
|
srv.wg.Add(1)
|
||||||
|
go srv.dialLoop()
|
||||||
|
}
|
||||||
|
if srv.NoDial && srv.ListenAddr == "" {
|
||||||
|
srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// make all slots available
|
||||||
|
for i := range srv.peers {
|
||||||
|
srv.peerSlots <- i
|
||||||
|
}
|
||||||
|
// note: discLoop is not part of WaitGroup
|
||||||
|
go srv.discLoop()
|
||||||
|
srv.running = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) startListening() error {
|
||||||
|
listener, err := net.Listen("tcp", srv.ListenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
srv.ListenAddr = listener.Addr().String()
|
||||||
|
srv.laddr = listener.Addr().(*net.TCPAddr)
|
||||||
|
srv.listener = listener
|
||||||
|
srv.wg.Add(1)
|
||||||
|
go srv.listenLoop()
|
||||||
|
if !srv.laddr.IP.IsLoopback() && srv.NAT != nil {
|
||||||
|
srv.wg.Add(1)
|
||||||
|
go srv.natLoop(srv.laddr.Port)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop terminates the server and all active peer connections.
|
||||||
|
// It blocks until all active connections have been closed.
|
||||||
|
func (srv *Server) Stop() {
|
||||||
|
srv.lock.Lock()
|
||||||
|
if !srv.running {
|
||||||
|
srv.lock.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
srv.running = false
|
||||||
|
srv.lock.Unlock()
|
||||||
|
|
||||||
|
srvlog.Infoln("Stopping server")
|
||||||
|
if srv.listener != nil {
|
||||||
|
// this unblocks listener Accept
|
||||||
|
srv.listener.Close()
|
||||||
|
}
|
||||||
|
close(srv.quit)
|
||||||
|
for _, peer := range srv.Peers() {
|
||||||
|
peer.Disconnect(DiscQuitting)
|
||||||
|
}
|
||||||
|
srv.wg.Wait()
|
||||||
|
|
||||||
|
// wait till they actually disconnect
|
||||||
|
// this is checked by claiming all peerSlots.
|
||||||
|
// slots become available as the peers disconnect.
|
||||||
|
for i := 0; i < cap(srv.peerSlots); i++ {
|
||||||
|
<-srv.peerSlots
|
||||||
|
}
|
||||||
|
// terminate discLoop
|
||||||
|
close(srv.peerDisconnect)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) discLoop() {
|
||||||
|
for peer := range srv.peerDisconnect {
|
||||||
|
// peer has just disconnected. free up its slot.
|
||||||
|
srvlog.Infof("%v is gone", peer)
|
||||||
|
srv.peerSlots <- peer.slot
|
||||||
|
srv.lock.Lock()
|
||||||
|
srv.peers[peer.slot] = nil
|
||||||
|
srv.lock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// main loop for adding connections via listening
|
||||||
|
func (srv *Server) listenLoop() {
|
||||||
|
defer srv.wg.Done()
|
||||||
|
|
||||||
|
srvlog.Infoln("Listening on", srv.listener.Addr())
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case slot := <-srv.peerSlots:
|
||||||
|
conn, err := srv.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
srv.peerSlots <- slot
|
||||||
|
return
|
||||||
|
}
|
||||||
|
srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot)
|
||||||
|
srv.addPeer(conn, nil, slot)
|
||||||
|
case <-srv.quit:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) natLoop(port int) {
|
||||||
|
defer srv.wg.Done()
|
||||||
|
for {
|
||||||
|
srv.updatePortMapping(port)
|
||||||
|
select {
|
||||||
|
case <-time.After(portMappingUpdateInterval):
|
||||||
|
// one more round
|
||||||
|
case <-srv.quit:
|
||||||
|
srv.removePortMapping(port)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) updatePortMapping(port int) {
|
||||||
|
srvlog.Infoln("Attempting to map port", port, "with", srv.NAT)
|
||||||
|
err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout)
|
||||||
|
if err != nil {
|
||||||
|
srvlog.Errorln("Port mapping error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
extip, err := srv.NAT.GetExternalAddress()
|
||||||
|
if err != nil {
|
||||||
|
srvlog.Errorln("Error getting external IP:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
srv.lock.Lock()
|
||||||
|
extaddr := *(srv.listener.Addr().(*net.TCPAddr))
|
||||||
|
extaddr.IP = extip
|
||||||
|
srvlog.Infoln("Mapped port, external addr is", &extaddr)
|
||||||
|
srv.laddr = &extaddr
|
||||||
|
srv.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) removePortMapping(port int) {
|
||||||
|
srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT)
|
||||||
|
srv.NAT.DeletePortMapping("tcp", port, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) dialLoop() {
|
||||||
|
defer srv.wg.Done()
|
||||||
|
var (
|
||||||
|
suggest chan *peerAddr
|
||||||
|
slot *int
|
||||||
|
slots = srv.peerSlots
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case i := <-slots:
|
||||||
|
// we need a peer in slot i, slot reserved
|
||||||
|
slot = &i
|
||||||
|
// now we can watch for candidate peers in the next loop
|
||||||
|
suggest = srv.peerConnect
|
||||||
|
// do not consume more until candidate peer is found
|
||||||
|
slots = nil
|
||||||
|
|
||||||
|
case desc := <-suggest:
|
||||||
|
// candidate peer found, will dial out asyncronously
|
||||||
|
// if connection fails slot will be released
|
||||||
|
go srv.dialPeer(desc, *slot)
|
||||||
|
// we can watch if more peers needed in the next loop
|
||||||
|
slots = srv.peerSlots
|
||||||
|
// until then we dont care about candidate peers
|
||||||
|
suggest = nil
|
||||||
|
|
||||||
|
case <-srv.quit:
|
||||||
|
// give back the currently reserved slot
|
||||||
|
if slot != nil {
|
||||||
|
srv.peerSlots <- *slot
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// connect to peer via dial out
|
||||||
|
func (srv *Server) dialPeer(desc *peerAddr, slot int) {
|
||||||
|
srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot)
|
||||||
|
conn, err := srv.Dialer.Dial(desc.Network(), desc.String())
|
||||||
|
if err != nil {
|
||||||
|
srvlog.Errorf("Dial error: %v", err)
|
||||||
|
srv.peerSlots <- slot
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go srv.addPeer(conn, desc, slot)
|
||||||
|
}
|
||||||
|
|
||||||
|
// creates the new peer object and inserts it into its slot
|
||||||
|
func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer {
|
||||||
|
srv.lock.Lock()
|
||||||
|
defer srv.lock.Unlock()
|
||||||
|
if !srv.running {
|
||||||
|
conn.Close()
|
||||||
|
srv.peerSlots <- slot // release slot
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
peer := srv.newPeerFunc(srv, conn, desc)
|
||||||
|
peer.slot = slot
|
||||||
|
srv.peers[slot] = peer
|
||||||
|
srv.peerCount++
|
||||||
|
go func() { peer.loop(); srv.peerDisconnect <- peer }()
|
||||||
|
return peer
|
||||||
|
}
|
||||||
|
|
||||||
|
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
|
||||||
|
func (srv *Server) removePeer(peer *Peer) {
|
||||||
|
srv.lock.Lock()
|
||||||
|
defer srv.lock.Unlock()
|
||||||
|
srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot)
|
||||||
|
if srv.peers[peer.slot] != peer {
|
||||||
|
srvlog.Warnln("Invalid peer to remove:", peer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// remove from list and index
|
||||||
|
srv.peerCount--
|
||||||
|
srv.peers[peer.slot] = nil
|
||||||
|
// release slot to signal need for a new peer, last!
|
||||||
|
srv.peerSlots <- peer.slot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) verifyPeer(addr *peerAddr) error {
|
||||||
|
if srv.Blacklist.Exists(addr.Pubkey) {
|
||||||
|
return errors.New("blacklisted")
|
||||||
|
}
|
||||||
|
if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) {
|
||||||
|
return newPeerError(errPubkeyForbidden, "not allowed to connect to srv")
|
||||||
|
}
|
||||||
|
srv.lock.RLock()
|
||||||
|
defer srv.lock.RUnlock()
|
||||||
|
for _, peer := range srv.peers {
|
||||||
|
if peer != nil {
|
||||||
|
id := peer.Identity()
|
||||||
|
if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) {
|
||||||
|
return errors.New("already connected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type Blacklist interface {
|
type Blacklist interface {
|
||||||
Get([]byte) (bool, error)
|
Get([]byte) (bool, error)
|
||||||
Put([]byte) error
|
Put([]byte) error
|
||||||
@ -66,419 +465,3 @@ func (self *BlacklistMap) Delete(pubkey []byte) error {
|
|||||||
delete(self.blacklist, string(pubkey))
|
delete(self.blacklist, string(pubkey))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
|
||||||
network Network
|
|
||||||
listening bool //needed?
|
|
||||||
dialing bool //needed?
|
|
||||||
closed bool
|
|
||||||
identity ClientIdentity
|
|
||||||
addr net.Addr
|
|
||||||
port uint16
|
|
||||||
protocols []string
|
|
||||||
|
|
||||||
quit chan chan bool
|
|
||||||
peersLock sync.RWMutex
|
|
||||||
|
|
||||||
maxPeers int
|
|
||||||
peers []*Peer
|
|
||||||
peerSlots chan int
|
|
||||||
peersTable map[string]int
|
|
||||||
peersMsg *Msg
|
|
||||||
peerCount int
|
|
||||||
|
|
||||||
peerConnect chan net.Addr
|
|
||||||
peerDisconnect chan DisconnectRequest
|
|
||||||
blacklist Blacklist
|
|
||||||
handlers Handlers
|
|
||||||
}
|
|
||||||
|
|
||||||
var logger = logpkg.NewLogger("P2P")
|
|
||||||
|
|
||||||
func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server {
|
|
||||||
// get alphabetical list of protocol names from handlers map
|
|
||||||
protocols := []string{}
|
|
||||||
for protocol := range handlers {
|
|
||||||
protocols = append(protocols, protocol)
|
|
||||||
}
|
|
||||||
sort.Strings(protocols)
|
|
||||||
|
|
||||||
_, port, _ := net.SplitHostPort(addr.String())
|
|
||||||
intport, _ := strconv.Atoi(port)
|
|
||||||
|
|
||||||
self := &Server{
|
|
||||||
// NewSimpleClientIdentity(clientIdentifier, version, customIdentifier)
|
|
||||||
network: network,
|
|
||||||
identity: identity,
|
|
||||||
addr: addr,
|
|
||||||
port: uint16(intport),
|
|
||||||
protocols: protocols,
|
|
||||||
|
|
||||||
quit: make(chan chan bool),
|
|
||||||
|
|
||||||
maxPeers: maxPeers,
|
|
||||||
peers: make([]*Peer, maxPeers),
|
|
||||||
peerSlots: make(chan int, maxPeers),
|
|
||||||
peersTable: make(map[string]int),
|
|
||||||
|
|
||||||
peerConnect: make(chan net.Addr, outboundAddressPoolSize),
|
|
||||||
peerDisconnect: make(chan DisconnectRequest),
|
|
||||||
blacklist: blacklist,
|
|
||||||
|
|
||||||
handlers: handlers,
|
|
||||||
}
|
|
||||||
for i := 0; i < maxPeers; i++ {
|
|
||||||
self.peerSlots <- i // fill up with indexes
|
|
||||||
}
|
|
||||||
return self
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) {
|
|
||||||
addr, err = self.network.NewAddr(host, port)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Server) ParseAddr(address string) (addr net.Addr, err error) {
|
|
||||||
addr, err = self.network.ParseAddr(address)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Server) ClientIdentity() ClientIdentity {
|
|
||||||
return self.identity
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Server) PeersMessage() (msg *Msg, err error) {
|
|
||||||
// TODO: memoize and reset when peers change
|
|
||||||
self.peersLock.RLock()
|
|
||||||
defer self.peersLock.RUnlock()
|
|
||||||
msg = self.peersMsg
|
|
||||||
if msg == nil {
|
|
||||||
var peerData []interface{}
|
|
||||||
for _, i := range self.peersTable {
|
|
||||||
peer := self.peers[i]
|
|
||||||
peerData = append(peerData, peer.Encode())
|
|
||||||
}
|
|
||||||
if len(peerData) == 0 {
|
|
||||||
err = fmt.Errorf("no peers")
|
|
||||||
} else {
|
|
||||||
msg, err = NewMsg(PeersMsg, peerData...)
|
|
||||||
self.peersMsg = msg //memoize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Server) Peers() (peers []*Peer) {
|
|
||||||
self.peersLock.RLock()
|
|
||||||
defer self.peersLock.RUnlock()
|
|
||||||
for _, peer := range self.peers {
|
|
||||||
if peer != nil {
|
|
||||||
peers = append(peers, peer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Server) PeerCount() int {
|
|
||||||
self.peersLock.RLock()
|
|
||||||
defer self.peersLock.RUnlock()
|
|
||||||
return self.peerCount
|
|
||||||
}
|
|
||||||
|
|
||||||
var getPeersMsg, _ = NewMsg(GetPeersMsg)
|
|
||||||
|
|
||||||
func (self *Server) PeerConnect(addr net.Addr) {
|
|
||||||
// TODO: should buffer, filter and uniq
|
|
||||||
// send GetPeersMsg if not blocking
|
|
||||||
select {
|
|
||||||
case self.peerConnect <- addr: // not enough peers
|
|
||||||
self.Broadcast("", getPeersMsg)
|
|
||||||
default: // we dont care
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Server) PeerDisconnect() chan DisconnectRequest {
|
|
||||||
return self.peerDisconnect
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Server) Blacklist() Blacklist {
|
|
||||||
return self.blacklist
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Server) Handlers() Handlers {
|
|
||||||
return self.handlers
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Server) Broadcast(protocol string, msg *Msg) {
|
|
||||||
self.peersLock.RLock()
|
|
||||||
defer self.peersLock.RUnlock()
|
|
||||||
for _, peer := range self.peers {
|
|
||||||
if peer != nil {
|
|
||||||
peer.Write(protocol, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the server
|
|
||||||
func (self *Server) Start(listen bool, dial bool) {
|
|
||||||
self.network.Start()
|
|
||||||
if listen {
|
|
||||||
listener, err := self.network.Listener(self.addr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warnf("Error initializing listener: %v", err)
|
|
||||||
logger.Warnf("Connection listening disabled")
|
|
||||||
self.listening = false
|
|
||||||
} else {
|
|
||||||
self.listening = true
|
|
||||||
logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr())
|
|
||||||
go self.inboundPeerHandler(listener)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if dial {
|
|
||||||
dialer, err := self.network.Dialer(self.addr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warnf("Error initializing dialer: %v", err)
|
|
||||||
logger.Warnf("Connection dialout disabled")
|
|
||||||
self.dialing = false
|
|
||||||
} else {
|
|
||||||
self.dialing = true
|
|
||||||
logger.Infoln("Dial peers watching outbound address pool")
|
|
||||||
go self.outboundPeerHandler(dialer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Infoln("server started")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Server) Stop() {
|
|
||||||
logger.Infoln("server stopping...")
|
|
||||||
// // quit one loop if dialing
|
|
||||||
if self.dialing {
|
|
||||||
logger.Infoln("stop dialout...")
|
|
||||||
dialq := make(chan bool)
|
|
||||||
self.quit <- dialq
|
|
||||||
<-dialq
|
|
||||||
fmt.Println("quit another")
|
|
||||||
}
|
|
||||||
// quit the other loop if listening
|
|
||||||
if self.listening {
|
|
||||||
logger.Infoln("stop listening...")
|
|
||||||
listenq := make(chan bool)
|
|
||||||
self.quit <- listenq
|
|
||||||
<-listenq
|
|
||||||
fmt.Println("quit one")
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("quit waited")
|
|
||||||
|
|
||||||
logger.Infoln("stopping peers...")
|
|
||||||
peers := []net.Addr{}
|
|
||||||
self.peersLock.RLock()
|
|
||||||
self.closed = true
|
|
||||||
for _, peer := range self.peers {
|
|
||||||
if peer != nil {
|
|
||||||
peers = append(peers, peer.Address)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.peersLock.RUnlock()
|
|
||||||
for _, address := range peers {
|
|
||||||
go self.removePeer(DisconnectRequest{
|
|
||||||
addr: address,
|
|
||||||
reason: DiscQuitting,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
// wait till they actually disconnect
|
|
||||||
// this is checked by draining the peerSlots (slots are released back if a peer is removed)
|
|
||||||
i := 0
|
|
||||||
fmt.Println("draining peers")
|
|
||||||
|
|
||||||
FOR:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case slot := <-self.peerSlots:
|
|
||||||
i++
|
|
||||||
fmt.Printf("%v: found slot %v", i, slot)
|
|
||||||
if i == self.maxPeers {
|
|
||||||
break FOR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Infoln("server stopped")
|
|
||||||
}
|
|
||||||
|
|
||||||
// main loop for adding connections via listening
|
|
||||||
func (self *Server) inboundPeerHandler(listener net.Listener) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case slot := <-self.peerSlots:
|
|
||||||
go self.connectInboundPeer(listener, slot)
|
|
||||||
case errc := <-self.quit:
|
|
||||||
listener.Close()
|
|
||||||
fmt.Println("quit listenloop")
|
|
||||||
errc <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// main loop for adding outbound peers based on peerConnect address pool
|
|
||||||
// this same loop handles peer disconnect requests as well
|
|
||||||
func (self *Server) outboundPeerHandler(dialer Dialer) {
|
|
||||||
// addressChan initially set to nil (only watches peerConnect if we need more peers)
|
|
||||||
var addressChan chan net.Addr
|
|
||||||
slots := self.peerSlots
|
|
||||||
var slot *int
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case i := <-slots:
|
|
||||||
// we need a peer in slot i, slot reserved
|
|
||||||
slot = &i
|
|
||||||
// now we can watch for candidate peers in the next loop
|
|
||||||
addressChan = self.peerConnect
|
|
||||||
// do not consume more until candidate peer is found
|
|
||||||
slots = nil
|
|
||||||
case address := <-addressChan:
|
|
||||||
// candidate peer found, will dial out asyncronously
|
|
||||||
// if connection fails slot will be released
|
|
||||||
go self.connectOutboundPeer(dialer, address, *slot)
|
|
||||||
// we can watch if more peers needed in the next loop
|
|
||||||
slots = self.peerSlots
|
|
||||||
// until then we dont care about candidate peers
|
|
||||||
addressChan = nil
|
|
||||||
case request := <-self.peerDisconnect:
|
|
||||||
go self.removePeer(request)
|
|
||||||
case errc := <-self.quit:
|
|
||||||
if addressChan != nil && slot != nil {
|
|
||||||
self.peerSlots <- *slot
|
|
||||||
}
|
|
||||||
fmt.Println("quit dialloop")
|
|
||||||
errc <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if peer address already connected
|
|
||||||
func (self *Server) connected(address net.Addr) (err error) {
|
|
||||||
self.peersLock.RLock()
|
|
||||||
defer self.peersLock.RUnlock()
|
|
||||||
// fmt.Printf("address: %v\n", address)
|
|
||||||
slot, found := self.peersTable[address.String()]
|
|
||||||
if found {
|
|
||||||
err = fmt.Errorf("already connected as peer %v (%v)", slot, address)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// connect to peer via listener.Accept()
|
|
||||||
func (self *Server) connectInboundPeer(listener net.Listener, slot int) {
|
|
||||||
var address net.Addr
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err == nil {
|
|
||||||
address = conn.RemoteAddr()
|
|
||||||
err = self.connected(address)
|
|
||||||
if err != nil {
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
logger.Debugln(err)
|
|
||||||
self.peerSlots <- slot
|
|
||||||
} else {
|
|
||||||
fmt.Printf("adding %v\n", address)
|
|
||||||
go self.addPeer(conn, address, true, slot)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// connect to peer via dial out
|
|
||||||
func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) {
|
|
||||||
var conn net.Conn
|
|
||||||
err := self.connected(address)
|
|
||||||
if err == nil {
|
|
||||||
conn, err = dialer.Dial(address.Network(), address.String())
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
logger.Debugln(err)
|
|
||||||
self.peerSlots <- slot
|
|
||||||
} else {
|
|
||||||
go self.addPeer(conn, address, false, slot)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// creates the new peer object and inserts it into its slot
|
|
||||||
func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) {
|
|
||||||
self.peersLock.Lock()
|
|
||||||
defer self.peersLock.Unlock()
|
|
||||||
if self.closed {
|
|
||||||
fmt.Println("oopsy, not no longer need peer")
|
|
||||||
conn.Close() //oopsy our bad
|
|
||||||
self.peerSlots <- slot // release slot
|
|
||||||
} else {
|
|
||||||
peer := NewPeer(conn, address, inbound, self)
|
|
||||||
self.peers[slot] = peer
|
|
||||||
self.peersTable[address.String()] = slot
|
|
||||||
self.peerCount++
|
|
||||||
// reset peersmsg
|
|
||||||
self.peersMsg = nil
|
|
||||||
fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
|
|
||||||
peer.Start()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
|
|
||||||
func (self *Server) removePeer(request DisconnectRequest) {
|
|
||||||
self.peersLock.Lock()
|
|
||||||
|
|
||||||
address := request.addr
|
|
||||||
slot := self.peersTable[address.String()]
|
|
||||||
peer := self.peers[slot]
|
|
||||||
fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot)
|
|
||||||
if peer == nil {
|
|
||||||
logger.Debugf("already removed peer on %v", address)
|
|
||||||
self.peersLock.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// remove from list and index
|
|
||||||
self.peerCount--
|
|
||||||
self.peers[slot] = nil
|
|
||||||
delete(self.peersTable, address.String())
|
|
||||||
// reset peersmsg
|
|
||||||
self.peersMsg = nil
|
|
||||||
fmt.Printf("removed peer %v (slot %v)\n", peer, slot)
|
|
||||||
self.peersLock.Unlock()
|
|
||||||
|
|
||||||
// sending disconnect message
|
|
||||||
disconnectMsg, _ := NewMsg(DiscMsg, request.reason)
|
|
||||||
peer.Write("", disconnectMsg)
|
|
||||||
// be nice and wait
|
|
||||||
time.Sleep(disconnectGracePeriod * time.Second)
|
|
||||||
// switch off peer and close connections etc.
|
|
||||||
fmt.Println("stopping peer")
|
|
||||||
peer.Stop()
|
|
||||||
fmt.Println("stopped peer")
|
|
||||||
// release slot to signal need for a new peer, last!
|
|
||||||
self.peerSlots <- slot
|
|
||||||
}
|
|
||||||
|
|
||||||
// fix handshake message to push to peers
|
|
||||||
func (self *Server) Handshake() *Msg {
|
|
||||||
fmt.Println(self.identity.Pubkey()[1:])
|
|
||||||
msg, _ := NewMsg(HandshakeMsg, P2PVersion, []byte(self.identity.String()), []interface{}{self.protocols}, self.port, self.identity.Pubkey()[1:])
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error {
|
|
||||||
// Check for blacklisting
|
|
||||||
if self.blacklist.Exists(pubkey) {
|
|
||||||
return fmt.Errorf("blacklisted")
|
|
||||||
}
|
|
||||||
|
|
||||||
self.peersLock.RLock()
|
|
||||||
defer self.peersLock.RUnlock()
|
|
||||||
for _, peer := range self.peers {
|
|
||||||
if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 {
|
|
||||||
return fmt.Errorf("already connected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
candidate.Pubkey = pubkey
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -2,207 +2,160 @@ package p2p
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TestNetwork struct {
|
func startTestServer(t *testing.T, pf peerFunc) *Server {
|
||||||
connections map[string]*TestNetworkConnection
|
server := &Server{
|
||||||
dialer Dialer
|
Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"),
|
||||||
maxinbound int
|
MaxPeers: 10,
|
||||||
}
|
ListenAddr: "127.0.0.1:0",
|
||||||
|
newPeerFunc: pf,
|
||||||
func NewTestNetwork(maxinbound int) *TestNetwork {
|
|
||||||
connections := make(map[string]*TestNetworkConnection)
|
|
||||||
return &TestNetwork{
|
|
||||||
connections: connections,
|
|
||||||
dialer: &TestDialer{connections},
|
|
||||||
maxinbound: maxinbound,
|
|
||||||
}
|
}
|
||||||
}
|
if err := server.Start(); err != nil {
|
||||||
|
t.Fatalf("Could not start server: %v", err)
|
||||||
func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) {
|
|
||||||
return self.dialer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) {
|
|
||||||
return &TestListener{
|
|
||||||
connections: self.connections,
|
|
||||||
addr: addr,
|
|
||||||
max: self.maxinbound,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetwork) Start() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type TestAddr struct {
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestAddr) String() string {
|
|
||||||
return self.name
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*TestAddr) Network() string {
|
|
||||||
return "test"
|
|
||||||
}
|
|
||||||
|
|
||||||
type TestDialer struct {
|
|
||||||
connections map[string]*TestNetworkConnection
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) {
|
|
||||||
address := &TestAddr{addr}
|
|
||||||
tconn := NewTestNetworkConnection(address)
|
|
||||||
self.connections[addr] = tconn
|
|
||||||
conn = net.Conn(tconn)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type TestListener struct {
|
|
||||||
connections map[string]*TestNetworkConnection
|
|
||||||
addr net.Addr
|
|
||||||
max int
|
|
||||||
i int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestListener) Accept() (conn net.Conn, err error) {
|
|
||||||
self.i++
|
|
||||||
if self.i > self.max {
|
|
||||||
err = fmt.Errorf("no more")
|
|
||||||
} else {
|
|
||||||
addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)}
|
|
||||||
tconn := NewTestNetworkConnection(addr)
|
|
||||||
key := tconn.RemoteAddr().String()
|
|
||||||
self.connections[key] = tconn
|
|
||||||
conn = net.Conn(tconn)
|
|
||||||
fmt.Printf("accepted connection from: %v \n", addr)
|
|
||||||
}
|
}
|
||||||
return
|
return server
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *TestListener) Close() error {
|
func TestServerListen(t *testing.T) {
|
||||||
return nil
|
defer testlog(t).detach()
|
||||||
}
|
|
||||||
|
|
||||||
func (self *TestListener) Addr() net.Addr {
|
// start the test server
|
||||||
return self.addr
|
connected := make(chan *Peer)
|
||||||
}
|
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
|
||||||
|
if conn == nil {
|
||||||
func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) {
|
t.Error("peer func called with nil conn")
|
||||||
network = NewTestNetwork(1)
|
|
||||||
addr := &TestAddr{"test:30303"}
|
|
||||||
identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey")
|
|
||||||
maxPeers := 2
|
|
||||||
if handlers == nil {
|
|
||||||
handlers = make(Handlers)
|
|
||||||
}
|
|
||||||
blackist := NewBlacklist()
|
|
||||||
server = New(network, addr, identity, handlers, maxPeers, blackist)
|
|
||||||
fmt.Println(server.identity.Pubkey())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServerListener(t *testing.T) {
|
|
||||||
network, server := SetupTestServer(nil)
|
|
||||||
server.Start(true, false)
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
server.Stop()
|
|
||||||
peer1, ok := network.connections["inboundpeer-1"]
|
|
||||||
if !ok {
|
|
||||||
t.Error("not found inbound peer 1")
|
|
||||||
} else {
|
|
||||||
fmt.Printf("out: %v\n", peer1.Out)
|
|
||||||
if len(peer1.Out) != 2 {
|
|
||||||
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
|
|
||||||
}
|
}
|
||||||
}
|
if dialAddr != nil {
|
||||||
|
t.Error("peer func called with non-nil dialAddr")
|
||||||
|
}
|
||||||
|
peer := newPeer(conn, nil, dialAddr)
|
||||||
|
connected <- peer
|
||||||
|
return peer
|
||||||
|
})
|
||||||
|
defer close(connected)
|
||||||
|
defer srv.Stop()
|
||||||
|
|
||||||
|
// dial the test server
|
||||||
|
conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("could not dial: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case peer := <-connected:
|
||||||
|
if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() {
|
||||||
|
t.Errorf("peer started with wrong conn: got %v, want %v",
|
||||||
|
peer.conn.LocalAddr(), conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Error("server did not accept within one second")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerDialer(t *testing.T) {
|
func TestServerDial(t *testing.T) {
|
||||||
network, server := SetupTestServer(nil)
|
defer testlog(t).detach()
|
||||||
server.Start(false, true)
|
|
||||||
server.peerConnect <- &TestAddr{"outboundpeer-1"}
|
// run a fake TCP server to handle the connection.
|
||||||
time.Sleep(10 * time.Millisecond)
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
server.Stop()
|
if err != nil {
|
||||||
peer1, ok := network.connections["outboundpeer-1"]
|
t.Fatalf("could not setup listener: %v")
|
||||||
if !ok {
|
}
|
||||||
t.Error("not found outbound peer 1")
|
defer listener.Close()
|
||||||
} else {
|
accepted := make(chan net.Conn)
|
||||||
fmt.Printf("out: %v\n", peer1.Out)
|
go func() {
|
||||||
if len(peer1.Out) != 2 {
|
conn, err := listener.Accept()
|
||||||
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
|
if err != nil {
|
||||||
|
t.Error("acccept error:", err)
|
||||||
}
|
}
|
||||||
|
conn.Close()
|
||||||
|
accepted <- conn
|
||||||
|
}()
|
||||||
|
|
||||||
|
// start the test server
|
||||||
|
connected := make(chan *Peer)
|
||||||
|
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
|
||||||
|
if conn == nil {
|
||||||
|
t.Error("peer func called with nil conn")
|
||||||
|
}
|
||||||
|
peer := newPeer(conn, nil, dialAddr)
|
||||||
|
connected <- peer
|
||||||
|
return peer
|
||||||
|
})
|
||||||
|
defer close(connected)
|
||||||
|
defer srv.Stop()
|
||||||
|
|
||||||
|
// tell the server to connect.
|
||||||
|
connAddr := newPeerAddr(listener.Addr(), nil)
|
||||||
|
srv.peerConnect <- connAddr
|
||||||
|
|
||||||
|
select {
|
||||||
|
case conn := <-accepted:
|
||||||
|
select {
|
||||||
|
case peer := <-connected:
|
||||||
|
if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() {
|
||||||
|
t.Errorf("peer started with wrong conn: got %v, want %v",
|
||||||
|
peer.conn.RemoteAddr(), conn.LocalAddr())
|
||||||
|
}
|
||||||
|
if peer.dialAddr != connAddr {
|
||||||
|
t.Errorf("peer started with wrong dialAddr: got %v, want %v",
|
||||||
|
peer.dialAddr, connAddr)
|
||||||
|
}
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Error("server did not launch peer within one second")
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Error("server did not connect within one second")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerBroadcast(t *testing.T) {
|
func TestServerBroadcast(t *testing.T) {
|
||||||
handlers := make(Handlers)
|
defer testlog(t).detach()
|
||||||
testProtocol := &TestProtocol{Msgs: []*Msg{}}
|
var connected sync.WaitGroup
|
||||||
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
|
srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer {
|
||||||
network, server := SetupTestServer(handlers)
|
peer := newPeer(c, []Protocol{discard}, dialAddr)
|
||||||
server.Start(true, true)
|
peer.startSubprotocols([]Cap{discard.cap()})
|
||||||
server.peerConnect <- &TestAddr{"outboundpeer-1"}
|
connected.Done()
|
||||||
time.Sleep(10 * time.Millisecond)
|
return peer
|
||||||
msg, _ := NewMsg(0)
|
})
|
||||||
server.Broadcast("", msg)
|
defer srv.Stop()
|
||||||
packet := Packet(0, 0)
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
server.Stop()
|
|
||||||
peer1, ok := network.connections["outboundpeer-1"]
|
|
||||||
if !ok {
|
|
||||||
t.Error("not found outbound peer 1")
|
|
||||||
} else {
|
|
||||||
fmt.Printf("out: %v\n", peer1.Out)
|
|
||||||
if len(peer1.Out) != 3 {
|
|
||||||
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
|
|
||||||
} else {
|
|
||||||
if bytes.Compare(peer1.Out[1], packet) != 0 {
|
|
||||||
t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
peer2, ok := network.connections["inboundpeer-1"]
|
|
||||||
if !ok {
|
|
||||||
t.Error("not found inbound peer 2")
|
|
||||||
} else {
|
|
||||||
fmt.Printf("out: %v\n", peer2.Out)
|
|
||||||
if len(peer1.Out) != 3 {
|
|
||||||
t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
|
|
||||||
} else {
|
|
||||||
if bytes.Compare(peer2.Out[1], packet) != 0 {
|
|
||||||
t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServerPeersMessage(t *testing.T) {
|
// dial a bunch of conns
|
||||||
handlers := make(Handlers)
|
var conns = make([]net.Conn, 8)
|
||||||
_, server := SetupTestServer(handlers)
|
connected.Add(len(conns))
|
||||||
server.Start(true, true)
|
deadline := time.Now().Add(3 * time.Second)
|
||||||
defer server.Stop()
|
dialer := &net.Dialer{Deadline: deadline}
|
||||||
server.peerConnect <- &TestAddr{"outboundpeer-1"}
|
for i := range conns {
|
||||||
time.Sleep(10 * time.Millisecond)
|
conn, err := dialer.Dial("tcp", srv.ListenAddr)
|
||||||
peersMsg, err := server.PeersMessage()
|
if err != nil {
|
||||||
fmt.Println(peersMsg)
|
t.Fatalf("conn %d: dial error: %v", i, err)
|
||||||
if err != nil {
|
}
|
||||||
t.Errorf("expect no error, got %v", err)
|
defer conn.Close()
|
||||||
|
conn.SetDeadline(deadline)
|
||||||
|
conns[i] = conn
|
||||||
}
|
}
|
||||||
if c := server.PeerCount(); c != 2 {
|
connected.Wait()
|
||||||
t.Errorf("expect 2 peers, got %v", c)
|
|
||||||
|
// broadcast one message
|
||||||
|
srv.Broadcast("discard", 0, "foo")
|
||||||
|
goldbuf := new(bytes.Buffer)
|
||||||
|
writeMsg(goldbuf, NewMsg(16, "foo"))
|
||||||
|
golden := goldbuf.Bytes()
|
||||||
|
|
||||||
|
// check that the message has been written everywhere
|
||||||
|
for i, conn := range conns {
|
||||||
|
buf := make([]byte, len(golden))
|
||||||
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||||
|
t.Errorf("conn %d: read error: %v", i, err)
|
||||||
|
} else if !bytes.Equal(buf, golden) {
|
||||||
|
t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
28
p2p/testlog_test.go
Normal file
28
p2p/testlog_test.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testLogger struct{ t *testing.T }
|
||||||
|
|
||||||
|
func testlog(t *testing.T) testLogger {
|
||||||
|
logger.Reset()
|
||||||
|
l := testLogger{t}
|
||||||
|
logger.AddLogSystem(l)
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
|
||||||
|
func (testLogger) SetLogLevel(logger.LogLevel) {}
|
||||||
|
|
||||||
|
func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
|
||||||
|
l.t.Logf("%s", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (testLogger) detach() {
|
||||||
|
logger.Flush()
|
||||||
|
logger.Reset()
|
||||||
|
}
|
40
p2p/testpoc7.go
Normal file
40
p2p/testpoc7.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// +build none
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p"
|
||||||
|
"github.com/obscuren/secp256k1-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
|
||||||
|
|
||||||
|
pub, _ := secp256k1.GenerateKeyPair()
|
||||||
|
srv := p2p.Server{
|
||||||
|
MaxPeers: 10,
|
||||||
|
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
|
||||||
|
ListenAddr: ":30303",
|
||||||
|
NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
|
||||||
|
}
|
||||||
|
if err := srv.Start(); err != nil {
|
||||||
|
fmt.Println("could not start server:", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// add seed peers
|
||||||
|
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("couldn't resolve:", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
srv.SuggestPeer(seed.IP, seed.Port, nil)
|
||||||
|
|
||||||
|
select {}
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
package rlp
|
package rlp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -24,8 +25,9 @@ type Decoder interface {
|
|||||||
DecodeRLP(*Stream) error
|
DecodeRLP(*Stream) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode parses RLP-encoded data from r and stores the result
|
// Decode parses RLP-encoded data from r and stores the result in the
|
||||||
// in the value pointed to by val. Val must be a non-nil pointer.
|
// value pointed to by val. Val must be a non-nil pointer. If r does
|
||||||
|
// not implement ByteReader, Decode will do its own buffering.
|
||||||
//
|
//
|
||||||
// Decode uses the following type-dependent decoding rules:
|
// Decode uses the following type-dependent decoding rules:
|
||||||
//
|
//
|
||||||
@ -66,10 +68,19 @@ type Decoder interface {
|
|||||||
//
|
//
|
||||||
// Non-empty interface types are not supported, nor are bool, float32,
|
// Non-empty interface types are not supported, nor are bool, float32,
|
||||||
// float64, maps, channel types and functions.
|
// float64, maps, channel types and functions.
|
||||||
func Decode(r ByteReader, val interface{}) error {
|
func Decode(r io.Reader, val interface{}) error {
|
||||||
return NewStream(r).Decode(val)
|
return NewStream(r).Decode(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type decodeError struct {
|
||||||
|
msg string
|
||||||
|
typ reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err decodeError) Error() string {
|
||||||
|
return fmt.Sprintf("rlp: %s for %v", err.msg, err.typ)
|
||||||
|
}
|
||||||
|
|
||||||
func makeNumDecoder(typ reflect.Type) decoder {
|
func makeNumDecoder(typ reflect.Type) decoder {
|
||||||
kind := typ.Kind()
|
kind := typ.Kind()
|
||||||
switch {
|
switch {
|
||||||
@ -83,8 +94,11 @@ func makeNumDecoder(typ reflect.Type) decoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func decodeInt(s *Stream, val reflect.Value) error {
|
func decodeInt(s *Stream, val reflect.Value) error {
|
||||||
num, err := s.uint(val.Type().Bits())
|
typ := val.Type()
|
||||||
if err != nil {
|
num, err := s.uint(typ.Bits())
|
||||||
|
if err == errUintOverflow {
|
||||||
|
return decodeError{"input string too long", typ}
|
||||||
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
val.SetInt(int64(num))
|
val.SetInt(int64(num))
|
||||||
@ -92,8 +106,11 @@ func decodeInt(s *Stream, val reflect.Value) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func decodeUint(s *Stream, val reflect.Value) error {
|
func decodeUint(s *Stream, val reflect.Value) error {
|
||||||
num, err := s.uint(val.Type().Bits())
|
typ := val.Type()
|
||||||
if err != nil {
|
num, err := s.uint(typ.Bits())
|
||||||
|
if err == errUintOverflow {
|
||||||
|
return decodeError{"input string too big", typ}
|
||||||
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
val.SetUint(num)
|
val.SetUint(num)
|
||||||
@ -175,7 +192,7 @@ func decodeList(s *Stream, val reflect.Value, elemdec decoder, maxelem int) erro
|
|||||||
i := 0
|
i := 0
|
||||||
for {
|
for {
|
||||||
if i > maxelem {
|
if i > maxelem {
|
||||||
return fmt.Errorf("rlp: input List has more than %d elements", maxelem)
|
return decodeError{"input list has too many elements", val.Type()}
|
||||||
}
|
}
|
||||||
if val.Kind() == reflect.Slice {
|
if val.Kind() == reflect.Slice {
|
||||||
// grow slice if necessary
|
// grow slice if necessary
|
||||||
@ -226,8 +243,6 @@ func decodeByteSlice(s *Stream, val reflect.Value) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var errStringDoesntFitArray = errors.New("rlp: string value doesn't fit into target array")
|
|
||||||
|
|
||||||
func decodeByteArray(s *Stream, val reflect.Value) error {
|
func decodeByteArray(s *Stream, val reflect.Value) error {
|
||||||
kind, size, err := s.Kind()
|
kind, size, err := s.Kind()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -236,14 +251,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error {
|
|||||||
switch kind {
|
switch kind {
|
||||||
case Byte:
|
case Byte:
|
||||||
if val.Len() == 0 {
|
if val.Len() == 0 {
|
||||||
return errStringDoesntFitArray
|
return decodeError{"input string too big", val.Type()}
|
||||||
}
|
}
|
||||||
bv, _ := s.Uint()
|
bv, _ := s.Uint()
|
||||||
val.Index(0).SetUint(bv)
|
val.Index(0).SetUint(bv)
|
||||||
zero(val, 1)
|
zero(val, 1)
|
||||||
case String:
|
case String:
|
||||||
if uint64(val.Len()) < size {
|
if uint64(val.Len()) < size {
|
||||||
return errStringDoesntFitArray
|
return decodeError{"input string too big", val.Type()}
|
||||||
}
|
}
|
||||||
slice := val.Slice(0, int(size)).Interface().([]byte)
|
slice := val.Slice(0, int(size)).Interface().([]byte)
|
||||||
if err := s.readFull(slice); err != nil {
|
if err := s.readFull(slice); err != nil {
|
||||||
@ -293,7 +308,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err = s.ListEnd(); err == errNotAtEOL {
|
if err = s.ListEnd(); err == errNotAtEOL {
|
||||||
err = errors.New("rlp: input List has too many elements")
|
err = decodeError{"input list has too many elements", typ}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -432,8 +447,23 @@ type Stream struct {
|
|||||||
|
|
||||||
type listpos struct{ pos, size uint64 }
|
type listpos struct{ pos, size uint64 }
|
||||||
|
|
||||||
func NewStream(r ByteReader) *Stream {
|
// NewStream creates a new stream reading from r.
|
||||||
return &Stream{r: r, uintbuf: make([]byte, 8), kind: -1}
|
// If r does not implement ByteReader, the Stream will
|
||||||
|
// introduce its own buffering.
|
||||||
|
func NewStream(r io.Reader) *Stream {
|
||||||
|
s := new(Stream)
|
||||||
|
s.Reset(r)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewListStream creates a new stream that pretends to be positioned
|
||||||
|
// at an encoded list of the given length.
|
||||||
|
func NewListStream(r io.Reader, len uint64) *Stream {
|
||||||
|
s := new(Stream)
|
||||||
|
s.Reset(r)
|
||||||
|
s.kind = List
|
||||||
|
s.size = len
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bytes reads an RLP string and returns its contents as a byte slice.
|
// Bytes reads an RLP string and returns its contents as a byte slice.
|
||||||
@ -459,6 +489,8 @@ func (s *Stream) Bytes() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errUintOverflow = errors.New("rlp: uint overflow")
|
||||||
|
|
||||||
// Uint reads an RLP string of up to 8 bytes and returns its contents
|
// Uint reads an RLP string of up to 8 bytes and returns its contents
|
||||||
// as an unsigned integer. If the input does not contain an RLP string, the
|
// as an unsigned integer. If the input does not contain an RLP string, the
|
||||||
// returned error will be ErrExpectedString.
|
// returned error will be ErrExpectedString.
|
||||||
@ -477,7 +509,7 @@ func (s *Stream) uint(maxbits int) (uint64, error) {
|
|||||||
return uint64(s.byteval), nil
|
return uint64(s.byteval), nil
|
||||||
case String:
|
case String:
|
||||||
if size > uint64(maxbits/8) {
|
if size > uint64(maxbits/8) {
|
||||||
return 0, fmt.Errorf("rlp: string is larger than %d bits", maxbits)
|
return 0, errUintOverflow
|
||||||
}
|
}
|
||||||
return s.readUint(byte(size))
|
return s.readUint(byte(size))
|
||||||
default:
|
default:
|
||||||
@ -543,6 +575,23 @@ func (s *Stream) Decode(val interface{}) error {
|
|||||||
return info.decoder(s, rval.Elem())
|
return info.decoder(s, rval.Elem())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset discards any information about the current decoding context
|
||||||
|
// and starts reading from r. If r does not also implement ByteReader,
|
||||||
|
// Stream will do its own buffering.
|
||||||
|
func (s *Stream) Reset(r io.Reader) {
|
||||||
|
bufr, ok := r.(ByteReader)
|
||||||
|
if !ok {
|
||||||
|
bufr = bufio.NewReader(r)
|
||||||
|
}
|
||||||
|
s.r = bufr
|
||||||
|
s.stack = s.stack[:0]
|
||||||
|
s.size = 0
|
||||||
|
s.kind = -1
|
||||||
|
if s.uintbuf == nil {
|
||||||
|
s.uintbuf = make([]byte, 8)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Kind returns the kind and size of the next value in the
|
// Kind returns the kind and size of the next value in the
|
||||||
// input stream.
|
// input stream.
|
||||||
//
|
//
|
||||||
|
@ -3,7 +3,6 @@ package rlp
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
@ -54,6 +53,24 @@ func TestStreamKind(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewListStream(t *testing.T) {
|
||||||
|
ls := NewListStream(bytes.NewReader(unhex("0101010101")), 3)
|
||||||
|
if k, size, err := ls.Kind(); k != List || size != 3 || err != nil {
|
||||||
|
t.Errorf("Kind() returned (%v, %d, %v), expected (List, 3, nil)", k, size, err)
|
||||||
|
}
|
||||||
|
if size, err := ls.List(); size != 3 || err != nil {
|
||||||
|
t.Errorf("List() returned (%d, %v), expected (3, nil)", size, err)
|
||||||
|
}
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
if val, err := ls.Uint(); val != 1 || err != nil {
|
||||||
|
t.Errorf("Uint() returned (%d, %v), expected (1, nil)", val, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := ls.ListEnd(); err != nil {
|
||||||
|
t.Errorf("ListEnd() returned %v, expected (3, nil)", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamErrors(t *testing.T) {
|
func TestStreamErrors(t *testing.T) {
|
||||||
type calls []string
|
type calls []string
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@ -69,7 +86,7 @@ func TestStreamErrors(t *testing.T) {
|
|||||||
{"81", calls{"Bytes"}, io.ErrUnexpectedEOF},
|
{"81", calls{"Bytes"}, io.ErrUnexpectedEOF},
|
||||||
{"81", calls{"Uint"}, io.ErrUnexpectedEOF},
|
{"81", calls{"Uint"}, io.ErrUnexpectedEOF},
|
||||||
{"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF},
|
{"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF},
|
||||||
{"89000000000000000001", calls{"Uint"}, errors.New("rlp: string is larger than 64 bits")},
|
{"89000000000000000001", calls{"Uint"}, errUintOverflow},
|
||||||
{"00", calls{"List"}, ErrExpectedList},
|
{"00", calls{"List"}, ErrExpectedList},
|
||||||
{"80", calls{"List"}, ErrExpectedList},
|
{"80", calls{"List"}, ErrExpectedList},
|
||||||
{"C0", calls{"List", "Uint"}, EOL},
|
{"C0", calls{"List", "Uint"}, EOL},
|
||||||
@ -163,7 +180,7 @@ type decodeTest struct {
|
|||||||
input string
|
input string
|
||||||
ptr interface{}
|
ptr interface{}
|
||||||
value interface{}
|
value interface{}
|
||||||
error error
|
error string
|
||||||
}
|
}
|
||||||
|
|
||||||
type simplestruct struct {
|
type simplestruct struct {
|
||||||
@ -196,8 +213,8 @@ var decodeTests = []decodeTest{
|
|||||||
{input: "820505", ptr: new(uint32), value: uint32(0x0505)},
|
{input: "820505", ptr: new(uint32), value: uint32(0x0505)},
|
||||||
{input: "83050505", ptr: new(uint32), value: uint32(0x050505)},
|
{input: "83050505", ptr: new(uint32), value: uint32(0x050505)},
|
||||||
{input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)},
|
{input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)},
|
||||||
{input: "850505050505", ptr: new(uint32), error: errors.New("rlp: string is larger than 32 bits")},
|
{input: "850505050505", ptr: new(uint32), error: "rlp: input string too big for uint32"},
|
||||||
{input: "C0", ptr: new(uint32), error: ErrExpectedString},
|
{input: "C0", ptr: new(uint32), error: ErrExpectedString.Error()},
|
||||||
|
|
||||||
// slices
|
// slices
|
||||||
{input: "C0", ptr: new([]int), value: []int{}},
|
{input: "C0", ptr: new([]int), value: []int{}},
|
||||||
@ -206,7 +223,7 @@ var decodeTests = []decodeTest{
|
|||||||
// arrays
|
// arrays
|
||||||
{input: "C0", ptr: new([5]int), value: [5]int{}},
|
{input: "C0", ptr: new([5]int), value: [5]int{}},
|
||||||
{input: "C50102030405", ptr: new([5]int), value: [5]int{1, 2, 3, 4, 5}},
|
{input: "C50102030405", ptr: new([5]int), value: [5]int{1, 2, 3, 4, 5}},
|
||||||
{input: "C6010203040506", ptr: new([5]int), error: errors.New("rlp: input List has more than 5 elements")},
|
{input: "C6010203040506", ptr: new([5]int), error: "rlp: input list has too many elements for [5]int"},
|
||||||
|
|
||||||
// byte slices
|
// byte slices
|
||||||
{input: "01", ptr: new([]byte), value: []byte{1}},
|
{input: "01", ptr: new([]byte), value: []byte{1}},
|
||||||
@ -214,7 +231,7 @@ var decodeTests = []decodeTest{
|
|||||||
{input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")},
|
{input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")},
|
||||||
{input: "C0", ptr: new([]byte), value: []byte{}},
|
{input: "C0", ptr: new([]byte), value: []byte{}},
|
||||||
{input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}},
|
{input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}},
|
||||||
{input: "C3820102", ptr: new([]byte), error: errors.New("rlp: string is larger than 8 bits")},
|
{input: "C3820102", ptr: new([]byte), error: "rlp: input string too big for uint8"},
|
||||||
|
|
||||||
// byte arrays
|
// byte arrays
|
||||||
{input: "01", ptr: new([5]byte), value: [5]byte{1}},
|
{input: "01", ptr: new([5]byte), value: [5]byte{1}},
|
||||||
@ -222,9 +239,9 @@ var decodeTests = []decodeTest{
|
|||||||
{input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}},
|
{input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}},
|
||||||
{input: "C0", ptr: new([5]byte), value: [5]byte{}},
|
{input: "C0", ptr: new([5]byte), value: [5]byte{}},
|
||||||
{input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}},
|
{input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}},
|
||||||
{input: "C3820102", ptr: new([5]byte), error: errors.New("rlp: string is larger than 8 bits")},
|
{input: "C3820102", ptr: new([5]byte), error: "rlp: input string too big for uint8"},
|
||||||
{input: "86010203040506", ptr: new([5]byte), error: errStringDoesntFitArray},
|
{input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too big for [5]uint8"},
|
||||||
{input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF},
|
{input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF.Error()},
|
||||||
|
|
||||||
// byte array reuse (should be zeroed)
|
// byte array reuse (should be zeroed)
|
||||||
{input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}},
|
{input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}},
|
||||||
@ -237,25 +254,25 @@ var decodeTests = []decodeTest{
|
|||||||
// zero sized byte arrays
|
// zero sized byte arrays
|
||||||
{input: "80", ptr: new([0]byte), value: [0]byte{}},
|
{input: "80", ptr: new([0]byte), value: [0]byte{}},
|
||||||
{input: "C0", ptr: new([0]byte), value: [0]byte{}},
|
{input: "C0", ptr: new([0]byte), value: [0]byte{}},
|
||||||
{input: "01", ptr: new([0]byte), error: errStringDoesntFitArray},
|
{input: "01", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"},
|
||||||
{input: "8101", ptr: new([0]byte), error: errStringDoesntFitArray},
|
{input: "8101", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"},
|
||||||
|
|
||||||
// strings
|
// strings
|
||||||
{input: "00", ptr: new(string), value: "\000"},
|
{input: "00", ptr: new(string), value: "\000"},
|
||||||
{input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"},
|
{input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"},
|
||||||
{input: "C0", ptr: new(string), error: ErrExpectedString},
|
{input: "C0", ptr: new(string), error: ErrExpectedString.Error()},
|
||||||
|
|
||||||
// big ints
|
// big ints
|
||||||
{input: "01", ptr: new(*big.Int), value: big.NewInt(1)},
|
{input: "01", ptr: new(*big.Int), value: big.NewInt(1)},
|
||||||
{input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt},
|
{input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt},
|
||||||
{input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works
|
{input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works
|
||||||
{input: "C0", ptr: new(*big.Int), error: ErrExpectedString},
|
{input: "C0", ptr: new(*big.Int), error: ErrExpectedString.Error()},
|
||||||
|
|
||||||
// structs
|
// structs
|
||||||
{input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}},
|
{input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}},
|
||||||
{input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}},
|
{input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}},
|
||||||
{input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}},
|
{input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}},
|
||||||
{input: "C3010101", ptr: new(simplestruct), error: errors.New("rlp: input List has too many elements")},
|
{input: "C3010101", ptr: new(simplestruct), error: "rlp: input list has too many elements for rlp.simplestruct"},
|
||||||
{
|
{
|
||||||
input: "C501C302C103",
|
input: "C501C302C103",
|
||||||
ptr: new(recstruct),
|
ptr: new(recstruct),
|
||||||
@ -286,20 +303,20 @@ var decodeTests = []decodeTest{
|
|||||||
|
|
||||||
func intp(i int) *int { return &i }
|
func intp(i int) *int { return &i }
|
||||||
|
|
||||||
func TestDecode(t *testing.T) {
|
func runTests(t *testing.T, decode func([]byte, interface{}) error) {
|
||||||
for i, test := range decodeTests {
|
for i, test := range decodeTests {
|
||||||
input, err := hex.DecodeString(test.input)
|
input, err := hex.DecodeString(test.input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("test %d: invalid hex input %q", i, test.input)
|
t.Errorf("test %d: invalid hex input %q", i, test.input)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = Decode(bytes.NewReader(input), test.ptr)
|
err = decode(input, test.ptr)
|
||||||
if err != nil && test.error == nil {
|
if err != nil && test.error == "" {
|
||||||
t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q",
|
t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q",
|
||||||
i, err, test.ptr, test.input)
|
i, err, test.ptr, test.input)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if test.error != nil && fmt.Sprint(err) != fmt.Sprint(test.error) {
|
if test.error != "" && fmt.Sprint(err) != test.error {
|
||||||
t.Errorf("test %d: Decode error mismatch\ngot %v\nwant %v\ndecoding into %T\ninput %q",
|
t.Errorf("test %d: Decode error mismatch\ngot %v\nwant %v\ndecoding into %T\ninput %q",
|
||||||
i, err, test.error, test.ptr, test.input)
|
i, err, test.error, test.ptr, test.input)
|
||||||
continue
|
continue
|
||||||
@ -312,6 +329,40 @@ func TestDecode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDecodeWithByteReader(t *testing.T) {
|
||||||
|
runTests(t, func(input []byte, into interface{}) error {
|
||||||
|
return Decode(bytes.NewReader(input), into)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// dumbReader reads from a byte slice but does not
|
||||||
|
// implement ReadByte.
|
||||||
|
type dumbReader []byte
|
||||||
|
|
||||||
|
func (r *dumbReader) Read(buf []byte) (n int, err error) {
|
||||||
|
if len(*r) == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n = copy(buf, *r)
|
||||||
|
*r = (*r)[n:]
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeWithNonByteReader(t *testing.T) {
|
||||||
|
runTests(t, func(input []byte, into interface{}) error {
|
||||||
|
r := dumbReader(input)
|
||||||
|
return Decode(&r, into)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeStreamReset(t *testing.T) {
|
||||||
|
s := NewStream(nil)
|
||||||
|
runTests(t, func(input []byte, into interface{}) error {
|
||||||
|
s.Reset(bytes.NewReader(input))
|
||||||
|
return s.Decode(into)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
type testDecoder struct{ called bool }
|
type testDecoder struct{ called bool }
|
||||||
|
|
||||||
func (t *testDecoder) DecodeRLP(s *Stream) error {
|
func (t *testDecoder) DecodeRLP(s *Stream) error {
|
||||||
|
Loading…
Reference in New Issue
Block a user