initial commit of p2p package
This commit is contained in:
parent
119c5b40a7
commit
771fbcc02e
63
p2p/client_identity.go
Normal file
63
p2p/client_identity.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc.
|
||||||
|
type ClientIdentity interface {
|
||||||
|
String() string
|
||||||
|
Pubkey() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type SimpleClientIdentity struct {
|
||||||
|
clientIdentifier string
|
||||||
|
version string
|
||||||
|
customIdentifier string
|
||||||
|
os string
|
||||||
|
implementation string
|
||||||
|
pubkey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey string) *SimpleClientIdentity {
|
||||||
|
clientIdentity := &SimpleClientIdentity{
|
||||||
|
clientIdentifier: clientIdentifier,
|
||||||
|
version: version,
|
||||||
|
customIdentifier: customIdentifier,
|
||||||
|
os: runtime.GOOS,
|
||||||
|
implementation: runtime.Version(),
|
||||||
|
pubkey: pubkey,
|
||||||
|
}
|
||||||
|
|
||||||
|
return clientIdentity
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SimpleClientIdentity) init() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SimpleClientIdentity) String() string {
|
||||||
|
var id string
|
||||||
|
if len(c.customIdentifier) > 0 {
|
||||||
|
id = "/" + c.customIdentifier
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s/v%s%s/%s/%s",
|
||||||
|
c.clientIdentifier,
|
||||||
|
c.version,
|
||||||
|
id,
|
||||||
|
c.os,
|
||||||
|
c.implementation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SimpleClientIdentity) Pubkey() []byte {
|
||||||
|
return []byte(c.pubkey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) {
|
||||||
|
c.customIdentifier = customIdentifier
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SimpleClientIdentity) GetCustomIdentifier() string {
|
||||||
|
return c.customIdentifier
|
||||||
|
}
|
30
p2p/client_identity_test.go
Normal file
30
p2p/client_identity_test.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientIdentity(t *testing.T) {
|
||||||
|
clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", "pubkey")
|
||||||
|
clientString := clientIdentity.String()
|
||||||
|
expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
|
||||||
|
if clientString != expected {
|
||||||
|
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
|
||||||
|
}
|
||||||
|
customIdentifier := clientIdentity.GetCustomIdentifier()
|
||||||
|
if customIdentifier != "test" {
|
||||||
|
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier)
|
||||||
|
}
|
||||||
|
clientIdentity.SetCustomIdentifier("test2")
|
||||||
|
customIdentifier = clientIdentity.GetCustomIdentifier()
|
||||||
|
if customIdentifier != "test2" {
|
||||||
|
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier)
|
||||||
|
}
|
||||||
|
clientString = clientIdentity.String()
|
||||||
|
expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version())
|
||||||
|
if clientString != expected {
|
||||||
|
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
|
||||||
|
}
|
||||||
|
}
|
275
p2p/connection.go
Normal file
275
p2p/connection.go
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
// "fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/eth-go/ethutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Connection struct {
|
||||||
|
conn net.Conn
|
||||||
|
// conn NetworkConnection
|
||||||
|
timeout time.Duration
|
||||||
|
in chan []byte
|
||||||
|
out chan []byte
|
||||||
|
err chan *PeerError
|
||||||
|
closingIn chan chan bool
|
||||||
|
closingOut chan chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// const readBufferLength = 2 //for testing
|
||||||
|
|
||||||
|
const readBufferLength = 1440
|
||||||
|
const partialsQueueSize = 10
|
||||||
|
const maxPendingQueueSize = 1
|
||||||
|
const defaultTimeout = 500
|
||||||
|
|
||||||
|
var magicToken = []byte{34, 64, 8, 145}
|
||||||
|
|
||||||
|
func (self *Connection) Open() {
|
||||||
|
go self.startRead()
|
||||||
|
go self.startWrite()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Connection) Close() {
|
||||||
|
self.closeIn()
|
||||||
|
self.closeOut()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Connection) closeIn() {
|
||||||
|
errc := make(chan bool)
|
||||||
|
self.closingIn <- errc
|
||||||
|
<-errc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Connection) closeOut() {
|
||||||
|
errc := make(chan bool)
|
||||||
|
self.closingOut <- errc
|
||||||
|
<-errc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection {
|
||||||
|
return &Connection{
|
||||||
|
conn: conn,
|
||||||
|
timeout: defaultTimeout,
|
||||||
|
in: make(chan []byte),
|
||||||
|
out: make(chan []byte),
|
||||||
|
err: errchan,
|
||||||
|
closingIn: make(chan chan bool, 1),
|
||||||
|
closingOut: make(chan chan bool, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Connection) Read() <-chan []byte {
|
||||||
|
return self.in
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Connection) Write() chan<- []byte {
|
||||||
|
return self.out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Connection) Error() <-chan *PeerError {
|
||||||
|
return self.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Connection) startRead() {
|
||||||
|
payloads := make(chan []byte)
|
||||||
|
done := make(chan *PeerError)
|
||||||
|
pending := [][]byte{}
|
||||||
|
var head []byte
|
||||||
|
var wait time.Duration // initally 0 (no delay)
|
||||||
|
read := time.After(wait * time.Millisecond)
|
||||||
|
|
||||||
|
for {
|
||||||
|
// if pending empty, nil channel blocks
|
||||||
|
var in chan []byte
|
||||||
|
if len(pending) > 0 {
|
||||||
|
in = self.in // enable send case
|
||||||
|
head = pending[0]
|
||||||
|
} else {
|
||||||
|
in = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-read:
|
||||||
|
go self.read(payloads, done)
|
||||||
|
case err := <-done:
|
||||||
|
if err == nil { // no error but nothing to read
|
||||||
|
if len(pending) < maxPendingQueueSize {
|
||||||
|
wait = 100
|
||||||
|
} else if wait == 0 {
|
||||||
|
wait = 100
|
||||||
|
} else {
|
||||||
|
wait = 2 * wait
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.err <- err // report error
|
||||||
|
wait = 100
|
||||||
|
}
|
||||||
|
read = time.After(wait * time.Millisecond)
|
||||||
|
case payload := <-payloads:
|
||||||
|
pending = append(pending, payload)
|
||||||
|
if len(pending) < maxPendingQueueSize {
|
||||||
|
wait = 0
|
||||||
|
} else {
|
||||||
|
wait = 100
|
||||||
|
}
|
||||||
|
read = time.After(wait * time.Millisecond)
|
||||||
|
case in <- head:
|
||||||
|
pending = pending[1:]
|
||||||
|
case errc := <-self.closingIn:
|
||||||
|
errc <- true
|
||||||
|
close(self.in)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Connection) startWrite() {
|
||||||
|
pending := [][]byte{}
|
||||||
|
done := make(chan *PeerError)
|
||||||
|
writing := false
|
||||||
|
for {
|
||||||
|
if len(pending) > 0 && !writing {
|
||||||
|
writing = true
|
||||||
|
go self.write(pending[0], done)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case payload := <-self.out:
|
||||||
|
pending = append(pending, payload)
|
||||||
|
case err := <-done:
|
||||||
|
if err == nil {
|
||||||
|
pending = pending[1:]
|
||||||
|
writing = false
|
||||||
|
} else {
|
||||||
|
self.err <- err // report error
|
||||||
|
}
|
||||||
|
case errc := <-self.closingOut:
|
||||||
|
errc <- true
|
||||||
|
close(self.out)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func pack(payload []byte) (packet []byte) {
|
||||||
|
length := ethutil.NumberToBytes(uint32(len(payload)), 32)
|
||||||
|
// return error if too long?
|
||||||
|
// Write magic token and payload length (first 8 bytes)
|
||||||
|
packet = append(magicToken, length...)
|
||||||
|
packet = append(packet, payload...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func avoidPanic(done chan *PeerError) {
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
err := NewPeerError(MiscError, " %v", rec)
|
||||||
|
logger.Debugln(err)
|
||||||
|
done <- err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Connection) write(payload []byte, done chan *PeerError) {
|
||||||
|
defer avoidPanic(done)
|
||||||
|
var err *PeerError
|
||||||
|
_, ok := self.conn.Write(pack(payload))
|
||||||
|
if ok != nil {
|
||||||
|
err = NewPeerError(WriteError, " %v", ok)
|
||||||
|
logger.Debugln(err)
|
||||||
|
}
|
||||||
|
done <- err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Connection) read(payloads chan []byte, done chan *PeerError) {
|
||||||
|
//defer avoidPanic(done)
|
||||||
|
|
||||||
|
partials := make(chan []byte, partialsQueueSize)
|
||||||
|
errc := make(chan *PeerError)
|
||||||
|
go self.readPartials(partials, errc)
|
||||||
|
|
||||||
|
packet := []byte{}
|
||||||
|
length := 8
|
||||||
|
start := true
|
||||||
|
var err *PeerError
|
||||||
|
out:
|
||||||
|
for {
|
||||||
|
// appends partials read via connection until packet is
|
||||||
|
// - either parseable (>=8bytes)
|
||||||
|
// - or complete (payload fully consumed)
|
||||||
|
for len(packet) < length {
|
||||||
|
partial, ok := <-partials
|
||||||
|
if !ok { // partials channel is closed
|
||||||
|
err = <-errc
|
||||||
|
if err == nil && len(packet) > 0 {
|
||||||
|
if start {
|
||||||
|
err = NewPeerError(PacketTooShort, "%v", packet)
|
||||||
|
} else {
|
||||||
|
err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break out
|
||||||
|
}
|
||||||
|
packet = append(packet, partial...)
|
||||||
|
}
|
||||||
|
if start {
|
||||||
|
// at least 8 bytes read, can validate packet
|
||||||
|
if bytes.Compare(magicToken, packet[:4]) != 0 {
|
||||||
|
err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
length = int(ethutil.BytesToNumber(packet[4:8]))
|
||||||
|
packet = packet[8:]
|
||||||
|
|
||||||
|
if length > 0 {
|
||||||
|
start = false // now consuming payload
|
||||||
|
} else { //penalize peer but read on
|
||||||
|
self.err <- NewPeerError(EmptyPayload, "")
|
||||||
|
length = 8
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// packet complete (payload fully consumed)
|
||||||
|
payloads <- packet[:length]
|
||||||
|
packet = packet[length:] // resclice packet
|
||||||
|
start = true
|
||||||
|
length = 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// this stops partials read via the connection, should we?
|
||||||
|
//if err != nil {
|
||||||
|
// select {
|
||||||
|
// case errc <- err
|
||||||
|
// default:
|
||||||
|
//}
|
||||||
|
done <- err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) {
|
||||||
|
defer close(partials)
|
||||||
|
for {
|
||||||
|
// Give buffering some time
|
||||||
|
self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond))
|
||||||
|
buffer := make([]byte, readBufferLength)
|
||||||
|
// read partial from connection
|
||||||
|
bytesRead, err := self.conn.Read(buffer)
|
||||||
|
if err == nil || err.Error() == "EOF" {
|
||||||
|
if bytesRead > 0 {
|
||||||
|
partials <- buffer[:bytesRead]
|
||||||
|
}
|
||||||
|
if err != nil && err.Error() == "EOF" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// unexpected error, report to errc
|
||||||
|
err := NewPeerError(ReadError, " %v", err)
|
||||||
|
logger.Debugln(err)
|
||||||
|
errc <- err
|
||||||
|
return // will close partials channel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(errc)
|
||||||
|
}
|
222
p2p/connection_test.go
Normal file
222
p2p/connection_test.go
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TestNetworkConnection struct {
|
||||||
|
in chan []byte
|
||||||
|
current []byte
|
||||||
|
Out [][]byte
|
||||||
|
addr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
|
||||||
|
return &TestNetworkConnection{
|
||||||
|
in: make(chan []byte),
|
||||||
|
current: []byte{},
|
||||||
|
Out: [][]byte{},
|
||||||
|
addr: addr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
|
||||||
|
time.Sleep(latency)
|
||||||
|
for _, s := range packets {
|
||||||
|
self.in <- s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
|
||||||
|
if len(self.current) == 0 {
|
||||||
|
select {
|
||||||
|
case self.current = <-self.in:
|
||||||
|
default:
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
length := len(self.current)
|
||||||
|
if length > len(buff) {
|
||||||
|
copy(buff[:], self.current[:len(buff)])
|
||||||
|
self.current = self.current[len(buff):]
|
||||||
|
return len(buff), nil
|
||||||
|
} else {
|
||||||
|
copy(buff[:length], self.current[:])
|
||||||
|
self.current = []byte{}
|
||||||
|
return length, io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
|
||||||
|
self.Out = append(self.Out, buff)
|
||||||
|
fmt.Printf("net write %v\n%v\n", len(self.Out), buff)
|
||||||
|
return len(buff), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) Close() (err error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
|
||||||
|
return self.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupConnection() (*Connection, *TestNetworkConnection) {
|
||||||
|
addr := &TestAddr{"test:30303"}
|
||||||
|
net := NewTestNetworkConnection(addr)
|
||||||
|
conn := NewConnection(net, NewPeerErrorChannel())
|
||||||
|
conn.Open()
|
||||||
|
return conn, net
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadingNilPacket(t *testing.T) {
|
||||||
|
conn, net := setupConnection()
|
||||||
|
go net.In(0, []byte{})
|
||||||
|
// time.Sleep(10 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case packet := <-conn.Read():
|
||||||
|
t.Errorf("read %v", packet)
|
||||||
|
case err := <-conn.Error():
|
||||||
|
t.Errorf("incorrect error %v", err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadingShortPacket(t *testing.T) {
|
||||||
|
conn, net := setupConnection()
|
||||||
|
go net.In(0, []byte{0})
|
||||||
|
select {
|
||||||
|
case packet := <-conn.Read():
|
||||||
|
t.Errorf("read %v", packet)
|
||||||
|
case err := <-conn.Error():
|
||||||
|
if err.Code != PacketTooShort {
|
||||||
|
t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadingInvalidPacket(t *testing.T) {
|
||||||
|
conn, net := setupConnection()
|
||||||
|
go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0})
|
||||||
|
select {
|
||||||
|
case packet := <-conn.Read():
|
||||||
|
t.Errorf("read %v", packet)
|
||||||
|
case err := <-conn.Error():
|
||||||
|
if err.Code != MagicTokenMismatch {
|
||||||
|
t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadingInvalidPayload(t *testing.T) {
|
||||||
|
conn, net := setupConnection()
|
||||||
|
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0})
|
||||||
|
select {
|
||||||
|
case packet := <-conn.Read():
|
||||||
|
t.Errorf("read %v", packet)
|
||||||
|
case err := <-conn.Error():
|
||||||
|
if err.Code != PayloadTooShort {
|
||||||
|
t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadingEmptyPayload(t *testing.T) {
|
||||||
|
conn, net := setupConnection()
|
||||||
|
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0})
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case packet := <-conn.Read():
|
||||||
|
t.Errorf("read %v", packet)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case err := <-conn.Error():
|
||||||
|
code := err.Code
|
||||||
|
if code != EmptyPayload {
|
||||||
|
t.Errorf("incorrect error, expected EmptyPayload, got %v", code)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Errorf("no error, expected EmptyPayload")
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadingCompletePacket(t *testing.T) {
|
||||||
|
conn, net := setupConnection()
|
||||||
|
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1})
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case packet := <-conn.Read():
|
||||||
|
if bytes.Compare(packet, []byte{1}) != 0 {
|
||||||
|
t.Errorf("incorrect payload read")
|
||||||
|
}
|
||||||
|
case err := <-conn.Error():
|
||||||
|
t.Errorf("incorrect error %v", err)
|
||||||
|
default:
|
||||||
|
t.Errorf("nothing read")
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadingTwoCompletePackets(t *testing.T) {
|
||||||
|
conn, net := setupConnection()
|
||||||
|
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1})
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case packet := <-conn.Read():
|
||||||
|
if bytes.Compare(packet, []byte{byte(i)}) != 0 {
|
||||||
|
t.Errorf("incorrect payload read")
|
||||||
|
}
|
||||||
|
case err := <-conn.Error():
|
||||||
|
t.Errorf("incorrect error %v", err)
|
||||||
|
default:
|
||||||
|
t.Errorf("nothing read")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriting(t *testing.T) {
|
||||||
|
conn, net := setupConnection()
|
||||||
|
conn.Write() <- []byte{0}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
if len(net.Out) == 0 {
|
||||||
|
t.Errorf("no output")
|
||||||
|
} else {
|
||||||
|
out := net.Out[0]
|
||||||
|
if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 {
|
||||||
|
t.Errorf("incorrect packet %v", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243
|
75
p2p/message.go
Normal file
75
p2p/message.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
// "fmt"
|
||||||
|
"github.com/ethereum/eth-go/ethutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MsgCode uint8
|
||||||
|
|
||||||
|
type Msg struct {
|
||||||
|
code MsgCode // this is the raw code as per adaptive msg code scheme
|
||||||
|
data *ethutil.Value
|
||||||
|
encoded []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Msg) Code() MsgCode {
|
||||||
|
return self.code
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Msg) Data() *ethutil.Value {
|
||||||
|
return self.data
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) {
|
||||||
|
|
||||||
|
// // data := [][]interface{}{}
|
||||||
|
// data := []interface{}{}
|
||||||
|
// for _, value := range params {
|
||||||
|
// if encodable, ok := value.(ethutil.RlpEncodeDecode); ok {
|
||||||
|
// data = append(data, encodable.RlpValue())
|
||||||
|
// } else if raw, ok := value.([]interface{}); ok {
|
||||||
|
// data = append(data, raw)
|
||||||
|
// } else {
|
||||||
|
// // data = append(data, interface{}(raw))
|
||||||
|
// err = fmt.Errorf("Unable to encode object of type %T", value)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
return &Msg{
|
||||||
|
code: code,
|
||||||
|
data: ethutil.NewValue(interface{}(params)),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) {
|
||||||
|
value := ethutil.NewValueFromBytes(encoded)
|
||||||
|
// Type of message
|
||||||
|
code := value.Get(0).Uint()
|
||||||
|
// Actual data
|
||||||
|
data := value.SliceFrom(1)
|
||||||
|
|
||||||
|
msg = &Msg{
|
||||||
|
code: MsgCode(code),
|
||||||
|
data: data,
|
||||||
|
// data: ethutil.NewValue(data),
|
||||||
|
encoded: encoded,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Msg) Decode(offset MsgCode) {
|
||||||
|
self.code = self.code - offset
|
||||||
|
}
|
||||||
|
|
||||||
|
// encode takes an offset argument to implement adaptive message coding
|
||||||
|
// the encoded message is memoized to make msgs relayed to several peers more efficient
|
||||||
|
func (self *Msg) Encode(offset MsgCode) (res []byte) {
|
||||||
|
if len(self.encoded) == 0 {
|
||||||
|
res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode()
|
||||||
|
self.encoded = res
|
||||||
|
} else {
|
||||||
|
res = self.encoded
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
38
p2p/message_test.go
Normal file
38
p2p/message_test.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewMsg(t *testing.T) {
|
||||||
|
msg, _ := NewMsg(3, 1, "000")
|
||||||
|
if msg.Code() != 3 {
|
||||||
|
t.Errorf("incorrect code %v", msg.Code())
|
||||||
|
}
|
||||||
|
data0 := msg.Data().Get(0).Uint()
|
||||||
|
data1 := string(msg.Data().Get(1).Bytes())
|
||||||
|
if data0 != 1 {
|
||||||
|
t.Errorf("incorrect data %v", data0)
|
||||||
|
}
|
||||||
|
if data1 != "000" {
|
||||||
|
t.Errorf("incorrect data %v", data1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeMsg(t *testing.T) {
|
||||||
|
msg, _ := NewMsg(3, 1, "000")
|
||||||
|
encoded := msg.Encode(3)
|
||||||
|
msg, _ = NewMsgFromBytes(encoded)
|
||||||
|
msg.Decode(3)
|
||||||
|
if msg.Code() != 3 {
|
||||||
|
t.Errorf("incorrect code %v", msg.Code())
|
||||||
|
}
|
||||||
|
data0 := msg.Data().Get(0).Uint()
|
||||||
|
data1 := msg.Data().Get(1).Str()
|
||||||
|
if data0 != 1 {
|
||||||
|
t.Errorf("incorrect data %v", data0)
|
||||||
|
}
|
||||||
|
if data1 != "000" {
|
||||||
|
t.Errorf("incorrect data %v", data1)
|
||||||
|
}
|
||||||
|
}
|
220
p2p/messenger.go
Normal file
220
p2p/messenger.go
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
handlerTimeout = 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
type Handlers map[string](func(p *Peer) Protocol)
|
||||||
|
|
||||||
|
type Messenger struct {
|
||||||
|
conn *Connection
|
||||||
|
peer *Peer
|
||||||
|
handlers Handlers
|
||||||
|
protocolLock sync.RWMutex
|
||||||
|
protocols []Protocol
|
||||||
|
offsets []MsgCode // offsets for adaptive message idss
|
||||||
|
protocolTable map[string]int
|
||||||
|
quit chan chan bool
|
||||||
|
err chan *PeerError
|
||||||
|
pulse chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger {
|
||||||
|
baseProtocol := NewBaseProtocol(peer)
|
||||||
|
return &Messenger{
|
||||||
|
conn: conn,
|
||||||
|
peer: peer,
|
||||||
|
offsets: []MsgCode{baseProtocol.Offset()},
|
||||||
|
handlers: handlers,
|
||||||
|
protocols: []Protocol{baseProtocol},
|
||||||
|
protocolTable: make(map[string]int),
|
||||||
|
err: errchan,
|
||||||
|
pulse: make(chan bool, 1),
|
||||||
|
quit: make(chan chan bool, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Messenger) Start() {
|
||||||
|
self.conn.Open()
|
||||||
|
go self.messenger()
|
||||||
|
self.protocolLock.RLock()
|
||||||
|
defer self.protocolLock.RUnlock()
|
||||||
|
self.protocols[0].Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Messenger) Stop() {
|
||||||
|
// close pulse to stop ping pong monitoring
|
||||||
|
close(self.pulse)
|
||||||
|
self.protocolLock.RLock()
|
||||||
|
defer self.protocolLock.RUnlock()
|
||||||
|
for _, protocol := range self.protocols {
|
||||||
|
protocol.Stop() // could be parallel
|
||||||
|
}
|
||||||
|
q := make(chan bool)
|
||||||
|
self.quit <- q
|
||||||
|
<-q
|
||||||
|
self.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Messenger) messenger() {
|
||||||
|
in := self.conn.Read()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case payload, ok := <-in:
|
||||||
|
//dispatches message to the protocol asynchronously
|
||||||
|
if ok {
|
||||||
|
go self.handle(payload)
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case q := <-self.quit:
|
||||||
|
q <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handles each message by dispatching to the appropriate protocol
|
||||||
|
// using adaptive message codes
|
||||||
|
// this function is started as a separate go routine for each message
|
||||||
|
// it waits for the protocol response
|
||||||
|
// then encodes and sends outgoing messages to the connection's write channel
|
||||||
|
func (self *Messenger) handle(payload []byte) {
|
||||||
|
// send ping to heartbeat channel signalling time of last message
|
||||||
|
// select {
|
||||||
|
// case self.pulse <- true:
|
||||||
|
// default:
|
||||||
|
// }
|
||||||
|
self.pulse <- true
|
||||||
|
// initialise message from payload
|
||||||
|
msg, err := NewMsgFromBytes(payload)
|
||||||
|
if err != nil {
|
||||||
|
self.err <- NewPeerError(MiscError, " %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// retrieves protocol based on message Code
|
||||||
|
protocol, offset, peerErr := self.getProtocol(msg.Code())
|
||||||
|
if err != nil {
|
||||||
|
self.err <- peerErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// reset message code based on adaptive offset
|
||||||
|
msg.Decode(offset)
|
||||||
|
// dispatches
|
||||||
|
response := make(chan *Msg)
|
||||||
|
go protocol.HandleIn(msg, response)
|
||||||
|
// protocol reponse timeout to prevent leaks
|
||||||
|
timer := time.After(handlerTimeout * time.Millisecond)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case outgoing, ok := <-response:
|
||||||
|
// we check if response channel is not closed
|
||||||
|
if ok {
|
||||||
|
self.conn.Write() <- outgoing.Encode(offset)
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-timer:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// negotiated protocols
|
||||||
|
// stores offsets needed for adaptive message id scheme
|
||||||
|
|
||||||
|
// based on offsets set at handshake
|
||||||
|
// get the right protocol to handle the message
|
||||||
|
func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) {
|
||||||
|
self.protocolLock.RLock()
|
||||||
|
defer self.protocolLock.RUnlock()
|
||||||
|
base := MsgCode(0)
|
||||||
|
for index, offset := range self.offsets {
|
||||||
|
if code < offset {
|
||||||
|
return self.protocols[index], base, nil
|
||||||
|
}
|
||||||
|
base = offset
|
||||||
|
}
|
||||||
|
return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) {
|
||||||
|
fmt.Printf("pingpong keepalive started at %v", time.Now())
|
||||||
|
|
||||||
|
timer := time.After(timeout)
|
||||||
|
pinged := false
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case _, ok := <-self.pulse:
|
||||||
|
if ok {
|
||||||
|
pinged = false
|
||||||
|
timer = time.After(timeout)
|
||||||
|
} else {
|
||||||
|
// pulse is closed, stop monitoring
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-timer:
|
||||||
|
if pinged {
|
||||||
|
fmt.Printf("timeout at %v", time.Now())
|
||||||
|
timeoutCallback()
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
fmt.Printf("pinged at %v", time.Now())
|
||||||
|
pingCallback()
|
||||||
|
timer = time.After(gracePeriod)
|
||||||
|
pinged = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Messenger) AddProtocols(protocols []string) {
|
||||||
|
self.protocolLock.Lock()
|
||||||
|
defer self.protocolLock.Unlock()
|
||||||
|
i := len(self.offsets)
|
||||||
|
offset := self.offsets[i-1]
|
||||||
|
for _, name := range protocols {
|
||||||
|
protocolFunc, ok := self.handlers[name]
|
||||||
|
if ok {
|
||||||
|
protocol := protocolFunc(self.peer)
|
||||||
|
self.protocolTable[name] = i
|
||||||
|
i++
|
||||||
|
offset += protocol.Offset()
|
||||||
|
fmt.Println("offset ", name, offset)
|
||||||
|
|
||||||
|
self.offsets = append(self.offsets, offset)
|
||||||
|
self.protocols = append(self.protocols, protocol)
|
||||||
|
protocol.Start()
|
||||||
|
} else {
|
||||||
|
fmt.Println("no ", name)
|
||||||
|
// protocol not handled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Messenger) Write(protocol string, msg *Msg) error {
|
||||||
|
self.protocolLock.RLock()
|
||||||
|
defer self.protocolLock.RUnlock()
|
||||||
|
i := 0
|
||||||
|
offset := MsgCode(0)
|
||||||
|
if len(protocol) > 0 {
|
||||||
|
var ok bool
|
||||||
|
i, ok = self.protocolTable[protocol]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("protocol %v not handled by peer", protocol)
|
||||||
|
}
|
||||||
|
offset = self.offsets[i-1]
|
||||||
|
}
|
||||||
|
handler := self.protocols[i]
|
||||||
|
// checking if protocol status/caps allows the message to be sent out
|
||||||
|
if handler.HandleOut(msg) {
|
||||||
|
self.conn.Write() <- msg.Encode(offset)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
146
p2p/messenger_test.go
Normal file
146
p2p/messenger_test.go
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
// "fmt"
|
||||||
|
"bytes"
|
||||||
|
"github.com/ethereum/eth-go/ethutil"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) {
|
||||||
|
errchan := NewPeerErrorChannel()
|
||||||
|
addr := &TestAddr{"test:30303"}
|
||||||
|
net := NewTestNetworkConnection(addr)
|
||||||
|
conn := NewConnection(net, errchan)
|
||||||
|
mess := NewMessenger(nil, conn, errchan, handlers)
|
||||||
|
mess.Start()
|
||||||
|
return net, errchan, mess
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestProtocol struct {
|
||||||
|
Msgs []*Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestProtocol) Start() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestProtocol) Stop() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestProtocol) Offset() MsgCode {
|
||||||
|
return MsgCode(5)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) {
|
||||||
|
self.Msgs = append(self.Msgs, msg)
|
||||||
|
close(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestProtocol) HandleOut(msg *Msg) bool {
|
||||||
|
if msg.Code() > 3 {
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestProtocol) Name() string {
|
||||||
|
return "a"
|
||||||
|
}
|
||||||
|
|
||||||
|
func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte {
|
||||||
|
msg, _ := NewMsg(code, params...)
|
||||||
|
encoded := msg.Encode(offset)
|
||||||
|
packet := []byte{34, 64, 8, 145}
|
||||||
|
packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...)
|
||||||
|
return append(packet, encoded...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRead(t *testing.T) {
|
||||||
|
handlers := make(Handlers)
|
||||||
|
testProtocol := &TestProtocol{Msgs: []*Msg{}}
|
||||||
|
handlers["a"] = func(p *Peer) Protocol { return testProtocol }
|
||||||
|
net, _, mess := setupMessenger(handlers)
|
||||||
|
mess.AddProtocols([]string{"a"})
|
||||||
|
defer mess.Stop()
|
||||||
|
wait := 1 * time.Millisecond
|
||||||
|
packet := Packet(16, 1, uint32(1), "000")
|
||||||
|
go net.In(0, packet)
|
||||||
|
time.Sleep(wait)
|
||||||
|
if len(testProtocol.Msgs) != 1 {
|
||||||
|
t.Errorf("msg not relayed to correct protocol")
|
||||||
|
} else {
|
||||||
|
if testProtocol.Msgs[0].Code() != 1 {
|
||||||
|
t.Errorf("incorrect msg code relayed to protocol")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrite(t *testing.T) {
|
||||||
|
handlers := make(Handlers)
|
||||||
|
testProtocol := &TestProtocol{Msgs: []*Msg{}}
|
||||||
|
handlers["a"] = func(p *Peer) Protocol { return testProtocol }
|
||||||
|
net, _, mess := setupMessenger(handlers)
|
||||||
|
mess.AddProtocols([]string{"a"})
|
||||||
|
defer mess.Stop()
|
||||||
|
wait := 1 * time.Millisecond
|
||||||
|
msg, _ := NewMsg(3, uint32(1), "000")
|
||||||
|
err := mess.Write("b", msg)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expect error for unknown protocol")
|
||||||
|
}
|
||||||
|
err = mess.Write("a", msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expect no error for known protocol: %v", err)
|
||||||
|
} else {
|
||||||
|
time.Sleep(wait)
|
||||||
|
if len(net.Out) != 1 {
|
||||||
|
t.Errorf("msg not written")
|
||||||
|
} else {
|
||||||
|
out := net.Out[0]
|
||||||
|
packet := Packet(16, 3, uint32(1), "000")
|
||||||
|
if bytes.Compare(out, packet) != 0 {
|
||||||
|
t.Errorf("incorrect packet %v", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPulse(t *testing.T) {
|
||||||
|
net, _, mess := setupMessenger(make(Handlers))
|
||||||
|
defer mess.Stop()
|
||||||
|
ping := false
|
||||||
|
timeout := false
|
||||||
|
pingTimeout := 10 * time.Millisecond
|
||||||
|
gracePeriod := 200 * time.Millisecond
|
||||||
|
go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true })
|
||||||
|
net.In(0, Packet(0, 1))
|
||||||
|
if ping {
|
||||||
|
t.Errorf("ping sent too early")
|
||||||
|
}
|
||||||
|
time.Sleep(pingTimeout + 100*time.Millisecond)
|
||||||
|
if !ping {
|
||||||
|
t.Errorf("no ping sent after timeout")
|
||||||
|
}
|
||||||
|
if timeout {
|
||||||
|
t.Errorf("timeout too early")
|
||||||
|
}
|
||||||
|
ping = false
|
||||||
|
net.In(0, Packet(0, 1))
|
||||||
|
time.Sleep(pingTimeout + 100*time.Millisecond)
|
||||||
|
if !ping {
|
||||||
|
t.Errorf("no ping sent after timeout")
|
||||||
|
}
|
||||||
|
if timeout {
|
||||||
|
t.Errorf("timeout too early")
|
||||||
|
}
|
||||||
|
ping = false
|
||||||
|
time.Sleep(gracePeriod)
|
||||||
|
if ping {
|
||||||
|
t.Errorf("ping called twice")
|
||||||
|
}
|
||||||
|
if !timeout {
|
||||||
|
t.Errorf("no timeout after grace period")
|
||||||
|
}
|
||||||
|
}
|
55
p2p/natpmp.go
Normal file
55
p2p/natpmp.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
natpmp "github.com/jackpal/go-nat-pmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Adapt the NAT-PMP protocol to the NAT interface
|
||||||
|
|
||||||
|
// TODO:
|
||||||
|
// + Register for changes to the external address.
|
||||||
|
// + Re-register port mapping when router reboots.
|
||||||
|
// + A mechanism for keeping a port mapping registered.
|
||||||
|
|
||||||
|
type natPMPClient struct {
|
||||||
|
client *natpmp.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNatPMP(gateway net.IP) (nat NAT) {
|
||||||
|
return &natPMPClient{natpmp.NewClient(gateway)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) {
|
||||||
|
response, err := n.client.GetExternalAddress()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ip := response.ExternalIPAddress
|
||||||
|
addr = net.IPv4(ip[0], ip[1], ip[2], ip[3])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int,
|
||||||
|
description string, timeout int) (mappedExternalPort int, err error) {
|
||||||
|
if timeout <= 0 {
|
||||||
|
err = fmt.Errorf("timeout must not be <= 0")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
|
||||||
|
response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mappedExternalPort = int(response.MappedExternalPort)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
|
||||||
|
// To destroy a mapping, send an add-port with
|
||||||
|
// an internalPort of the internal port to destroy, an external port of zero and a time of zero.
|
||||||
|
_, err = n.client.AddPortMapping(protocol, internalPort, 0, 0)
|
||||||
|
return
|
||||||
|
}
|
335
p2p/natupnp.go
Normal file
335
p2p/natupnp.go
Normal file
@ -0,0 +1,335 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
// Just enough UPnP to be able to forward ports
|
||||||
|
//
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/xml"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type upnpNAT struct {
|
||||||
|
serviceURL string
|
||||||
|
ourIP string
|
||||||
|
}
|
||||||
|
|
||||||
|
func upnpDiscover(attempts int) (nat NAT, err error) {
|
||||||
|
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, err := net.ListenPacket("udp4", ":0")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
socket := conn.(*net.UDPConn)
|
||||||
|
defer socket.Close()
|
||||||
|
|
||||||
|
err = socket.SetDeadline(time.Now().Add(10 * time.Second))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
|
||||||
|
buf := bytes.NewBufferString(
|
||||||
|
"M-SEARCH * HTTP/1.1\r\n" +
|
||||||
|
"HOST: 239.255.255.250:1900\r\n" +
|
||||||
|
st +
|
||||||
|
"MAN: \"ssdp:discover\"\r\n" +
|
||||||
|
"MX: 2\r\n\r\n")
|
||||||
|
message := buf.Bytes()
|
||||||
|
answerBytes := make([]byte, 1024)
|
||||||
|
for i := 0; i < attempts; i++ {
|
||||||
|
_, err = socket.WriteToUDP(message, ssdp)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var n int
|
||||||
|
n, _, err = socket.ReadFromUDP(answerBytes)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
// socket.Close()
|
||||||
|
// return
|
||||||
|
}
|
||||||
|
answer := string(answerBytes[0:n])
|
||||||
|
if strings.Index(answer, "\r\n"+st) < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// HTTP header field names are case-insensitive.
|
||||||
|
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
|
||||||
|
locString := "\r\nlocation: "
|
||||||
|
answer = strings.ToLower(answer)
|
||||||
|
locIndex := strings.Index(answer, locString)
|
||||||
|
if locIndex < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
loc := answer[locIndex+len(locString):]
|
||||||
|
endIndex := strings.Index(loc, "\r\n")
|
||||||
|
if endIndex < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
locURL := loc[0:endIndex]
|
||||||
|
var serviceURL string
|
||||||
|
serviceURL, err = getServiceURL(locURL)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var ourIP string
|
||||||
|
ourIP, err = getOurIP()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = errors.New("UPnP port discovery failed.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// service represents the Service type in an UPnP xml description.
|
||||||
|
// Only the parts we care about are present and thus the xml may have more
|
||||||
|
// fields than present in the structure.
|
||||||
|
type service struct {
|
||||||
|
ServiceType string `xml:"serviceType"`
|
||||||
|
ControlURL string `xml:"controlURL"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// deviceList represents the deviceList type in an UPnP xml description.
|
||||||
|
// Only the parts we care about are present and thus the xml may have more
|
||||||
|
// fields than present in the structure.
|
||||||
|
type deviceList struct {
|
||||||
|
XMLName xml.Name `xml:"deviceList"`
|
||||||
|
Device []device `xml:"device"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// serviceList represents the serviceList type in an UPnP xml description.
|
||||||
|
// Only the parts we care about are present and thus the xml may have more
|
||||||
|
// fields than present in the structure.
|
||||||
|
type serviceList struct {
|
||||||
|
XMLName xml.Name `xml:"serviceList"`
|
||||||
|
Service []service `xml:"service"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// device represents the device type in an UPnP xml description.
|
||||||
|
// Only the parts we care about are present and thus the xml may have more
|
||||||
|
// fields than present in the structure.
|
||||||
|
type device struct {
|
||||||
|
XMLName xml.Name `xml:"device"`
|
||||||
|
DeviceType string `xml:"deviceType"`
|
||||||
|
DeviceList deviceList `xml:"deviceList"`
|
||||||
|
ServiceList serviceList `xml:"serviceList"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// specVersion represents the specVersion in a UPnP xml description.
|
||||||
|
// Only the parts we care about are present and thus the xml may have more
|
||||||
|
// fields than present in the structure.
|
||||||
|
type specVersion struct {
|
||||||
|
XMLName xml.Name `xml:"specVersion"`
|
||||||
|
Major int `xml:"major"`
|
||||||
|
Minor int `xml:"minor"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// root represents the Root document for a UPnP xml description.
|
||||||
|
// Only the parts we care about are present and thus the xml may have more
|
||||||
|
// fields than present in the structure.
|
||||||
|
type root struct {
|
||||||
|
XMLName xml.Name `xml:"root"`
|
||||||
|
SpecVersion specVersion
|
||||||
|
Device device
|
||||||
|
}
|
||||||
|
|
||||||
|
func getChildDevice(d *device, deviceType string) *device {
|
||||||
|
dl := d.DeviceList.Device
|
||||||
|
for i := 0; i < len(dl); i++ {
|
||||||
|
if dl[i].DeviceType == deviceType {
|
||||||
|
return &dl[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getChildService(d *device, serviceType string) *service {
|
||||||
|
sl := d.ServiceList.Service
|
||||||
|
for i := 0; i < len(sl); i++ {
|
||||||
|
if sl[i].ServiceType == serviceType {
|
||||||
|
return &sl[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOurIP() (ip string, err error) {
|
||||||
|
hostname, err := os.Hostname()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p, err := net.LookupIP(hostname)
|
||||||
|
if err != nil && len(p) > 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return p[0].String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getServiceURL(rootURL string) (url string, err error) {
|
||||||
|
r, err := http.Get(rootURL)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer r.Body.Close()
|
||||||
|
if r.StatusCode >= 400 {
|
||||||
|
err = errors.New(string(r.StatusCode))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var root root
|
||||||
|
err = xml.NewDecoder(r.Body).Decode(&root)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
a := &root.Device
|
||||||
|
if a.DeviceType != "urn:schemas-upnp-org:device:InternetGatewayDevice:1" {
|
||||||
|
err = errors.New("No InternetGatewayDevice")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b := getChildDevice(a, "urn:schemas-upnp-org:device:WANDevice:1")
|
||||||
|
if b == nil {
|
||||||
|
err = errors.New("No WANDevice")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c := getChildDevice(b, "urn:schemas-upnp-org:device:WANConnectionDevice:1")
|
||||||
|
if c == nil {
|
||||||
|
err = errors.New("No WANConnectionDevice")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d := getChildService(c, "urn:schemas-upnp-org:service:WANIPConnection:1")
|
||||||
|
if d == nil {
|
||||||
|
err = errors.New("No WANIPConnection")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
url = combineURL(rootURL, d.ControlURL)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func combineURL(rootURL, subURL string) string {
|
||||||
|
protocolEnd := "://"
|
||||||
|
protoEndIndex := strings.Index(rootURL, protocolEnd)
|
||||||
|
a := rootURL[protoEndIndex+len(protocolEnd):]
|
||||||
|
rootIndex := strings.Index(a, "/")
|
||||||
|
return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
|
||||||
|
}
|
||||||
|
|
||||||
|
func soapRequest(url, function, message string) (r *http.Response, err error) {
|
||||||
|
fullMessage := "<?xml version=\"1.0\" ?>" +
|
||||||
|
"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" +
|
||||||
|
"<s:Body>" + message + "</s:Body></s:Envelope>"
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
|
||||||
|
req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
|
||||||
|
//req.Header.Set("Transfer-Encoding", "chunked")
|
||||||
|
req.Header.Set("SOAPAction", "\"urn:schemas-upnp-org:service:WANIPConnection:1#"+function+"\"")
|
||||||
|
req.Header.Set("Connection", "Close")
|
||||||
|
req.Header.Set("Cache-Control", "no-cache")
|
||||||
|
req.Header.Set("Pragma", "no-cache")
|
||||||
|
|
||||||
|
r, err = http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Body != nil {
|
||||||
|
defer r.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode >= 400 {
|
||||||
|
// log.Stderr(function, r.StatusCode)
|
||||||
|
err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
|
||||||
|
r = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type statusInfo struct {
|
||||||
|
externalIpAddress string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
|
||||||
|
|
||||||
|
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
||||||
|
"</u:GetStatusInfo>"
|
||||||
|
|
||||||
|
var response *http.Response
|
||||||
|
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
|
||||||
|
|
||||||
|
response.Body.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
|
||||||
|
info, err := n.getStatusInfo()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addr = net.ParseIP(info.externalIpAddress)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
|
||||||
|
// A single concatenation would break ARM compilation.
|
||||||
|
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
||||||
|
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort)
|
||||||
|
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
|
||||||
|
message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" +
|
||||||
|
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
|
||||||
|
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
|
||||||
|
message += description +
|
||||||
|
"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) +
|
||||||
|
"</NewLeaseDuration></u:AddPortMapping>"
|
||||||
|
|
||||||
|
var response *http.Response
|
||||||
|
response, err = soapRequest(n.serviceURL, "AddPortMapping", message)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: check response to see if the port was forwarded
|
||||||
|
// log.Println(message, response)
|
||||||
|
mappedExternalPort = externalPort
|
||||||
|
_ = response
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
|
||||||
|
|
||||||
|
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
||||||
|
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
|
||||||
|
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
|
||||||
|
"</u:DeletePortMapping>"
|
||||||
|
|
||||||
|
var response *http.Response
|
||||||
|
response, err = soapRequest(n.serviceURL, "DeletePortMapping", message)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: check response to see if the port was deleted
|
||||||
|
// log.Println(message, response)
|
||||||
|
_ = response
|
||||||
|
return
|
||||||
|
}
|
196
p2p/network.go
Normal file
196
p2p/network.go
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DialerTimeout = 180 //seconds
|
||||||
|
KeepAlivePeriod = 60 //minutes
|
||||||
|
portMappingUpdateInterval = 900 // seconds = 15 mins
|
||||||
|
upnpDiscoverAttempts = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dialer is not an interface in net, so we define one
|
||||||
|
// *net.Dialer conforms to this
|
||||||
|
type Dialer interface {
|
||||||
|
Dial(network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Network interface {
|
||||||
|
Start() error
|
||||||
|
Listener(net.Addr) (net.Listener, error)
|
||||||
|
Dialer(net.Addr) (Dialer, error)
|
||||||
|
NewAddr(string, int) (addr net.Addr, err error)
|
||||||
|
ParseAddr(string) (addr net.Addr, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type NAT interface {
|
||||||
|
GetExternalAddress() (addr net.IP, err error)
|
||||||
|
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
|
||||||
|
DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TCPNetwork struct {
|
||||||
|
nat NAT
|
||||||
|
natType NATType
|
||||||
|
quit chan chan bool
|
||||||
|
ports chan string
|
||||||
|
}
|
||||||
|
|
||||||
|
type NATType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
NONE = iota
|
||||||
|
UPNP
|
||||||
|
PMP
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
portMappingTimeout = 1200 // 20 mins
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewTCPNetwork(natType NATType) (net *TCPNetwork) {
|
||||||
|
return &TCPNetwork{
|
||||||
|
natType: natType,
|
||||||
|
ports: make(chan string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) {
|
||||||
|
return &net.Dialer{
|
||||||
|
Timeout: DialerTimeout * time.Second,
|
||||||
|
// KeepAlive: KeepAlivePeriod * time.Minute,
|
||||||
|
LocalAddr: addr,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) {
|
||||||
|
if self.natType == UPNP {
|
||||||
|
_, port, _ := net.SplitHostPort(addr.String())
|
||||||
|
if self.quit == nil {
|
||||||
|
self.quit = make(chan chan bool)
|
||||||
|
go self.updatePortMappings()
|
||||||
|
}
|
||||||
|
self.ports <- port
|
||||||
|
}
|
||||||
|
return net.Listen(addr.Network(), addr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TCPNetwork) Start() (err error) {
|
||||||
|
switch self.natType {
|
||||||
|
case NONE:
|
||||||
|
case UPNP:
|
||||||
|
nat, uerr := upnpDiscover(upnpDiscoverAttempts)
|
||||||
|
if uerr != nil {
|
||||||
|
err = fmt.Errorf("UPNP failed: ", uerr)
|
||||||
|
} else {
|
||||||
|
self.nat = nat
|
||||||
|
}
|
||||||
|
case PMP:
|
||||||
|
err = fmt.Errorf("PMP not implemented")
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("Invalid NAT type: %v", self.natType)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TCPNetwork) Stop() {
|
||||||
|
q := make(chan bool)
|
||||||
|
self.quit <- q
|
||||||
|
<-q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TCPNetwork) addPortMapping(lport int) (err error) {
|
||||||
|
_, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("unable to add port mapping on %v: %v", lport, err)
|
||||||
|
} else {
|
||||||
|
logger.Debugf("succesfully added port mapping on %v", lport)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TCPNetwork) updatePortMappings() {
|
||||||
|
timer := time.NewTimer(portMappingUpdateInterval * time.Second)
|
||||||
|
lports := []int{}
|
||||||
|
out:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case port := <-self.ports:
|
||||||
|
int64lport, _ := strconv.ParseInt(port, 10, 16)
|
||||||
|
lport := int(int64lport)
|
||||||
|
if err := self.addPortMapping(lport); err != nil {
|
||||||
|
lports = append(lports, lport)
|
||||||
|
}
|
||||||
|
case <-timer.C:
|
||||||
|
for lport := range lports {
|
||||||
|
if err := self.addPortMapping(lport); err != nil {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case errc := <-self.quit:
|
||||||
|
errc <- true
|
||||||
|
break out
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
timer.Stop()
|
||||||
|
for lport := range lports {
|
||||||
|
if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil {
|
||||||
|
logger.Debugf("unable to remove port mapping on %v: %v", lport, err)
|
||||||
|
} else {
|
||||||
|
logger.Debugf("succesfully removed port mapping on %v", lport)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) {
|
||||||
|
ip, err := self.lookupIP(host)
|
||||||
|
if err == nil {
|
||||||
|
return &net.TCPAddr{
|
||||||
|
IP: ip,
|
||||||
|
Port: port,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) {
|
||||||
|
host, port, err := net.SplitHostPort(address)
|
||||||
|
if err == nil {
|
||||||
|
iport, _ := strconv.Atoi(port)
|
||||||
|
addr, e := self.NewAddr(host, iport)
|
||||||
|
return addr, e
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) {
|
||||||
|
if ip = net.ParseIP(host); ip != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var ips []net.IP
|
||||||
|
ips, err = net.LookupIP(host)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warnln(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(ips) == 0 {
|
||||||
|
err = fmt.Errorf("No IP addresses available for %v", host)
|
||||||
|
logger.Warnln(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(ips) > 1 {
|
||||||
|
// Pick a random IP address, simulating round-robin DNS.
|
||||||
|
rand.Seed(time.Now().UTC().UnixNano())
|
||||||
|
ip = ips[rand.Intn(len(ips))]
|
||||||
|
} else {
|
||||||
|
ip = ips[0]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
83
p2p/peer.go
Normal file
83
p2p/peer.go
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Peer struct {
|
||||||
|
// quit chan chan bool
|
||||||
|
Inbound bool // inbound (via listener) or outbound (via dialout)
|
||||||
|
Address net.Addr
|
||||||
|
Host []byte
|
||||||
|
Port uint16
|
||||||
|
Pubkey []byte
|
||||||
|
Id string
|
||||||
|
Caps []string
|
||||||
|
peerErrorChan chan *PeerError
|
||||||
|
messenger *Messenger
|
||||||
|
peerErrorHandler *PeerErrorHandler
|
||||||
|
server *Server
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Peer) Messenger() *Messenger {
|
||||||
|
return self.messenger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Peer) PeerErrorChan() chan *PeerError {
|
||||||
|
return self.peerErrorChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Peer) Server() *Server {
|
||||||
|
return self.server
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer {
|
||||||
|
peerErrorChan := NewPeerErrorChannel()
|
||||||
|
host, port, _ := net.SplitHostPort(address.String())
|
||||||
|
intport, _ := strconv.Atoi(port)
|
||||||
|
peer := &Peer{
|
||||||
|
Inbound: inbound,
|
||||||
|
Address: address,
|
||||||
|
Port: uint16(intport),
|
||||||
|
Host: net.ParseIP(host),
|
||||||
|
peerErrorChan: peerErrorChan,
|
||||||
|
server: server,
|
||||||
|
}
|
||||||
|
connection := NewConnection(conn, peerErrorChan)
|
||||||
|
peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers())
|
||||||
|
peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist())
|
||||||
|
return peer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Peer) String() string {
|
||||||
|
var kind string
|
||||||
|
if self.Inbound {
|
||||||
|
kind = "inbound"
|
||||||
|
} else {
|
||||||
|
kind = "outbound"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Peer) Write(protocol string, msg *Msg) error {
|
||||||
|
return self.messenger.Write(protocol, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Peer) Start() {
|
||||||
|
self.peerErrorHandler.Start()
|
||||||
|
self.messenger.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Peer) Stop() {
|
||||||
|
self.peerErrorHandler.Stop()
|
||||||
|
self.messenger.Stop()
|
||||||
|
// q := make(chan bool)
|
||||||
|
// self.quit <- q
|
||||||
|
// <-q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) Encode() []interface{} {
|
||||||
|
return []interface{}{p.Host, p.Port, p.Pubkey}
|
||||||
|
}
|
76
p2p/peer_error.go
Normal file
76
p2p/peer_error.go
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ErrorCode int
|
||||||
|
|
||||||
|
const errorChanCapacity = 10
|
||||||
|
|
||||||
|
const (
|
||||||
|
PacketTooShort = iota
|
||||||
|
PayloadTooShort
|
||||||
|
MagicTokenMismatch
|
||||||
|
EmptyPayload
|
||||||
|
ReadError
|
||||||
|
WriteError
|
||||||
|
MiscError
|
||||||
|
InvalidMsgCode
|
||||||
|
InvalidMsg
|
||||||
|
P2PVersionMismatch
|
||||||
|
PubkeyMissing
|
||||||
|
PubkeyInvalid
|
||||||
|
PubkeyForbidden
|
||||||
|
ProtocolBreach
|
||||||
|
PortMismatch
|
||||||
|
PingTimeout
|
||||||
|
InvalidGenesis
|
||||||
|
InvalidNetworkId
|
||||||
|
InvalidProtocolVersion
|
||||||
|
)
|
||||||
|
|
||||||
|
var errorToString = map[ErrorCode]string{
|
||||||
|
PacketTooShort: "Packet too short",
|
||||||
|
PayloadTooShort: "Payload too short",
|
||||||
|
MagicTokenMismatch: "Magic token mismatch",
|
||||||
|
EmptyPayload: "Empty payload",
|
||||||
|
ReadError: "Read error",
|
||||||
|
WriteError: "Write error",
|
||||||
|
MiscError: "Misc error",
|
||||||
|
InvalidMsgCode: "Invalid message code",
|
||||||
|
InvalidMsg: "Invalid message",
|
||||||
|
P2PVersionMismatch: "P2P Version Mismatch",
|
||||||
|
PubkeyMissing: "Public key missing",
|
||||||
|
PubkeyInvalid: "Public key invalid",
|
||||||
|
PubkeyForbidden: "Public key forbidden",
|
||||||
|
ProtocolBreach: "Protocol Breach",
|
||||||
|
PortMismatch: "Port mismatch",
|
||||||
|
PingTimeout: "Ping timeout",
|
||||||
|
InvalidGenesis: "Invalid genesis block",
|
||||||
|
InvalidNetworkId: "Invalid network id",
|
||||||
|
InvalidProtocolVersion: "Invalid protocol version",
|
||||||
|
}
|
||||||
|
|
||||||
|
type PeerError struct {
|
||||||
|
Code ErrorCode
|
||||||
|
message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError {
|
||||||
|
desc, ok := errorToString[code]
|
||||||
|
if !ok {
|
||||||
|
panic("invalid error code")
|
||||||
|
}
|
||||||
|
format = desc + ": " + format
|
||||||
|
message := fmt.Sprintf(format, v...)
|
||||||
|
return &PeerError{code, message}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *PeerError) Error() string {
|
||||||
|
return self.message
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeerErrorChannel() chan *PeerError {
|
||||||
|
return make(chan *PeerError, errorChanCapacity)
|
||||||
|
}
|
101
p2p/peer_error_handler.go
Normal file
101
p2p/peer_error_handler.go
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
severityThreshold = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
type DisconnectRequest struct {
|
||||||
|
addr net.Addr
|
||||||
|
reason DiscReason
|
||||||
|
}
|
||||||
|
|
||||||
|
type PeerErrorHandler struct {
|
||||||
|
quit chan chan bool
|
||||||
|
address net.Addr
|
||||||
|
peerDisconnect chan DisconnectRequest
|
||||||
|
severity int
|
||||||
|
peerErrorChan chan *PeerError
|
||||||
|
blacklist Blacklist
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler {
|
||||||
|
return &PeerErrorHandler{
|
||||||
|
quit: make(chan chan bool),
|
||||||
|
address: address,
|
||||||
|
peerDisconnect: peerDisconnect,
|
||||||
|
peerErrorChan: peerErrorChan,
|
||||||
|
blacklist: blacklist,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *PeerErrorHandler) Start() {
|
||||||
|
go self.listen()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *PeerErrorHandler) Stop() {
|
||||||
|
q := make(chan bool)
|
||||||
|
self.quit <- q
|
||||||
|
<-q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *PeerErrorHandler) listen() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case peerError, ok := <-self.peerErrorChan:
|
||||||
|
if ok {
|
||||||
|
logger.Debugf("error %v\n", peerError)
|
||||||
|
go self.handle(peerError)
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case q := <-self.quit:
|
||||||
|
q <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *PeerErrorHandler) handle(peerError *PeerError) {
|
||||||
|
reason := DiscReason(' ')
|
||||||
|
switch peerError.Code {
|
||||||
|
case P2PVersionMismatch:
|
||||||
|
reason = DiscIncompatibleVersion
|
||||||
|
case PubkeyMissing, PubkeyInvalid:
|
||||||
|
reason = DiscInvalidIdentity
|
||||||
|
case PubkeyForbidden:
|
||||||
|
reason = DiscUselessPeer
|
||||||
|
case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach:
|
||||||
|
reason = DiscProtocolError
|
||||||
|
case PingTimeout:
|
||||||
|
reason = DiscReadTimeout
|
||||||
|
case WriteError, MiscError:
|
||||||
|
reason = DiscNetworkError
|
||||||
|
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
|
||||||
|
reason = DiscSubprotocolError
|
||||||
|
default:
|
||||||
|
self.severity += self.getSeverity(peerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.severity >= severityThreshold {
|
||||||
|
reason = DiscSubprotocolError
|
||||||
|
}
|
||||||
|
if reason != DiscReason(' ') {
|
||||||
|
self.peerDisconnect <- DisconnectRequest{
|
||||||
|
addr: self.address,
|
||||||
|
reason: reason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
|
||||||
|
switch peerError.Code {
|
||||||
|
case ReadError:
|
||||||
|
return 4 //tolerate 3 :)
|
||||||
|
default:
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
34
p2p/peer_error_handler_test.go
Normal file
34
p2p/peer_error_handler_test.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
// "fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPeerErrorHandler(t *testing.T) {
|
||||||
|
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
|
||||||
|
peerDisconnect := make(chan DisconnectRequest)
|
||||||
|
peerErrorChan := NewPeerErrorChannel()
|
||||||
|
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist())
|
||||||
|
peh.Start()
|
||||||
|
defer peh.Stop()
|
||||||
|
for i := 0; i < 11; i++ {
|
||||||
|
select {
|
||||||
|
case <-peerDisconnect:
|
||||||
|
t.Errorf("expected no disconnect request")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
peerErrorChan <- NewPeerError(MiscError, "")
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case request := <-peerDisconnect:
|
||||||
|
if request.addr.String() != address.String() {
|
||||||
|
t.Errorf("incorrect address %v != %v", request.addr, address)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Errorf("expected disconnect request")
|
||||||
|
}
|
||||||
|
}
|
96
p2p/peer_test.go
Normal file
96
p2p/peer_test.go
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
// "net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPeer(t *testing.T) {
|
||||||
|
handlers := make(Handlers)
|
||||||
|
testProtocol := &TestProtocol{Msgs: []*Msg{}}
|
||||||
|
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
|
||||||
|
handlers["ccc"] = func(p *Peer) Protocol { return testProtocol }
|
||||||
|
addr := &TestAddr{"test:30"}
|
||||||
|
conn := NewTestNetworkConnection(addr)
|
||||||
|
_, server := SetupTestServer(handlers)
|
||||||
|
server.Handshake()
|
||||||
|
peer := NewPeer(conn, addr, true, server)
|
||||||
|
// peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
|
||||||
|
peer.Start()
|
||||||
|
defer peer.Stop()
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
if len(conn.Out) != 1 {
|
||||||
|
t.Errorf("handshake not sent")
|
||||||
|
} else {
|
||||||
|
out := conn.Out[0]
|
||||||
|
packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
|
||||||
|
if bytes.Compare(out, packet) != 0 {
|
||||||
|
t.Errorf("incorrect handshake packet %v != %v", out, packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
|
||||||
|
conn.In(0, packet)
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
|
||||||
|
if pro.state != handshakeReceived {
|
||||||
|
t.Errorf("handshake not received")
|
||||||
|
}
|
||||||
|
if peer.Port != 30 {
|
||||||
|
t.Errorf("port incorrectly set")
|
||||||
|
}
|
||||||
|
if peer.Id != "peer" {
|
||||||
|
t.Errorf("id incorrectly set")
|
||||||
|
}
|
||||||
|
if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
|
||||||
|
t.Errorf("pubkey incorrectly set")
|
||||||
|
}
|
||||||
|
fmt.Println(peer.Caps)
|
||||||
|
if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
|
||||||
|
t.Errorf("protocols incorrectly set")
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, _ := NewMsg(3)
|
||||||
|
err := peer.Write("aaa", msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expect no error for known protocol: %v", err)
|
||||||
|
} else {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
if len(conn.Out) != 2 {
|
||||||
|
t.Errorf("msg not written")
|
||||||
|
} else {
|
||||||
|
out := conn.Out[1]
|
||||||
|
packet := Packet(16, 3)
|
||||||
|
if bytes.Compare(out, packet) != 0 {
|
||||||
|
t.Errorf("incorrect packet %v != %v", out, packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, _ = NewMsg(2)
|
||||||
|
err = peer.Write("ccc", msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expect no error for known protocol: %v", err)
|
||||||
|
} else {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
if len(conn.Out) != 3 {
|
||||||
|
t.Errorf("msg not written")
|
||||||
|
} else {
|
||||||
|
out := conn.Out[2]
|
||||||
|
packet := Packet(21, 2)
|
||||||
|
if bytes.Compare(out, packet) != 0 {
|
||||||
|
t.Errorf("incorrect packet %v != %v", out, packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = peer.Write("bbb", msg)
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expect error for unknown protocol")
|
||||||
|
}
|
||||||
|
}
|
278
p2p/protocol.go
Normal file
278
p2p/protocol.go
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Protocol interface {
|
||||||
|
Start()
|
||||||
|
Stop()
|
||||||
|
HandleIn(*Msg, chan *Msg)
|
||||||
|
HandleOut(*Msg) bool
|
||||||
|
Offset() MsgCode
|
||||||
|
Name() string
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
P2PVersion = 0
|
||||||
|
pingTimeout = 2
|
||||||
|
pingGracePeriod = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
HandshakeMsg = iota
|
||||||
|
DiscMsg
|
||||||
|
PingMsg
|
||||||
|
PongMsg
|
||||||
|
GetPeersMsg
|
||||||
|
PeersMsg
|
||||||
|
offset = 16
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProtocolState uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
nullState = iota
|
||||||
|
handshakeReceived
|
||||||
|
)
|
||||||
|
|
||||||
|
type DiscReason byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Values are given explicitly instead of by iota because these values are
|
||||||
|
// defined by the wire protocol spec; it is easier for humans to ensure
|
||||||
|
// correctness when values are explicit.
|
||||||
|
DiscRequested = 0x00
|
||||||
|
DiscNetworkError = 0x01
|
||||||
|
DiscProtocolError = 0x02
|
||||||
|
DiscUselessPeer = 0x03
|
||||||
|
DiscTooManyPeers = 0x04
|
||||||
|
DiscAlreadyConnected = 0x05
|
||||||
|
DiscIncompatibleVersion = 0x06
|
||||||
|
DiscInvalidIdentity = 0x07
|
||||||
|
DiscQuitting = 0x08
|
||||||
|
DiscUnexpectedIdentity = 0x09
|
||||||
|
DiscSelf = 0x0a
|
||||||
|
DiscReadTimeout = 0x0b
|
||||||
|
DiscSubprotocolError = 0x10
|
||||||
|
)
|
||||||
|
|
||||||
|
var discReasonToString = map[DiscReason]string{
|
||||||
|
DiscRequested: "Disconnect requested",
|
||||||
|
DiscNetworkError: "Network error",
|
||||||
|
DiscProtocolError: "Breach of protocol",
|
||||||
|
DiscUselessPeer: "Useless peer",
|
||||||
|
DiscTooManyPeers: "Too many peers",
|
||||||
|
DiscAlreadyConnected: "Already connected",
|
||||||
|
DiscIncompatibleVersion: "Incompatible P2P protocol version",
|
||||||
|
DiscInvalidIdentity: "Invalid node identity",
|
||||||
|
DiscQuitting: "Client quitting",
|
||||||
|
DiscUnexpectedIdentity: "Unexpected identity",
|
||||||
|
DiscSelf: "Connected to self",
|
||||||
|
DiscReadTimeout: "Read timeout",
|
||||||
|
DiscSubprotocolError: "Subprotocol error",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d DiscReason) String() string {
|
||||||
|
if len(discReasonToString) < int(d) {
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
return discReasonToString[d]
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaseProtocol struct {
|
||||||
|
peer *Peer
|
||||||
|
state ProtocolState
|
||||||
|
stateLock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBaseProtocol(peer *Peer) *BaseProtocol {
|
||||||
|
self := &BaseProtocol{
|
||||||
|
peer: peer,
|
||||||
|
}
|
||||||
|
|
||||||
|
return self
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BaseProtocol) Start() {
|
||||||
|
if self.peer != nil {
|
||||||
|
self.peer.Write("", self.peer.Server().Handshake())
|
||||||
|
go self.peer.Messenger().PingPong(
|
||||||
|
pingTimeout*time.Second,
|
||||||
|
pingGracePeriod*time.Second,
|
||||||
|
self.Ping,
|
||||||
|
self.Timeout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BaseProtocol) Stop() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BaseProtocol) Ping() {
|
||||||
|
msg, _ := NewMsg(PingMsg)
|
||||||
|
self.peer.Write("", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BaseProtocol) Timeout() {
|
||||||
|
self.peerError(PingTimeout, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BaseProtocol) Name() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BaseProtocol) Offset() MsgCode {
|
||||||
|
return offset
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BaseProtocol) CheckState(state ProtocolState) bool {
|
||||||
|
self.stateLock.RLock()
|
||||||
|
self.stateLock.RUnlock()
|
||||||
|
if self.state != state {
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) {
|
||||||
|
if msg.Code() == HandshakeMsg {
|
||||||
|
self.handleHandshake(msg)
|
||||||
|
} else {
|
||||||
|
if !self.CheckState(handshakeReceived) {
|
||||||
|
self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code())
|
||||||
|
close(response)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch msg.Code() {
|
||||||
|
case DiscMsg:
|
||||||
|
logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint()))
|
||||||
|
self.peer.Server().PeerDisconnect() <- DisconnectRequest{
|
||||||
|
addr: self.peer.Address,
|
||||||
|
reason: DiscRequested,
|
||||||
|
}
|
||||||
|
case PingMsg:
|
||||||
|
out, _ := NewMsg(PongMsg)
|
||||||
|
response <- out
|
||||||
|
case PongMsg:
|
||||||
|
case GetPeersMsg:
|
||||||
|
// Peer asked for list of connected peers
|
||||||
|
if out, err := self.peer.Server().PeersMessage(); err != nil {
|
||||||
|
response <- out
|
||||||
|
}
|
||||||
|
case PeersMsg:
|
||||||
|
self.handlePeers(msg)
|
||||||
|
default:
|
||||||
|
self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) {
|
||||||
|
// somewhat overly paranoid
|
||||||
|
allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) {
|
||||||
|
err := NewPeerError(errorCode, format, v...)
|
||||||
|
logger.Warnln(err)
|
||||||
|
fmt.Println(self.peer, err)
|
||||||
|
if self.peer != nil {
|
||||||
|
self.peer.PeerErrorChan() <- err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BaseProtocol) handlePeers(msg *Msg) {
|
||||||
|
it := msg.Data().NewIterator()
|
||||||
|
for it.Next() {
|
||||||
|
ip := net.IP(it.Value().Get(0).Bytes())
|
||||||
|
port := it.Value().Get(1).Uint()
|
||||||
|
address := &net.TCPAddr{IP: ip, Port: int(port)}
|
||||||
|
go self.peer.Server().PeerConnect(address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BaseProtocol) handleHandshake(msg *Msg) {
|
||||||
|
self.stateLock.Lock()
|
||||||
|
defer self.stateLock.Unlock()
|
||||||
|
if self.state != nullState {
|
||||||
|
self.peerError(ProtocolBreach, "extra handshake")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c := msg.Data()
|
||||||
|
|
||||||
|
var (
|
||||||
|
p2pVersion = c.Get(0).Uint()
|
||||||
|
id = c.Get(1).Str()
|
||||||
|
caps = c.Get(2)
|
||||||
|
port = c.Get(3).Uint()
|
||||||
|
pubkey = c.Get(4).Bytes()
|
||||||
|
)
|
||||||
|
fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey)
|
||||||
|
|
||||||
|
// Check correctness of p2p protocol version
|
||||||
|
if p2pVersion != P2PVersion {
|
||||||
|
self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the pub key (validation, uniqueness)
|
||||||
|
if len(pubkey) == 0 {
|
||||||
|
self.peerError(PubkeyMissing, "not supplied in handshake.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pubkey) != 64 {
|
||||||
|
self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Self connect detection
|
||||||
|
if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 {
|
||||||
|
self.peerError(PubkeyForbidden, "not allowed to connect to self")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// register pubkey on server. this also sets the pubkey on the peer (need lock)
|
||||||
|
if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil {
|
||||||
|
self.peerError(PubkeyForbidden, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// check port
|
||||||
|
if self.peer.Inbound {
|
||||||
|
uint16port := uint16(port)
|
||||||
|
if self.peer.Port > 0 && self.peer.Port != uint16port {
|
||||||
|
self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
self.peer.Port = uint16port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
capsIt := caps.NewIterator()
|
||||||
|
for capsIt.Next() {
|
||||||
|
cap := capsIt.Value().Str()
|
||||||
|
self.peer.Caps = append(self.peer.Caps, cap)
|
||||||
|
}
|
||||||
|
sort.Strings(self.peer.Caps)
|
||||||
|
self.peer.Messenger().AddProtocols(self.peer.Caps)
|
||||||
|
|
||||||
|
self.peer.Id = id
|
||||||
|
|
||||||
|
self.state = handshakeReceived
|
||||||
|
|
||||||
|
//p.ethereum.PushPeer(p)
|
||||||
|
// p.ethereum.reactor.Post("peerList", p.ethereum.Peers())
|
||||||
|
return
|
||||||
|
}
|
484
p2p/server.go
Normal file
484
p2p/server.go
Normal file
@ -0,0 +1,484 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/eth-go/ethlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
outboundAddressPoolSize = 10
|
||||||
|
disconnectGracePeriod = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type Blacklist interface {
|
||||||
|
Get([]byte) (bool, error)
|
||||||
|
Put([]byte) error
|
||||||
|
Delete([]byte) error
|
||||||
|
Exists(pubkey []byte) (ok bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
type BlacklistMap struct {
|
||||||
|
blacklist map[string]bool
|
||||||
|
lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBlacklist() *BlacklistMap {
|
||||||
|
return &BlacklistMap{
|
||||||
|
blacklist: make(map[string]bool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BlacklistMap) Get(pubkey []byte) (bool, error) {
|
||||||
|
self.lock.RLock()
|
||||||
|
defer self.lock.RUnlock()
|
||||||
|
v, ok := self.blacklist[string(pubkey)]
|
||||||
|
var err error
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf("not found")
|
||||||
|
}
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) {
|
||||||
|
self.lock.RLock()
|
||||||
|
defer self.lock.RUnlock()
|
||||||
|
_, ok = self.blacklist[string(pubkey)]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BlacklistMap) Put(pubkey []byte) error {
|
||||||
|
self.lock.RLock()
|
||||||
|
defer self.lock.RUnlock()
|
||||||
|
self.blacklist[string(pubkey)] = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *BlacklistMap) Delete(pubkey []byte) error {
|
||||||
|
self.lock.RLock()
|
||||||
|
defer self.lock.RUnlock()
|
||||||
|
delete(self.blacklist, string(pubkey))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
network Network
|
||||||
|
listening bool //needed?
|
||||||
|
dialing bool //needed?
|
||||||
|
closed bool
|
||||||
|
identity ClientIdentity
|
||||||
|
addr net.Addr
|
||||||
|
port uint16
|
||||||
|
protocols []string
|
||||||
|
|
||||||
|
quit chan chan bool
|
||||||
|
peersLock sync.RWMutex
|
||||||
|
|
||||||
|
maxPeers int
|
||||||
|
peers []*Peer
|
||||||
|
peerSlots chan int
|
||||||
|
peersTable map[string]int
|
||||||
|
peersMsg *Msg
|
||||||
|
peerCount int
|
||||||
|
|
||||||
|
peerConnect chan net.Addr
|
||||||
|
peerDisconnect chan DisconnectRequest
|
||||||
|
blacklist Blacklist
|
||||||
|
handlers Handlers
|
||||||
|
}
|
||||||
|
|
||||||
|
var logger = ethlog.NewLogger("P2P")
|
||||||
|
|
||||||
|
func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server {
|
||||||
|
// get alphabetical list of protocol names from handlers map
|
||||||
|
protocols := []string{}
|
||||||
|
for protocol := range handlers {
|
||||||
|
protocols = append(protocols, protocol)
|
||||||
|
}
|
||||||
|
sort.Strings(protocols)
|
||||||
|
|
||||||
|
_, port, _ := net.SplitHostPort(addr.String())
|
||||||
|
intport, _ := strconv.Atoi(port)
|
||||||
|
|
||||||
|
self := &Server{
|
||||||
|
// NewSimpleClientIdentity(clientIdentifier, version, customIdentifier)
|
||||||
|
network: network,
|
||||||
|
identity: identity,
|
||||||
|
addr: addr,
|
||||||
|
port: uint16(intport),
|
||||||
|
protocols: protocols,
|
||||||
|
|
||||||
|
quit: make(chan chan bool),
|
||||||
|
|
||||||
|
maxPeers: maxPeers,
|
||||||
|
peers: make([]*Peer, maxPeers),
|
||||||
|
peerSlots: make(chan int, maxPeers),
|
||||||
|
peersTable: make(map[string]int),
|
||||||
|
|
||||||
|
peerConnect: make(chan net.Addr, outboundAddressPoolSize),
|
||||||
|
peerDisconnect: make(chan DisconnectRequest),
|
||||||
|
blacklist: blacklist,
|
||||||
|
|
||||||
|
handlers: handlers,
|
||||||
|
}
|
||||||
|
for i := 0; i < maxPeers; i++ {
|
||||||
|
self.peerSlots <- i // fill up with indexes
|
||||||
|
}
|
||||||
|
return self
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) {
|
||||||
|
addr, err = self.network.NewAddr(host, port)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Server) ParseAddr(address string) (addr net.Addr, err error) {
|
||||||
|
addr, err = self.network.ParseAddr(address)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Server) ClientIdentity() ClientIdentity {
|
||||||
|
return self.identity
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Server) PeersMessage() (msg *Msg, err error) {
|
||||||
|
// TODO: memoize and reset when peers change
|
||||||
|
self.peersLock.RLock()
|
||||||
|
defer self.peersLock.RUnlock()
|
||||||
|
msg = self.peersMsg
|
||||||
|
if msg == nil {
|
||||||
|
var peerData []interface{}
|
||||||
|
for _, i := range self.peersTable {
|
||||||
|
peer := self.peers[i]
|
||||||
|
peerData = append(peerData, peer.Encode())
|
||||||
|
}
|
||||||
|
if len(peerData) == 0 {
|
||||||
|
err = fmt.Errorf("no peers")
|
||||||
|
} else {
|
||||||
|
msg, err = NewMsg(PeersMsg, peerData...)
|
||||||
|
self.peersMsg = msg //memoize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Server) Peers() (peers []*Peer) {
|
||||||
|
self.peersLock.RLock()
|
||||||
|
defer self.peersLock.RUnlock()
|
||||||
|
for _, peer := range self.peers {
|
||||||
|
if peer != nil {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Server) PeerCount() int {
|
||||||
|
self.peersLock.RLock()
|
||||||
|
defer self.peersLock.RUnlock()
|
||||||
|
return self.peerCount
|
||||||
|
}
|
||||||
|
|
||||||
|
var getPeersMsg, _ = NewMsg(GetPeersMsg)
|
||||||
|
|
||||||
|
func (self *Server) PeerConnect(addr net.Addr) {
|
||||||
|
// TODO: should buffer, filter and uniq
|
||||||
|
// send GetPeersMsg if not blocking
|
||||||
|
select {
|
||||||
|
case self.peerConnect <- addr: // not enough peers
|
||||||
|
self.Broadcast("", getPeersMsg)
|
||||||
|
default: // we dont care
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Server) PeerDisconnect() chan DisconnectRequest {
|
||||||
|
return self.peerDisconnect
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Server) Blacklist() Blacklist {
|
||||||
|
return self.blacklist
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Server) Handlers() Handlers {
|
||||||
|
return self.handlers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Server) Broadcast(protocol string, msg *Msg) {
|
||||||
|
self.peersLock.RLock()
|
||||||
|
defer self.peersLock.RUnlock()
|
||||||
|
for _, peer := range self.peers {
|
||||||
|
if peer != nil {
|
||||||
|
peer.Write(protocol, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
func (self *Server) Start(listen bool, dial bool) {
|
||||||
|
self.network.Start()
|
||||||
|
if listen {
|
||||||
|
listener, err := self.network.Listener(self.addr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warnf("Error initializing listener: %v", err)
|
||||||
|
logger.Warnf("Connection listening disabled")
|
||||||
|
self.listening = false
|
||||||
|
} else {
|
||||||
|
self.listening = true
|
||||||
|
logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr())
|
||||||
|
go self.inboundPeerHandler(listener)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if dial {
|
||||||
|
dialer, err := self.network.Dialer(self.addr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warnf("Error initializing dialer: %v", err)
|
||||||
|
logger.Warnf("Connection dialout disabled")
|
||||||
|
self.dialing = false
|
||||||
|
} else {
|
||||||
|
self.dialing = true
|
||||||
|
logger.Infoln("Dial peers watching outbound address pool")
|
||||||
|
go self.outboundPeerHandler(dialer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.Infoln("server started")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Server) Stop() {
|
||||||
|
logger.Infoln("server stopping...")
|
||||||
|
// // quit one loop if dialing
|
||||||
|
if self.dialing {
|
||||||
|
logger.Infoln("stop dialout...")
|
||||||
|
dialq := make(chan bool)
|
||||||
|
self.quit <- dialq
|
||||||
|
<-dialq
|
||||||
|
fmt.Println("quit another")
|
||||||
|
}
|
||||||
|
// quit the other loop if listening
|
||||||
|
if self.listening {
|
||||||
|
logger.Infoln("stop listening...")
|
||||||
|
listenq := make(chan bool)
|
||||||
|
self.quit <- listenq
|
||||||
|
<-listenq
|
||||||
|
fmt.Println("quit one")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("quit waited")
|
||||||
|
|
||||||
|
logger.Infoln("stopping peers...")
|
||||||
|
peers := []net.Addr{}
|
||||||
|
self.peersLock.RLock()
|
||||||
|
self.closed = true
|
||||||
|
for _, peer := range self.peers {
|
||||||
|
if peer != nil {
|
||||||
|
peers = append(peers, peer.Address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.peersLock.RUnlock()
|
||||||
|
for _, address := range peers {
|
||||||
|
go self.removePeer(DisconnectRequest{
|
||||||
|
addr: address,
|
||||||
|
reason: DiscQuitting,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// wait till they actually disconnect
|
||||||
|
// this is checked by draining the peerSlots (slots are released back if a peer is removed)
|
||||||
|
i := 0
|
||||||
|
fmt.Println("draining peers")
|
||||||
|
|
||||||
|
FOR:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case slot := <-self.peerSlots:
|
||||||
|
i++
|
||||||
|
fmt.Printf("%v: found slot %v", i, slot)
|
||||||
|
if i == self.maxPeers {
|
||||||
|
break FOR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.Infoln("server stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// main loop for adding connections via listening
|
||||||
|
func (self *Server) inboundPeerHandler(listener net.Listener) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case slot := <-self.peerSlots:
|
||||||
|
go self.connectInboundPeer(listener, slot)
|
||||||
|
case errc := <-self.quit:
|
||||||
|
listener.Close()
|
||||||
|
fmt.Println("quit listenloop")
|
||||||
|
errc <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// main loop for adding outbound peers based on peerConnect address pool
|
||||||
|
// this same loop handles peer disconnect requests as well
|
||||||
|
func (self *Server) outboundPeerHandler(dialer Dialer) {
|
||||||
|
// addressChan initially set to nil (only watches peerConnect if we need more peers)
|
||||||
|
var addressChan chan net.Addr
|
||||||
|
slots := self.peerSlots
|
||||||
|
var slot *int
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case i := <-slots:
|
||||||
|
// we need a peer in slot i, slot reserved
|
||||||
|
slot = &i
|
||||||
|
// now we can watch for candidate peers in the next loop
|
||||||
|
addressChan = self.peerConnect
|
||||||
|
// do not consume more until candidate peer is found
|
||||||
|
slots = nil
|
||||||
|
case address := <-addressChan:
|
||||||
|
// candidate peer found, will dial out asyncronously
|
||||||
|
// if connection fails slot will be released
|
||||||
|
go self.connectOutboundPeer(dialer, address, *slot)
|
||||||
|
// we can watch if more peers needed in the next loop
|
||||||
|
slots = self.peerSlots
|
||||||
|
// until then we dont care about candidate peers
|
||||||
|
addressChan = nil
|
||||||
|
case request := <-self.peerDisconnect:
|
||||||
|
go self.removePeer(request)
|
||||||
|
case errc := <-self.quit:
|
||||||
|
if addressChan != nil && slot != nil {
|
||||||
|
self.peerSlots <- *slot
|
||||||
|
}
|
||||||
|
fmt.Println("quit dialloop")
|
||||||
|
errc <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if peer address already connected
|
||||||
|
func (self *Server) connected(address net.Addr) (err error) {
|
||||||
|
self.peersLock.RLock()
|
||||||
|
defer self.peersLock.RUnlock()
|
||||||
|
// fmt.Printf("address: %v\n", address)
|
||||||
|
slot, found := self.peersTable[address.String()]
|
||||||
|
if found {
|
||||||
|
err = fmt.Errorf("already connected as peer %v (%v)", slot, address)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// connect to peer via listener.Accept()
|
||||||
|
func (self *Server) connectInboundPeer(listener net.Listener, slot int) {
|
||||||
|
var address net.Addr
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err == nil {
|
||||||
|
address = conn.RemoteAddr()
|
||||||
|
err = self.connected(address)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logger.Debugln(err)
|
||||||
|
self.peerSlots <- slot
|
||||||
|
} else {
|
||||||
|
fmt.Printf("adding %v\n", address)
|
||||||
|
go self.addPeer(conn, address, true, slot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// connect to peer via dial out
|
||||||
|
func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) {
|
||||||
|
var conn net.Conn
|
||||||
|
err := self.connected(address)
|
||||||
|
if err == nil {
|
||||||
|
conn, err = dialer.Dial(address.Network(), address.String())
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logger.Debugln(err)
|
||||||
|
self.peerSlots <- slot
|
||||||
|
} else {
|
||||||
|
go self.addPeer(conn, address, false, slot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// creates the new peer object and inserts it into its slot
|
||||||
|
func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) {
|
||||||
|
self.peersLock.Lock()
|
||||||
|
defer self.peersLock.Unlock()
|
||||||
|
if self.closed {
|
||||||
|
fmt.Println("oopsy, not no longer need peer")
|
||||||
|
conn.Close() //oopsy our bad
|
||||||
|
self.peerSlots <- slot // release slot
|
||||||
|
} else {
|
||||||
|
peer := NewPeer(conn, address, inbound, self)
|
||||||
|
self.peers[slot] = peer
|
||||||
|
self.peersTable[address.String()] = slot
|
||||||
|
self.peerCount++
|
||||||
|
// reset peersmsg
|
||||||
|
self.peersMsg = nil
|
||||||
|
fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
|
||||||
|
peer.Start()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
|
||||||
|
func (self *Server) removePeer(request DisconnectRequest) {
|
||||||
|
self.peersLock.Lock()
|
||||||
|
|
||||||
|
address := request.addr
|
||||||
|
slot := self.peersTable[address.String()]
|
||||||
|
peer := self.peers[slot]
|
||||||
|
fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot)
|
||||||
|
if peer == nil {
|
||||||
|
logger.Debugf("already removed peer on %v", address)
|
||||||
|
self.peersLock.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// remove from list and index
|
||||||
|
self.peerCount--
|
||||||
|
self.peers[slot] = nil
|
||||||
|
delete(self.peersTable, address.String())
|
||||||
|
// reset peersmsg
|
||||||
|
self.peersMsg = nil
|
||||||
|
fmt.Printf("removed peer %v (slot %v)\n", peer, slot)
|
||||||
|
self.peersLock.Unlock()
|
||||||
|
|
||||||
|
// sending disconnect message
|
||||||
|
disconnectMsg, _ := NewMsg(DiscMsg, request.reason)
|
||||||
|
peer.Write("", disconnectMsg)
|
||||||
|
// be nice and wait
|
||||||
|
time.Sleep(disconnectGracePeriod * time.Second)
|
||||||
|
// switch off peer and close connections etc.
|
||||||
|
fmt.Println("stopping peer")
|
||||||
|
peer.Stop()
|
||||||
|
fmt.Println("stopped peer")
|
||||||
|
// release slot to signal need for a new peer, last!
|
||||||
|
self.peerSlots <- slot
|
||||||
|
}
|
||||||
|
|
||||||
|
// fix handshake message to push to peers
|
||||||
|
func (self *Server) Handshake() *Msg {
|
||||||
|
fmt.Println(self.identity.Pubkey()[1:])
|
||||||
|
msg, _ := NewMsg(HandshakeMsg, P2PVersion, []byte(self.identity.String()), []interface{}{self.protocols}, self.port, self.identity.Pubkey()[1:])
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error {
|
||||||
|
// Check for blacklisting
|
||||||
|
if self.blacklist.Exists(pubkey) {
|
||||||
|
return fmt.Errorf("blacklisted")
|
||||||
|
}
|
||||||
|
|
||||||
|
self.peersLock.RLock()
|
||||||
|
defer self.peersLock.RUnlock()
|
||||||
|
for _, peer := range self.peers {
|
||||||
|
if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 {
|
||||||
|
return fmt.Errorf("already connected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
candidate.Pubkey = pubkey
|
||||||
|
return nil
|
||||||
|
}
|
208
p2p/server_test.go
Normal file
208
p2p/server_test.go
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TestNetwork struct {
|
||||||
|
connections map[string]*TestNetworkConnection
|
||||||
|
dialer Dialer
|
||||||
|
maxinbound int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTestNetwork(maxinbound int) *TestNetwork {
|
||||||
|
connections := make(map[string]*TestNetworkConnection)
|
||||||
|
return &TestNetwork{
|
||||||
|
connections: connections,
|
||||||
|
dialer: &TestDialer{connections},
|
||||||
|
maxinbound: maxinbound,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) {
|
||||||
|
return self.dialer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) {
|
||||||
|
return &TestListener{
|
||||||
|
connections: self.connections,
|
||||||
|
addr: addr,
|
||||||
|
max: self.maxinbound,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetwork) Start() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestAddr struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestAddr) String() string {
|
||||||
|
return self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*TestAddr) Network() string {
|
||||||
|
return "test"
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestDialer struct {
|
||||||
|
connections map[string]*TestNetworkConnection
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) {
|
||||||
|
address := &TestAddr{addr}
|
||||||
|
tconn := NewTestNetworkConnection(address)
|
||||||
|
self.connections[addr] = tconn
|
||||||
|
conn = net.Conn(tconn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestListener struct {
|
||||||
|
connections map[string]*TestNetworkConnection
|
||||||
|
addr net.Addr
|
||||||
|
max int
|
||||||
|
i int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestListener) Accept() (conn net.Conn, err error) {
|
||||||
|
self.i++
|
||||||
|
if self.i > self.max {
|
||||||
|
err = fmt.Errorf("no more")
|
||||||
|
} else {
|
||||||
|
addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)}
|
||||||
|
tconn := NewTestNetworkConnection(addr)
|
||||||
|
key := tconn.RemoteAddr().String()
|
||||||
|
self.connections[key] = tconn
|
||||||
|
conn = net.Conn(tconn)
|
||||||
|
fmt.Printf("accepted connection from: %v \n", addr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestListener) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (self *TestListener) Addr() net.Addr {
|
||||||
|
return self.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) {
|
||||||
|
network = NewTestNetwork(1)
|
||||||
|
addr := &TestAddr{"test:30303"}
|
||||||
|
identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey")
|
||||||
|
maxPeers := 2
|
||||||
|
if handlers == nil {
|
||||||
|
handlers = make(Handlers)
|
||||||
|
}
|
||||||
|
blackist := NewBlacklist()
|
||||||
|
server = New(network, addr, identity, handlers, maxPeers, blackist)
|
||||||
|
fmt.Println(server.identity.Pubkey())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerListener(t *testing.T) {
|
||||||
|
network, server := SetupTestServer(nil)
|
||||||
|
server.Start(true, false)
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
server.Stop()
|
||||||
|
peer1, ok := network.connections["inboundpeer-1"]
|
||||||
|
if !ok {
|
||||||
|
t.Error("not found inbound peer 1")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("out: %v\n", peer1.Out)
|
||||||
|
if len(peer1.Out) != 2 {
|
||||||
|
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerDialer(t *testing.T) {
|
||||||
|
network, server := SetupTestServer(nil)
|
||||||
|
server.Start(false, true)
|
||||||
|
server.peerConnect <- &TestAddr{"outboundpeer-1"}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
server.Stop()
|
||||||
|
peer1, ok := network.connections["outboundpeer-1"]
|
||||||
|
if !ok {
|
||||||
|
t.Error("not found outbound peer 1")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("out: %v\n", peer1.Out)
|
||||||
|
if len(peer1.Out) != 2 {
|
||||||
|
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerBroadcast(t *testing.T) {
|
||||||
|
handlers := make(Handlers)
|
||||||
|
testProtocol := &TestProtocol{Msgs: []*Msg{}}
|
||||||
|
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
|
||||||
|
network, server := SetupTestServer(handlers)
|
||||||
|
server.Start(true, true)
|
||||||
|
server.peerConnect <- &TestAddr{"outboundpeer-1"}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
msg, _ := NewMsg(0)
|
||||||
|
server.Broadcast("", msg)
|
||||||
|
packet := Packet(0, 0)
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
server.Stop()
|
||||||
|
peer1, ok := network.connections["outboundpeer-1"]
|
||||||
|
if !ok {
|
||||||
|
t.Error("not found outbound peer 1")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("out: %v\n", peer1.Out)
|
||||||
|
if len(peer1.Out) != 3 {
|
||||||
|
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
|
||||||
|
} else {
|
||||||
|
if bytes.Compare(peer1.Out[1], packet) != 0 {
|
||||||
|
t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
peer2, ok := network.connections["inboundpeer-1"]
|
||||||
|
if !ok {
|
||||||
|
t.Error("not found inbound peer 2")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("out: %v\n", peer2.Out)
|
||||||
|
if len(peer1.Out) != 3 {
|
||||||
|
t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
|
||||||
|
} else {
|
||||||
|
if bytes.Compare(peer2.Out[1], packet) != 0 {
|
||||||
|
t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerPeersMessage(t *testing.T) {
|
||||||
|
handlers := make(Handlers)
|
||||||
|
_, server := SetupTestServer(handlers)
|
||||||
|
server.Start(true, true)
|
||||||
|
defer server.Stop()
|
||||||
|
server.peerConnect <- &TestAddr{"outboundpeer-1"}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
peersMsg, err := server.PeersMessage()
|
||||||
|
fmt.Println(peersMsg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expect no error, got %v", err)
|
||||||
|
}
|
||||||
|
if c := server.PeerCount(); c != 2 {
|
||||||
|
t.Errorf("expect 2 peers, got %v", c)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user