p2p: API cleanup and PoC 7 compatibility

Whoa, one more big commit. I didn't manage to untangle the
changes while working towards compatibility.
This commit is contained in:
Felix Lange 2014-11-21 21:48:49 +01:00
parent e4a601c644
commit 59b63caf5e
17 changed files with 1720 additions and 1957 deletions

View File

@ -5,10 +5,10 @@ import (
"runtime" "runtime"
) )
// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc. // ClientIdentity represents the identity of a peer.
type ClientIdentity interface { type ClientIdentity interface {
String() string String() string // human readable identity
Pubkey() []byte Pubkey() []byte // 512-bit public key
} }
type SimpleClientIdentity struct { type SimpleClientIdentity struct {

View File

@ -11,8 +11,6 @@ import (
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
) )
type MsgCode uint64
// Msg defines the structure of a p2p message. // Msg defines the structure of a p2p message.
// //
// Note that a Msg can only be sent once since the Payload reader is // Note that a Msg can only be sent once since the Payload reader is
@ -21,13 +19,13 @@ type MsgCode uint64
// structure, encode the payload into a byte array and create a // structure, encode the payload into a byte array and create a
// separate Msg with a bytes.Reader as Payload for each send. // separate Msg with a bytes.Reader as Payload for each send.
type Msg struct { type Msg struct {
Code MsgCode Code uint64
Size uint32 // size of the paylod Size uint32 // size of the paylod
Payload io.Reader Payload io.Reader
} }
// NewMsg creates an RLP-encoded message with the given code. // NewMsg creates an RLP-encoded message with the given code.
func NewMsg(code MsgCode, params ...interface{}) Msg { func NewMsg(code uint64, params ...interface{}) Msg {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
for _, p := range params { for _, p := range params {
buf.Write(ethutil.Encode(p)) buf.Write(ethutil.Encode(p))
@ -63,6 +61,52 @@ func (msg Msg) Discard() error {
return err return err
} }
type MsgReader interface {
ReadMsg() (Msg, error)
}
type MsgWriter interface {
// WriteMsg sends an existing message.
// The Payload reader of the message is consumed.
// Note that messages can be sent only once.
WriteMsg(Msg) error
// EncodeMsg writes an RLP-encoded message with the given
// code and data elements.
EncodeMsg(code uint64, data ...interface{}) error
}
// MsgReadWriter provides reading and writing of encoded messages.
type MsgReadWriter interface {
MsgReader
MsgWriter
}
// MsgLoop reads messages off the given reader and
// calls the handler function for each decoded message until
// it returns an error or the peer connection is closed.
//
// If a message is larger than the given maximum size,
// MsgLoop returns an appropriate error.
func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Value) error) error {
for {
msg, err := r.ReadMsg()
if err != nil {
return err
}
if msg.Size > maxsize {
return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
}
value, err := msg.Data()
if err != nil {
return err
}
if err := f(msg.Code, value); err != nil {
return err
}
}
}
var magicToken = []byte{34, 64, 8, 145} var magicToken = []byte{34, 64, 8, 145}
func writeMsg(w io.Writer, msg Msg) error { func writeMsg(w io.Writer, msg Msg) error {
@ -103,10 +147,10 @@ func readMsg(r byteReader) (msg Msg, err error) {
// read magic and payload size // read magic and payload size
start := make([]byte, 8) start := make([]byte, 8)
if _, err = io.ReadFull(r, start); err != nil { if _, err = io.ReadFull(r, start); err != nil {
return msg, NewPeerError(ReadError, "%v", err) return msg, newPeerError(errRead, "%v", err)
} }
if !bytes.HasPrefix(start, magicToken) { if !bytes.HasPrefix(start, magicToken) {
return msg, NewPeerError(MagicTokenMismatch, "got %x, want %x", start[:4], magicToken) return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
} }
size := binary.BigEndian.Uint32(start[4:]) size := binary.BigEndian.Uint32(start[4:])
@ -152,13 +196,13 @@ func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) {
} }
// readUint reads an RLP-encoded unsigned integer from r. // readUint reads an RLP-encoded unsigned integer from r.
func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) { func readMsgCode(r byteReader) (code uint64, codelen uint32, err error) {
b, err := r.ReadByte() b, err := r.ReadByte()
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
if b < 0x80 { if b < 0x80 {
return MsgCode(b), 1, nil return uint64(b), 1, nil
} else if b < 0x89 { // max length for uint64 is 8 bytes } else if b < 0x89 { // max length for uint64 is 8 bytes
codelen = uint32(b - 0x80) codelen = uint32(b - 0x80)
if codelen == 0 { if codelen == 0 {
@ -168,7 +212,7 @@ func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) {
if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil { if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil {
return 0, 0, err return 0, 0, err
} }
return MsgCode(binary.BigEndian.Uint64(buf)), codelen, nil return binary.BigEndian.Uint64(buf), codelen, nil
} }
return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b) return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b)
} }

View File

@ -1,221 +0,0 @@
package p2p
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"sync"
"time"
)
type Handlers map[string]Protocol
type proto struct {
in chan Msg
maxcode, offset MsgCode
messenger *messenger
}
func (rw *proto) WriteMsg(msg Msg) error {
if msg.Code >= rw.maxcode {
return NewPeerError(InvalidMsgCode, "not handled")
}
msg.Code += rw.offset
return rw.messenger.writeMsg(msg)
}
func (rw *proto) ReadMsg() (Msg, error) {
msg, ok := <-rw.in
if !ok {
return msg, io.EOF
}
msg.Code -= rw.offset
return msg, nil
}
// eofSignal wraps a reader with eof signaling.
// the eof channel is closed when the wrapped reader
// reaches EOF.
type eofSignal struct {
wrapped io.Reader
eof chan struct{}
}
func (r *eofSignal) Read(buf []byte) (int, error) {
n, err := r.wrapped.Read(buf)
if err != nil {
close(r.eof) // tell messenger that msg has been consumed
}
return n, err
}
// messenger represents a message-oriented peer connection.
// It keeps track of the set of protocols understood
// by the remote peer.
type messenger struct {
peer *Peer
handlers Handlers
// the mutex protects the connection
// so only one protocol can write at a time.
writeMu sync.Mutex
conn net.Conn
bufconn *bufio.ReadWriter
protocolLock sync.RWMutex
protocols map[string]*proto
offsets map[MsgCode]*proto
protoWG sync.WaitGroup
err chan error
pulse chan bool
}
func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger {
return &messenger{
conn: conn,
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
peer: peer,
handlers: handlers,
protocols: make(map[string]*proto),
err: errchan,
pulse: make(chan bool, 1),
}
}
func (m *messenger) Start() {
m.protocols[""] = m.startProto(0, "", &baseProtocol{})
go m.readLoop()
}
func (m *messenger) Stop() {
m.conn.Close()
m.protoWG.Wait()
}
const (
// maximum amount of time allowed for reading a message
msgReadTimeout = 5 * time.Second
// messages smaller than this many bytes will be read at
// once before passing them to a protocol.
wholePayloadSize = 64 * 1024
)
func (m *messenger) readLoop() {
defer m.closeProtocols()
for {
m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
msg, err := readMsg(m.bufconn)
if err != nil {
m.err <- err
return
}
// send ping to heartbeat channel signalling time of last message
m.pulse <- true
proto, err := m.getProto(msg.Code)
if err != nil {
m.err <- err
return
}
if msg.Size <= wholePayloadSize {
// optimization: msg is small enough, read all
// of it and move on to the next message
buf, err := ioutil.ReadAll(msg.Payload)
if err != nil {
m.err <- err
return
}
msg.Payload = bytes.NewReader(buf)
proto.in <- msg
} else {
pr := &eofSignal{msg.Payload, make(chan struct{})}
msg.Payload = pr
proto.in <- msg
<-pr.eof
}
}
}
func (m *messenger) closeProtocols() {
m.protocolLock.RLock()
for _, p := range m.protocols {
close(p.in)
}
m.protocolLock.RUnlock()
}
func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto {
proto := &proto{
in: make(chan Msg),
offset: offset,
maxcode: impl.Offset(),
messenger: m,
}
m.protoWG.Add(1)
go func() {
if err := impl.Start(m.peer, proto); err != nil && err != io.EOF {
logger.Errorf("protocol %q error: %v\n", name, err)
m.err <- err
}
m.protoWG.Done()
}()
return proto
}
// getProto finds the protocol responsible for handling
// the given message code.
func (m *messenger) getProto(code MsgCode) (*proto, error) {
m.protocolLock.RLock()
defer m.protocolLock.RUnlock()
for _, proto := range m.protocols {
if code >= proto.offset && code < proto.offset+proto.maxcode {
return proto, nil
}
}
return nil, NewPeerError(InvalidMsgCode, "%d", code)
}
// setProtocols starts all subprotocols shared with the
// remote peer. the protocols must be sorted alphabetically.
func (m *messenger) setRemoteProtocols(protocols []string) {
m.protocolLock.Lock()
defer m.protocolLock.Unlock()
offset := baseProtocolOffset
for _, name := range protocols {
inst, ok := m.handlers[name]
if !ok {
continue // not handled
}
m.protocols[name] = m.startProto(offset, name, inst)
offset += inst.Offset()
}
}
// writeProtoMsg sends the given message on behalf of the given named protocol.
func (m *messenger) writeProtoMsg(protoName string, msg Msg) error {
m.protocolLock.RLock()
proto, ok := m.protocols[protoName]
m.protocolLock.RUnlock()
if !ok {
return fmt.Errorf("protocol %s not handled by peer", protoName)
}
if msg.Code >= proto.maxcode {
return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
}
msg.Code += proto.offset
return m.writeMsg(msg)
}
// writeMsg writes a message to the connection.
func (m *messenger) writeMsg(msg Msg) error {
m.writeMu.Lock()
defer m.writeMu.Unlock()
if err := writeMsg(m.bufconn, msg); err != nil {
return err
}
return m.bufconn.Flush()
}

View File

@ -1,203 +0,0 @@
package p2p
import (
"bufio"
"fmt"
"io"
"log"
"net"
"os"
"reflect"
"testing"
"time"
logpkg "github.com/ethereum/go-ethereum/logger"
)
func init() {
logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel))
}
func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
conn1, conn2 := net.Pipe()
id := NewSimpleClientIdentity("test", "0", "0", "public key")
server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist())
peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0)
return conn2, peer, peer.messenger
}
func performTestHandshake(r *bufio.Reader, w io.Writer) error {
// read remote handshake
msg, err := readMsg(r)
if err != nil {
return fmt.Errorf("read error: %v", err)
}
if msg.Code != handshakeMsg {
return fmt.Errorf("first message should be handshake, got %d", msg.Code)
}
if err := msg.Discard(); err != nil {
return err
}
// send empty handshake
pubkey := make([]byte, 64)
msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey)
return writeMsg(w, msg)
}
type testProtocol struct {
offset MsgCode
f func(MsgReadWriter)
}
func (p *testProtocol) Offset() MsgCode {
return p.offset
}
func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error {
p.f(rw)
return nil
}
func TestRead(t *testing.T) {
done := make(chan struct{})
handlers := Handlers{
"a": &testProtocol{5, func(rw MsgReadWriter) {
msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 2 {
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
}
data, err := msg.Data()
if err != nil {
t.Errorf("data decoding error: %v", err)
}
expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
if !reflect.DeepEqual(data.Slice(), expdata) {
t.Errorf("incorrect msg data %#v", data.Slice())
}
close(done)
}},
}
net, peer, m := testMessenger(handlers)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
m.setRemoteProtocols([]string{"a"})
writeMsg(net, NewMsg(18, 1, "000"))
select {
case <-done:
case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
}
}
func TestWriteFromProto(t *testing.T) {
handlers := Handlers{
"a": &testProtocol{2, func(rw MsgReadWriter) {
if err := rw.WriteMsg(NewMsg(2)); err == nil {
t.Error("expected error for out-of-range msg code, got nil")
}
if err := rw.WriteMsg(NewMsg(1)); err != nil {
t.Errorf("write error: %v", err)
}
}},
}
net, peer, mess := testMessenger(handlers)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
mess.setRemoteProtocols([]string{"a"})
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v")
}
if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
}
}
var discardProto = &testProtocol{1, func(rw MsgReadWriter) {
for {
msg, err := rw.ReadMsg()
if err != nil {
return
}
if err = msg.Discard(); err != nil {
return
}
}
}}
func TestMessengerWriteProtoMsg(t *testing.T) {
handlers := Handlers{"a": discardProto}
net, peer, mess := testMessenger(handlers)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
mess.setRemoteProtocols([]string{"a"})
// test write errors
if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil {
t.Errorf("expected error for unknown protocol, got nil")
}
if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil {
t.Errorf("expected error for out-of-range msg code, got nil")
} else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode {
t.Errorf("wrong error for out-of-range msg code, got %#v")
}
// test succcessful write
read, readerr := make(chan Msg), make(chan error)
go func() {
if msg, err := readMsg(bufr); err != nil {
readerr <- err
} else {
read <- msg
}
}()
if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err)
}
select {
case msg := <-read:
if msg.Code != 16 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
}
msg.Discard()
case err := <-readerr:
t.Errorf("read error: %v", err)
}
}
func TestPulse(t *testing.T) {
net, peer, _ := testMessenger(nil)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
before := time.Now()
msg, err := readMsg(bufr)
if err != nil {
t.Fatalf("read error: %v", err)
}
after := time.Now()
if msg.Code != pingMsg {
t.Errorf("expected ping message, got %d", msg.Code)
}
if d := after.Sub(before); d < pingTimeout {
t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout)
}
}

View File

@ -3,6 +3,7 @@ package p2p
import ( import (
"fmt" "fmt"
"net" "net"
"time"
natpmp "github.com/jackpal/go-nat-pmp" natpmp "github.com/jackpal/go-nat-pmp"
) )
@ -13,38 +14,37 @@ import (
// + Register for changes to the external address. // + Register for changes to the external address.
// + Re-register port mapping when router reboots. // + Re-register port mapping when router reboots.
// + A mechanism for keeping a port mapping registered. // + A mechanism for keeping a port mapping registered.
// + Discover gateway address automatically.
type natPMPClient struct { type natPMPClient struct {
client *natpmp.Client client *natpmp.Client
} }
func NewNatPMP(gateway net.IP) (nat NAT) { // PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
// address should be the IP of your router.
func PMP(gateway net.IP) (nat NAT) {
return &natPMPClient{natpmp.NewClient(gateway)} return &natPMPClient{natpmp.NewClient(gateway)}
} }
func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) { func (*natPMPClient) String() string {
response, err := n.client.GetExternalAddress() return "NAT-PMP"
if err != nil {
return
}
ip := response.ExternalIPAddress
addr = net.IPv4(ip[0], ip[1], ip[2], ip[3])
return
} }
func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int, func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
description string, timeout int) (mappedExternalPort int, err error) { response, err := n.client.GetExternalAddress()
if timeout <= 0 { if err != nil {
err = fmt.Errorf("timeout must not be <= 0") return nil, err
return }
return response.ExternalIPAddress[:], nil
}
func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
if lifetime <= 0 {
return fmt.Errorf("lifetime must not be <= 0")
} }
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping. // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout) _, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
if err != nil { return err
return
}
mappedExternalPort = int(response.MappedExternalPort)
return
} }
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {

View File

@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt"
"net" "net"
"net/http" "net/http"
"os" "os"
@ -15,28 +16,46 @@ import (
"time" "time"
) )
const (
upnpDiscoverAttempts = 3
upnpDiscoverTimeout = 5 * time.Second
)
// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
// discover the address of your router using UDP broadcasts.
func UPNP() NAT {
return &upnpNAT{}
}
type upnpNAT struct { type upnpNAT struct {
serviceURL string serviceURL string
ourIP string ourIP string
} }
func upnpDiscover(attempts int) (nat NAT, err error) { func (n *upnpNAT) String() string {
return "UPNP"
}
func (n *upnpNAT) discover() error {
if n.serviceURL != "" {
// already discovered
return nil
}
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
if err != nil { if err != nil {
return return err
} }
// TODO: try on all network interfaces simultaneously.
// Broadcasting on 0.0.0.0 could select a random interface
// to send on (platform specific).
conn, err := net.ListenPacket("udp4", ":0") conn, err := net.ListenPacket("udp4", ":0")
if err != nil { if err != nil {
return return err
}
socket := conn.(*net.UDPConn)
defer socket.Close()
err = socket.SetDeadline(time.Now().Add(10 * time.Second))
if err != nil {
return
} }
defer conn.Close()
conn.SetDeadline(time.Now().Add(10 * time.Second))
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
buf := bytes.NewBufferString( buf := bytes.NewBufferString(
"M-SEARCH * HTTP/1.1\r\n" + "M-SEARCH * HTTP/1.1\r\n" +
@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
"MX: 2\r\n\r\n") "MX: 2\r\n\r\n")
message := buf.Bytes() message := buf.Bytes()
answerBytes := make([]byte, 1024) answerBytes := make([]byte, 1024)
for i := 0; i < attempts; i++ { for i := 0; i < upnpDiscoverAttempts; i++ {
_, err = socket.WriteToUDP(message, ssdp) _, err = conn.WriteTo(message, ssdp)
if err != nil { if err != nil {
return return err
} }
var n int nn, _, err := conn.ReadFrom(answerBytes)
n, _, err = socket.ReadFromUDP(answerBytes)
if err != nil { if err != nil {
continue continue
// socket.Close()
// return
} }
answer := string(answerBytes[0:n]) answer := string(answerBytes[0:nn])
if strings.Index(answer, "\r\n"+st) < 0 { if strings.Index(answer, "\r\n"+st) < 0 {
continue continue
} }
@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
var serviceURL string var serviceURL string
serviceURL, err = getServiceURL(locURL) serviceURL, err = getServiceURL(locURL)
if err != nil { if err != nil {
return return err
} }
var ourIP string var ourIP string
ourIP, err = getOurIP() ourIP, err = getOurIP()
if err != nil { if err != nil {
return return err
} }
nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP} n.serviceURL = serviceURL
n.ourIP = ourIP
return nil
}
return errors.New("UPnP port discovery failed.")
}
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
if err := n.discover(); err != nil {
return nil, err
}
info, err := n.getStatusInfo()
return net.ParseIP(info.externalIpAddress), err
}
func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
if err := n.discover(); err != nil {
return err
}
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
"</NewLeaseDuration></u:AddPortMapping>"
// TODO: check response to see if the port was forwarded
_, err := soapRequest(n.serviceURL, "AddPortMapping", message)
return err
}
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
if err := n.discover(); err != nil {
return err
}
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
// TODO: check response to see if the port was deleted
_, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
return err
}
type statusInfo struct {
externalIpAddress string
}
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"</u:GetStatusInfo>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
if err != nil {
return return
} }
err = errors.New("UPnP port discovery failed.")
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
response.Body.Close()
return return
} }
@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) {
} }
return return
} }
type statusInfo struct {
externalIpAddress string
}
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"</u:GetStatusInfo>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
if err != nil {
return
}
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
response.Body.Close()
return
}
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
info, err := n.getStatusInfo()
if err != nil {
return
}
addr = net.ParseIP(info.externalIpAddress)
return
}
func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) +
"</NewLeaseDuration></u:AddPortMapping>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "AddPortMapping", message)
if err != nil {
return
}
// TODO: check response to see if the port was forwarded
// log.Println(message, response)
mappedExternalPort = externalPort
_ = response
return
}
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "DeletePortMapping", message)
if err != nil {
return
}
// TODO: check response to see if the port was deleted
// log.Println(message, response)
_ = response
return
}

View File

@ -1,196 +0,0 @@
package p2p
import (
"fmt"
"math/rand"
"net"
"strconv"
"time"
)
const (
DialerTimeout = 180 //seconds
KeepAlivePeriod = 60 //minutes
portMappingUpdateInterval = 900 // seconds = 15 mins
upnpDiscoverAttempts = 3
)
// Dialer is not an interface in net, so we define one
// *net.Dialer conforms to this
type Dialer interface {
Dial(network, address string) (net.Conn, error)
}
type Network interface {
Start() error
Listener(net.Addr) (net.Listener, error)
Dialer(net.Addr) (Dialer, error)
NewAddr(string, int) (addr net.Addr, err error)
ParseAddr(string) (addr net.Addr, err error)
}
type NAT interface {
GetExternalAddress() (addr net.IP, err error)
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
}
type TCPNetwork struct {
nat NAT
natType NATType
quit chan chan bool
ports chan string
}
type NATType int
const (
NONE = iota
UPNP
PMP
)
const (
portMappingTimeout = 1200 // 20 mins
)
func NewTCPNetwork(natType NATType) (net *TCPNetwork) {
return &TCPNetwork{
natType: natType,
ports: make(chan string),
}
}
func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) {
return &net.Dialer{
Timeout: DialerTimeout * time.Second,
// KeepAlive: KeepAlivePeriod * time.Minute,
LocalAddr: addr,
}, nil
}
func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) {
if self.natType == UPNP {
_, port, _ := net.SplitHostPort(addr.String())
if self.quit == nil {
self.quit = make(chan chan bool)
go self.updatePortMappings()
}
self.ports <- port
}
return net.Listen(addr.Network(), addr.String())
}
func (self *TCPNetwork) Start() (err error) {
switch self.natType {
case NONE:
case UPNP:
nat, uerr := upnpDiscover(upnpDiscoverAttempts)
if uerr != nil {
err = fmt.Errorf("UPNP failed: ", uerr)
} else {
self.nat = nat
}
case PMP:
err = fmt.Errorf("PMP not implemented")
default:
err = fmt.Errorf("Invalid NAT type: %v", self.natType)
}
return
}
func (self *TCPNetwork) Stop() {
q := make(chan bool)
self.quit <- q
<-q
}
func (self *TCPNetwork) addPortMapping(lport int) (err error) {
_, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout)
if err != nil {
logger.Errorf("unable to add port mapping on %v: %v", lport, err)
} else {
logger.Debugf("succesfully added port mapping on %v", lport)
}
return
}
func (self *TCPNetwork) updatePortMappings() {
timer := time.NewTimer(portMappingUpdateInterval * time.Second)
lports := []int{}
out:
for {
select {
case port := <-self.ports:
int64lport, _ := strconv.ParseInt(port, 10, 16)
lport := int(int64lport)
if err := self.addPortMapping(lport); err != nil {
lports = append(lports, lport)
}
case <-timer.C:
for lport := range lports {
if err := self.addPortMapping(lport); err != nil {
}
}
case errc := <-self.quit:
errc <- true
break out
}
}
timer.Stop()
for lport := range lports {
if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil {
logger.Debugf("unable to remove port mapping on %v: %v", lport, err)
} else {
logger.Debugf("succesfully removed port mapping on %v", lport)
}
}
}
func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) {
ip, err := self.lookupIP(host)
if err == nil {
return &net.TCPAddr{
IP: ip,
Port: port,
}, nil
}
return nil, err
}
func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) {
host, port, err := net.SplitHostPort(address)
if err == nil {
iport, _ := strconv.Atoi(port)
addr, e := self.NewAddr(host, iport)
return addr, e
}
return nil, err
}
func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) {
if ip = net.ParseIP(host); ip != nil {
return
}
var ips []net.IP
ips, err = net.LookupIP(host)
if err != nil {
logger.Warnln(err)
return
}
if len(ips) == 0 {
err = fmt.Errorf("No IP addresses available for %v", host)
logger.Warnln(err)
return
}
if len(ips) > 1 {
// Pick a random IP address, simulating round-robin DNS.
rand.Seed(time.Now().UTC().UnixNano())
ip = ips[rand.Intn(len(ips))]
} else {
ip = ips[0]
}
return
}

View File

@ -1,66 +1,454 @@
package p2p package p2p
import ( import (
"bufio"
"bytes"
"fmt" "fmt"
"io"
"io/ioutil"
"net" "net"
"strconv" "sort"
"sync"
"time"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/logger"
) )
type Peer struct { // peerAddr is the structure of a peer list element.
Inbound bool // inbound (via listener) or outbound (via dialout) // It is also a valid net.Addr.
Address net.Addr type peerAddr struct {
Host []byte IP net.IP
Port uint16 Port uint64
Pubkey []byte Pubkey []byte // optional
Id string
Caps []string
peerErrorChan chan error
messenger *messenger
peerErrorHandler *PeerErrorHandler
server *Server
} }
func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer { func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr {
peerErrorChan := NewPeerErrorChannel() n := addr.Network()
host, port, _ := net.SplitHostPort(address.String()) if n != "tcp" && n != "tcp4" && n != "tcp6" {
intport, _ := strconv.Atoi(port) // for testing with non-TCP
peer := &Peer{ return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
Inbound: inbound,
Address: address,
Port: uint16(intport),
Host: net.ParseIP(host),
peerErrorChan: peerErrorChan,
server: server,
} }
peer.messenger = newMessenger(peer, conn, peerErrorChan, server.Handlers()) ta := addr.(*net.TCPAddr)
peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan) return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
}
func (d peerAddr) Network() string {
if d.IP.To4() != nil {
return "tcp4"
} else {
return "tcp6"
}
}
func (d peerAddr) String() string {
return fmt.Sprintf("%v:%d", d.IP, d.Port)
}
func (d peerAddr) RlpData() interface{} {
return []interface{}{d.IP, d.Port, d.Pubkey}
}
// Peer represents a remote peer.
type Peer struct {
// Peers have all the log methods.
// Use them to display messages related to the peer.
*logger.Logger
infolock sync.Mutex
identity ClientIdentity
caps []Cap
listenAddr *peerAddr // what remote peer is listening on
dialAddr *peerAddr // non-nil if dialing
// The mutex protects the connection
// so only one protocol can write at a time.
writeMu sync.Mutex
conn net.Conn
bufconn *bufio.ReadWriter
// These fields maintain the running protocols.
protocols []Protocol
runBaseProtocol bool // for testing
runlock sync.RWMutex // protects running
running map[string]*proto
protoWG sync.WaitGroup
protoErr chan error
closed chan struct{}
disc chan DiscReason
activity event.TypeMux // for activity events
slot int // index into Server peer list
// These fields are kept so base protocol can access them.
// TODO: this should be one or more interfaces
ourID ClientIdentity // client id of the Server
ourListenAddr *peerAddr // listen addr of Server, nil if not listening
newPeerAddr chan<- *peerAddr // tell server about received peers
otherPeers func() []*Peer // should return the list of all peers
pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
}
// NewPeer returns a peer for testing purposes.
func NewPeer(id ClientIdentity, caps []Cap) *Peer {
conn, _ := net.Pipe()
peer := newPeer(conn, nil, nil)
peer.setHandshakeInfo(id, nil, caps)
return peer return peer
} }
func (self *Peer) String() string { func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
var kind string p := newPeer(conn, server.Protocols, dialAddr)
if self.Inbound { p.ourID = server.Identity
kind = "inbound" p.newPeerAddr = server.peerConnect
} else { p.otherPeers = server.Peers
p.pubkeyHook = server.verifyPeer
p.runBaseProtocol = true
// laddr can be updated concurrently by NAT traversal.
// newServerPeer must be called with the server lock held.
if server.laddr != nil {
p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
}
return p
}
func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer {
p := &Peer{
Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()),
conn: conn,
dialAddr: dialAddr,
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
protocols: protocols,
running: make(map[string]*proto),
disc: make(chan DiscReason),
protoErr: make(chan error),
closed: make(chan struct{}),
}
return p
}
// Identity returns the client identity of the remote peer. The
// identity can be nil if the peer has not yet completed the
// handshake.
func (p *Peer) Identity() ClientIdentity {
p.infolock.Lock()
defer p.infolock.Unlock()
return p.identity
}
// Caps returns the capabilities (supported subprotocols) of the remote peer.
func (p *Peer) Caps() []Cap {
p.infolock.Lock()
defer p.infolock.Unlock()
return p.caps
}
func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) {
p.infolock.Lock()
p.identity = id
p.listenAddr = laddr
p.caps = caps
p.infolock.Unlock()
}
// RemoteAddr returns the remote address of the network connection.
func (p *Peer) RemoteAddr() net.Addr {
return p.conn.RemoteAddr()
}
// LocalAddr returns the local address of the network connection.
func (p *Peer) LocalAddr() net.Addr {
return p.conn.LocalAddr()
}
// Disconnect terminates the peer connection with the given reason.
// It returns immediately and does not wait until the connection is closed.
func (p *Peer) Disconnect(reason DiscReason) {
select {
case p.disc <- reason:
case <-p.closed:
}
}
// String implements fmt.Stringer.
func (p *Peer) String() string {
kind := "inbound"
p.infolock.Lock()
if p.dialAddr != nil {
kind = "outbound" kind = "outbound"
} }
return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps) p.infolock.Unlock()
return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
} }
func (self *Peer) Write(protocol string, msg Msg) error { const (
return self.messenger.writeProtoMsg(protocol, msg) // maximum amount of time allowed for reading a message
msgReadTimeout = 5 * time.Second
// maximum amount of time allowed for writing a message
msgWriteTimeout = 5 * time.Second
// messages smaller than this many bytes will be read at
// once before passing them to a protocol.
wholePayloadSize = 64 * 1024
)
var (
inactivityTimeout = 2 * time.Second
disconnectGracePeriod = 2 * time.Second
)
func (p *Peer) loop() (reason DiscReason, err error) {
defer p.activity.Stop()
defer p.closeProtocols()
defer close(p.closed)
defer p.conn.Close()
// read loop
readMsg := make(chan Msg)
readErr := make(chan error)
readNext := make(chan bool, 1)
protoDone := make(chan struct{}, 1)
go p.readLoop(readMsg, readErr, readNext)
readNext <- true
if p.runBaseProtocol {
p.startBaseProtocol()
}
loop:
for {
select {
case msg := <-readMsg:
// a new message has arrived.
var wait bool
if wait, err = p.dispatch(msg, protoDone); err != nil {
p.Errorf("msg dispatch error: %v\n", err)
reason = discReasonForError(err)
break loop
}
if !wait {
// Msg has already been read completely, continue with next message.
readNext <- true
}
p.activity.Post(time.Now())
case <-protoDone:
// protocol has consumed the message payload,
// we can continue reading from the socket.
readNext <- true
case err := <-readErr:
// read failed. there is no need to run the
// polite disconnect sequence because the connection
// is probably dead anyway.
// TODO: handle write errors as well
return DiscNetworkError, err
case err = <-p.protoErr:
reason = discReasonForError(err)
break loop
case reason = <-p.disc:
break loop
}
}
// wait for read loop to return.
close(readNext)
<-readErr
// tell the remote end to disconnect
done := make(chan struct{})
go func() {
p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod))
p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod)
io.Copy(ioutil.Discard, p.conn)
close(done)
}()
select {
case <-done:
case <-time.After(disconnectGracePeriod):
}
return reason, err
} }
func (self *Peer) Start() { func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) {
self.peerErrorHandler.Start() for _ = range unblock {
self.messenger.Start() p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
if msg, err := readMsg(p.bufconn); err != nil {
errc <- err
} else {
msgc <- msg
}
}
close(errc)
} }
func (self *Peer) Stop() { func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) {
self.peerErrorHandler.Stop() proto, err := p.getProto(msg.Code)
self.messenger.Stop() if err != nil {
return false, err
}
if msg.Size <= wholePayloadSize {
// optimization: msg is small enough, read all
// of it and move on to the next message
buf, err := ioutil.ReadAll(msg.Payload)
if err != nil {
return false, err
}
msg.Payload = bytes.NewReader(buf)
proto.in <- msg
} else {
wait = true
pr := &eofSignal{msg.Payload, protoDone}
msg.Payload = pr
proto.in <- msg
}
return wait, nil
} }
func (p *Peer) Encode() []interface{} { func (p *Peer) startBaseProtocol() {
return []interface{}{p.Host, p.Port, p.Pubkey} p.runlock.Lock()
defer p.runlock.Unlock()
p.running[""] = p.startProto(0, Protocol{
Length: baseProtocolLength,
Run: runBaseProtocol,
})
}
// startProtocols starts matching named subprotocols.
func (p *Peer) startSubprotocols(caps []Cap) {
sort.Sort(capsByName(caps))
p.runlock.Lock()
defer p.runlock.Unlock()
offset := baseProtocolLength
outer:
for _, cap := range caps {
for _, proto := range p.protocols {
if proto.Name == cap.Name &&
proto.Version == cap.Version &&
p.running[cap.Name] == nil {
p.running[cap.Name] = p.startProto(offset, proto)
offset += proto.Length
continue outer
}
}
}
}
func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
rw := &proto{
in: make(chan Msg),
offset: offset,
maxcode: impl.Length,
peer: p,
}
p.protoWG.Add(1)
go func() {
err := impl.Run(p, rw)
if err == nil {
p.Infof("protocol %q returned", impl.Name)
err = newPeerError(errMisc, "protocol returned")
} else {
p.Errorf("protocol %q error: %v\n", impl.Name, err)
}
select {
case p.protoErr <- err:
case <-p.closed:
}
p.protoWG.Done()
}()
return rw
}
// getProto finds the protocol responsible for handling
// the given message code.
func (p *Peer) getProto(code uint64) (*proto, error) {
p.runlock.RLock()
defer p.runlock.RUnlock()
for _, proto := range p.running {
if code >= proto.offset && code < proto.offset+proto.maxcode {
return proto, nil
}
}
return nil, newPeerError(errInvalidMsgCode, "%d", code)
}
func (p *Peer) closeProtocols() {
p.runlock.RLock()
for _, p := range p.running {
close(p.in)
}
p.runlock.RUnlock()
p.protoWG.Wait()
}
// writeProtoMsg sends the given message on behalf of the given named protocol.
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
p.runlock.RLock()
proto, ok := p.running[protoName]
p.runlock.RUnlock()
if !ok {
return fmt.Errorf("protocol %s not handled by peer", protoName)
}
if msg.Code >= proto.maxcode {
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
}
msg.Code += proto.offset
return p.writeMsg(msg, msgWriteTimeout)
}
// writeMsg writes a message to the connection.
func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
p.writeMu.Lock()
defer p.writeMu.Unlock()
p.conn.SetWriteDeadline(time.Now().Add(timeout))
if err := writeMsg(p.bufconn, msg); err != nil {
return newPeerError(errWrite, "%v", err)
}
return p.bufconn.Flush()
}
type proto struct {
name string
in chan Msg
maxcode, offset uint64
peer *Peer
}
func (rw *proto) WriteMsg(msg Msg) error {
if msg.Code >= rw.maxcode {
return newPeerError(errInvalidMsgCode, "not handled")
}
msg.Code += rw.offset
return rw.peer.writeMsg(msg, msgWriteTimeout)
}
func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
return rw.WriteMsg(NewMsg(code, data))
}
func (rw *proto) ReadMsg() (Msg, error) {
msg, ok := <-rw.in
if !ok {
return msg, io.EOF
}
msg.Code -= rw.offset
return msg, nil
}
// eofSignal wraps a reader with eof signaling.
// the eof channel is closed when the wrapped reader
// reaches EOF.
type eofSignal struct {
wrapped io.Reader
eof chan<- struct{}
}
func (r *eofSignal) Read(buf []byte) (int, error) {
n, err := r.wrapped.Read(buf)
if err != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed
}
return n, err
} }

View File

@ -4,71 +4,121 @@ import (
"fmt" "fmt"
) )
type ErrorCode int
const errorChanCapacity = 10
const ( const (
PacketTooLong = iota errMagicTokenMismatch = iota
PayloadTooShort errRead
MagicTokenMismatch errWrite
ReadError errMisc
WriteError errInvalidMsgCode
MiscError errInvalidMsg
InvalidMsgCode errP2PVersionMismatch
InvalidMsg errPubkeyMissing
P2PVersionMismatch errPubkeyInvalid
PubkeyMissing errPubkeyForbidden
PubkeyInvalid errProtocolBreach
PubkeyForbidden errPingTimeout
ProtocolBreach errInvalidNetworkId
PortMismatch errInvalidProtocolVersion
PingTimeout
InvalidGenesis
InvalidNetworkId
InvalidProtocolVersion
) )
var errorToString = map[ErrorCode]string{ var errorToString = map[int]string{
PacketTooLong: "Packet too long", errMagicTokenMismatch: "Magic token mismatch",
PayloadTooShort: "Payload too short", errRead: "Read error",
MagicTokenMismatch: "Magic token mismatch", errWrite: "Write error",
ReadError: "Read error", errMisc: "Misc error",
WriteError: "Write error", errInvalidMsgCode: "Invalid message code",
MiscError: "Misc error", errInvalidMsg: "Invalid message",
InvalidMsgCode: "Invalid message code", errP2PVersionMismatch: "P2P Version Mismatch",
InvalidMsg: "Invalid message", errPubkeyMissing: "Public key missing",
P2PVersionMismatch: "P2P Version Mismatch", errPubkeyInvalid: "Public key invalid",
PubkeyMissing: "Public key missing", errPubkeyForbidden: "Public key forbidden",
PubkeyInvalid: "Public key invalid", errProtocolBreach: "Protocol Breach",
PubkeyForbidden: "Public key forbidden", errPingTimeout: "Ping timeout",
ProtocolBreach: "Protocol Breach", errInvalidNetworkId: "Invalid network id",
PortMismatch: "Port mismatch", errInvalidProtocolVersion: "Invalid protocol version",
PingTimeout: "Ping timeout",
InvalidGenesis: "Invalid genesis block",
InvalidNetworkId: "Invalid network id",
InvalidProtocolVersion: "Invalid protocol version",
} }
type PeerError struct { type peerError struct {
Code ErrorCode Code int
message string message string
} }
func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError { func newPeerError(code int, format string, v ...interface{}) *peerError {
desc, ok := errorToString[code] desc, ok := errorToString[code]
if !ok { if !ok {
panic("invalid error code") panic("invalid error code")
} }
format = desc + ": " + format err := &peerError{code, desc}
message := fmt.Sprintf(format, v...) if format != "" {
return &PeerError{code, message} err.message += ": " + fmt.Sprintf(format, v...)
}
return err
} }
func (self *PeerError) Error() string { func (self *peerError) Error() string {
return self.message return self.message
} }
func NewPeerErrorChannel() chan error { type DiscReason byte
return make(chan error, errorChanCapacity)
const (
DiscRequested DiscReason = 0x00
DiscNetworkError = 0x01
DiscProtocolError = 0x02
DiscUselessPeer = 0x03
DiscTooManyPeers = 0x04
DiscAlreadyConnected = 0x05
DiscIncompatibleVersion = 0x06
DiscInvalidIdentity = 0x07
DiscQuitting = 0x08
DiscUnexpectedIdentity = 0x09
DiscSelf = 0x0a
DiscReadTimeout = 0x0b
DiscSubprotocolError = 0x10
)
var discReasonToString = [DiscSubprotocolError + 1]string{
DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol",
DiscUselessPeer: "Useless peer",
DiscTooManyPeers: "Too many peers",
DiscAlreadyConnected: "Already connected",
DiscIncompatibleVersion: "Incompatible P2P protocol version",
DiscInvalidIdentity: "Invalid node identity",
DiscQuitting: "Client quitting",
DiscUnexpectedIdentity: "Unexpected identity",
DiscSelf: "Connected to self",
DiscReadTimeout: "Read timeout",
DiscSubprotocolError: "Subprotocol error",
}
func (d DiscReason) String() string {
if len(discReasonToString) < int(d) {
return fmt.Sprintf("Unknown Reason(%d)", d)
}
return discReasonToString[d]
}
func discReasonForError(err error) DiscReason {
peerError, ok := err.(*peerError)
if !ok {
return DiscSubprotocolError
}
switch peerError.Code {
case errP2PVersionMismatch:
return DiscIncompatibleVersion
case errPubkeyMissing, errPubkeyInvalid:
return DiscInvalidIdentity
case errPubkeyForbidden:
return DiscUselessPeer
case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
return DiscProtocolError
case errPingTimeout:
return DiscReadTimeout
case errRead, errWrite, errMisc:
return DiscNetworkError
default:
return DiscSubprotocolError
}
} }

View File

@ -1,98 +0,0 @@
package p2p
import (
"net"
)
const (
severityThreshold = 10
)
type DisconnectRequest struct {
addr net.Addr
reason DiscReason
}
type PeerErrorHandler struct {
quit chan chan bool
address net.Addr
peerDisconnect chan DisconnectRequest
severity int
errc chan error
}
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, errc chan error) *PeerErrorHandler {
return &PeerErrorHandler{
quit: make(chan chan bool),
address: address,
peerDisconnect: peerDisconnect,
errc: errc,
}
}
func (self *PeerErrorHandler) Start() {
go self.listen()
}
func (self *PeerErrorHandler) Stop() {
q := make(chan bool)
self.quit <- q
<-q
}
func (self *PeerErrorHandler) listen() {
for {
select {
case err, ok := <-self.errc:
if ok {
logger.Debugf("error %v\n", err)
go self.handle(err)
} else {
return
}
case q := <-self.quit:
q <- true
return
}
}
}
func (self *PeerErrorHandler) handle(err error) {
reason := DiscReason(' ')
peerError, ok := err.(*PeerError)
if !ok {
peerError = NewPeerError(MiscError, " %v", err)
}
switch peerError.Code {
case P2PVersionMismatch:
reason = DiscIncompatibleVersion
case PubkeyMissing, PubkeyInvalid:
reason = DiscInvalidIdentity
case PubkeyForbidden:
reason = DiscUselessPeer
case InvalidMsgCode, PacketTooLong, PayloadTooShort, MagicTokenMismatch, ProtocolBreach:
reason = DiscProtocolError
case PingTimeout:
reason = DiscReadTimeout
case ReadError, WriteError, MiscError:
reason = DiscNetworkError
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
reason = DiscSubprotocolError
default:
self.severity += self.getSeverity(peerError)
}
if self.severity >= severityThreshold {
reason = DiscSubprotocolError
}
if reason != DiscReason(' ') {
self.peerDisconnect <- DisconnectRequest{
addr: self.address,
reason: reason,
}
}
}
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
return 1
}

View File

@ -1,34 +0,0 @@
package p2p
import (
// "fmt"
"net"
"testing"
"time"
)
func TestPeerErrorHandler(t *testing.T) {
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
peerDisconnect := make(chan DisconnectRequest)
peerErrorChan := NewPeerErrorChannel()
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan)
peh.Start()
defer peh.Stop()
for i := 0; i < 11; i++ {
select {
case <-peerDisconnect:
t.Errorf("expected no disconnect request")
default:
}
peerErrorChan <- NewPeerError(MiscError, "")
}
time.Sleep(1 * time.Millisecond)
select {
case request := <-peerDisconnect:
if request.addr.String() != address.String() {
t.Errorf("incorrect address %v != %v", request.addr, address)
}
default:
t.Errorf("expected disconnect request")
}
}

View File

@ -1,90 +1,222 @@
package p2p package p2p
// "net" import (
"bufio"
"net"
"reflect"
"testing"
"time"
)
// func TestPeer(t *testing.T) { var discard = Protocol{
// handlers := make(Handlers) Name: "discard",
// testProtocol := &TestProtocol{recv: make(chan testMsg)} Length: 1,
// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } Run: func(p *Peer, rw MsgReadWriter) error {
// handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } for {
// addr := &TestAddr{"test:30"} msg, err := rw.ReadMsg()
// conn := NewTestNetworkConnection(addr) if err != nil {
// _, server := SetupTestServer(handlers) return err
// server.Handshake() }
// peer := NewPeer(conn, addr, true, server) if err = msg.Discard(); err != nil {
// // peer.Messenger().AddProtocols([]string{"aaa", "ccc"}) return err
// peer.Start() }
// defer peer.Stop() }
// time.Sleep(2 * time.Millisecond) },
// if len(conn.Out) != 1 { }
// t.Errorf("handshake not sent")
// } else {
// out := conn.Out[0]
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
// if bytes.Compare(out, packet) != 0 {
// t.Errorf("incorrect handshake packet %v != %v", out, packet)
// }
// }
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
// conn.In(0, packet) conn1, conn2 := net.Pipe()
// time.Sleep(10 * time.Millisecond) id := NewSimpleClientIdentity("test", "0", "0", "public key")
peer := newPeer(conn1, protos, nil)
peer.ourID = id
peer.pubkeyHook = func(*peerAddr) error { return nil }
errc := make(chan error, 1)
go func() {
_, err := peer.loop()
errc <- err
}()
return conn2, peer, errc
}
// pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) func TestPeerProtoReadMsg(t *testing.T) {
// if pro.state != handshakeReceived { defer testlog(t).detach()
// t.Errorf("handshake not received")
// }
// if peer.Port != 30 {
// t.Errorf("port incorrectly set")
// }
// if peer.Id != "peer" {
// t.Errorf("id incorrectly set")
// }
// if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
// t.Errorf("pubkey incorrectly set")
// }
// fmt.Println(peer.Caps)
// if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
// t.Errorf("protocols incorrectly set")
// }
// msg := NewMsg(3) done := make(chan struct{})
// err := peer.Write("aaa", msg) proto := Protocol{
// if err != nil { Name: "a",
// t.Errorf("expect no error for known protocol: %v", err) Length: 5,
// } else { Run: func(peer *Peer, rw MsgReadWriter) error {
// time.Sleep(1 * time.Millisecond) msg, err := rw.ReadMsg()
// if len(conn.Out) != 2 { if err != nil {
// t.Errorf("msg not written") t.Errorf("read error: %v", err)
// } else { }
// out := conn.Out[1] if msg.Code != 2 {
// packet := Packet(16, 3) t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
// if bytes.Compare(out, packet) != 0 { }
// t.Errorf("incorrect packet %v != %v", out, packet) data, err := msg.Data()
// } if err != nil {
// } t.Errorf("data decoding error: %v", err)
// } }
expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
if !reflect.DeepEqual(data.Slice(), expdata) {
t.Errorf("incorrect msg data %#v", data.Slice())
}
close(done)
return nil
},
}
// msg = NewMsg(2) net, peer, errc := testPeer([]Protocol{proto})
// err = peer.Write("ccc", msg) defer net.Close()
// if err != nil { peer.startSubprotocols([]Cap{proto.cap()})
// t.Errorf("expect no error for known protocol: %v", err)
// } else {
// time.Sleep(1 * time.Millisecond)
// if len(conn.Out) != 3 {
// t.Errorf("msg not written")
// } else {
// out := conn.Out[2]
// packet := Packet(21, 2)
// if bytes.Compare(out, packet) != 0 {
// t.Errorf("incorrect packet %v != %v", out, packet)
// }
// }
// }
// err = peer.Write("bbb", msg) writeMsg(net, NewMsg(18, 1, "000"))
// time.Sleep(1 * time.Millisecond) select {
// if err == nil { case <-done:
// t.Errorf("expect error for unknown protocol") case err := <-errc:
// } t.Errorf("peer returned: %v", err)
// } case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
}
}
func TestPeerProtoReadLargeMsg(t *testing.T) {
defer testlog(t).detach()
msgsize := uint32(10 * 1024 * 1024)
done := make(chan struct{})
proto := Protocol{
Name: "a",
Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Size != msgsize+4 {
t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
}
msg.Discard()
close(done)
return nil
},
}
net, peer, errc := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, make([]byte, msgsize)))
select {
case <-done:
case err := <-errc:
t.Errorf("peer returned: %v", err)
case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
}
}
func TestPeerProtoEncodeMsg(t *testing.T) {
defer testlog(t).detach()
proto := Protocol{
Name: "a",
Length: 2,
Run: func(peer *Peer, rw MsgReadWriter) error {
if err := rw.EncodeMsg(2); err == nil {
t.Error("expected error for out-of-range msg code, got nil")
}
if err := rw.EncodeMsg(1); err != nil {
t.Errorf("write error: %v", err)
}
return nil
},
}
net, peer, _ := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
bufr := bufio.NewReader(net)
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
}
}
func TestPeerWrite(t *testing.T) {
defer testlog(t).detach()
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
peer.startSubprotocols([]Cap{discard.cap()})
// test write errors
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
t.Errorf("expected error for unknown protocol, got nil")
}
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
t.Errorf("expected error for out-of-range msg code, got nil")
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
}
// setup for reading the message on the other end
read := make(chan struct{})
go func() {
bufr := bufio.NewReader(net)
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v", err)
} else if msg.Code != 16 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
}
msg.Discard()
close(read)
}()
// test succcessful write
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err)
}
select {
case <-read:
case err := <-peerErr:
t.Fatalf("peer stopped: %v", err)
}
}
func TestPeerActivity(t *testing.T) {
// shorten inactivityTimeout while this test is running
oldT := inactivityTimeout
defer func() { inactivityTimeout = oldT }()
inactivityTimeout = 20 * time.Millisecond
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
peer.startSubprotocols([]Cap{discard.cap()})
sub := peer.activity.Subscribe(time.Time{})
defer sub.Unsubscribe()
for i := 0; i < 6; i++ {
writeMsg(net, NewMsg(16))
select {
case <-sub.Chan():
case <-time.After(inactivityTimeout / 2):
t.Fatal("no event within ", inactivityTimeout/2)
case err := <-peerErr:
t.Fatal("peer error", err)
}
}
select {
case <-time.After(inactivityTimeout * 2):
case <-sub.Chan():
t.Fatal("got activity event while connection was inactive")
case err := <-peerErr:
t.Fatal("peer error", err)
}
}

View File

@ -3,249 +3,185 @@ package p2p
import ( import (
"bytes" "bytes"
"net" "net"
"sort"
"time" "time"
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
) )
// Protocol is implemented by P2P subprotocols. // Protocol represents a P2P subprotocol implementation.
type Protocol interface { type Protocol struct {
// Start is called when the protocol becomes active. // Name should contain the official protocol name,
// It should read and write messages from rw. // often a three-letter word.
// Messages must be fully consumed. Name string
// Version should contain the version number of the protocol.
Version uint
// Length should contain the number of message codes used
// by the protocol.
Length uint64
// Run is called in a new groutine when the protocol has been
// negotiated with a peer. It should read and write messages from
// rw. The Payload for each message must be fully consumed.
// //
// The connection is closed when Start returns. It should return // The peer connection is closed when Start returns. It should return
// any protocol-level error (such as an I/O error) that is // any protocol-level error (such as an I/O error) that is
// encountered. // encountered.
Start(peer *Peer, rw MsgReadWriter) error Run func(peer *Peer, rw MsgReadWriter) error
// Offset should return the number of message codes
// used by the protocol.
Offset() MsgCode
} }
type MsgReader interface { func (p Protocol) cap() Cap {
ReadMsg() (Msg, error) return Cap{p.Name, p.Version}
} }
type MsgWriter interface { const (
WriteMsg(Msg) error baseProtocolVersion = 2
baseProtocolLength = uint64(16)
baseProtocolMaxMsgSize = 10 * 1024 * 1024
)
const (
// devp2p message codes
handshakeMsg = 0x00
discMsg = 0x01
pingMsg = 0x02
pongMsg = 0x03
getPeersMsg = 0x04
peersMsg = 0x05
)
// handshake is the structure of a handshake list.
type handshake struct {
Version uint64
ID string
Caps []Cap
ListenPort uint64
NodeID []byte
} }
// MsgReadWriter is passed to protocols. Protocol implementations can func (h *handshake) String() string {
// use it to write messages back to a connected peer. return h.ID
type MsgReadWriter interface { }
MsgReader func (h *handshake) Pubkey() []byte {
MsgWriter return h.NodeID
} }
type MsgHandler func(code MsgCode, data *ethutil.Value) error // Cap is the structure of a peer capability.
type Cap struct {
// MsgLoop reads messages off the given reader and Name string
// calls the handler function for each decoded message until Version uint
// it returns an error or the peer connection is closed.
//
// If a message is larger than the given maximum size, RunProtocol
// returns an appropriate error.n
func MsgLoop(r MsgReader, maxsize uint32, handler MsgHandler) error {
for {
msg, err := r.ReadMsg()
if err != nil {
return err
}
if msg.Size > maxsize {
return NewPeerError(InvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
}
value, err := msg.Data()
if err != nil {
return err
}
if err := handler(msg.Code, value); err != nil {
return err
}
}
} }
// the ÐΞVp2p base protocol func (cap Cap) RlpData() interface{} {
return []interface{}{cap.Name, cap.Version}
}
type capsByName []Cap
func (cs capsByName) Len() int { return len(cs) }
func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
type baseProtocol struct { type baseProtocol struct {
rw MsgReadWriter rw MsgReadWriter
peer *Peer peer *Peer
} }
type bpMsg struct { func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
code MsgCode bp := &baseProtocol{rw, peer}
data *ethutil.Value
}
const ( // do handshake
p2pVersion = 0 if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
pingTimeout = 2 * time.Second return err
pingGracePeriod = 2 * time.Second
)
const (
// message codes
handshakeMsg = iota
discMsg
pingMsg
pongMsg
getPeersMsg
peersMsg
)
const (
baseProtocolOffset MsgCode = 16
baseProtocolMaxMsgSize = 500 * 1024
)
type DiscReason byte
const (
// Values are given explicitly instead of by iota because these values are
// defined by the wire protocol spec; it is easier for humans to ensure
// correctness when values are explicit.
DiscRequested = 0x00
DiscNetworkError = 0x01
DiscProtocolError = 0x02
DiscUselessPeer = 0x03
DiscTooManyPeers = 0x04
DiscAlreadyConnected = 0x05
DiscIncompatibleVersion = 0x06
DiscInvalidIdentity = 0x07
DiscQuitting = 0x08
DiscUnexpectedIdentity = 0x09
DiscSelf = 0x0a
DiscReadTimeout = 0x0b
DiscSubprotocolError = 0x10
)
var discReasonToString = [DiscSubprotocolError + 1]string{
DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol",
DiscUselessPeer: "Useless peer",
DiscTooManyPeers: "Too many peers",
DiscAlreadyConnected: "Already connected",
DiscIncompatibleVersion: "Incompatible P2P protocol version",
DiscInvalidIdentity: "Invalid node identity",
DiscQuitting: "Client quitting",
DiscUnexpectedIdentity: "Unexpected identity",
DiscSelf: "Connected to self",
DiscReadTimeout: "Read timeout",
DiscSubprotocolError: "Subprotocol error",
}
func (d DiscReason) String() string {
if len(discReasonToString) < int(d) {
return "Unknown"
} }
return discReasonToString[d]
}
func (bp *baseProtocol) Offset() MsgCode {
return baseProtocolOffset
}
func (bp *baseProtocol) Start(peer *Peer, rw MsgReadWriter) error {
bp.peer, bp.rw = peer, rw
// Do the handshake.
// TODO: disconnect is valid before handshake, too.
rw.WriteMsg(bp.peer.server.handshakeMsg())
msg, err := rw.ReadMsg() msg, err := rw.ReadMsg()
if err != nil { if err != nil {
return err return err
} }
if msg.Code != handshakeMsg { if msg.Code != handshakeMsg {
return NewPeerError(ProtocolBreach, " first message must be handshake") return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
} }
data, err := msg.Data() data, err := msg.Data()
if err != nil { if err != nil {
return NewPeerError(InvalidMsg, "%v", err) return newPeerError(errInvalidMsg, "%v", err)
} }
if err := bp.handleHandshake(data); err != nil { if err := bp.handleHandshake(data); err != nil {
return err return err
} }
msgin := make(chan bpMsg) // run main loop
done := make(chan error, 1) quit := make(chan error, 1)
go func() { go func() {
done <- MsgLoop(rw, baseProtocolMaxMsgSize, quit <- MsgLoop(rw, baseProtocolMaxMsgSize, bp.handle)
func(code MsgCode, data *ethutil.Value) error {
msgin <- bpMsg{code, data}
return nil
})
}() }()
return bp.loop(msgin, done) return bp.loop(quit)
} }
func (bp *baseProtocol) loop(msgin <-chan bpMsg, quit <-chan error) error { var pingTimeout = 2 * time.Second
logger.Debugf("pingpong keepalive started at %v\n", time.Now())
messenger := bp.rw.(*proto).messenger
pingTimer := time.NewTimer(pingTimeout)
pinged := true
for { func (bp *baseProtocol) loop(quit <-chan error) error {
ping := time.NewTimer(pingTimeout)
activity := bp.peer.activity.Subscribe(time.Time{})
lastActive := time.Time{}
defer ping.Stop()
defer activity.Unsubscribe()
getPeersTick := time.NewTicker(10 * time.Second)
defer getPeersTick.Stop()
err := bp.rw.EncodeMsg(getPeersMsg)
for err == nil {
select { select {
case msg := <-msgin: case err = <-quit:
if err := bp.handle(msg.code, msg.data); err != nil {
return err
}
case err := <-quit:
return err return err
case <-messenger.pulse: case <-getPeersTick.C:
pingTimer.Reset(pingTimeout) err = bp.rw.EncodeMsg(getPeersMsg)
pinged = false case event := <-activity.Chan():
case <-pingTimer.C: ping.Reset(pingTimeout)
if pinged { lastActive = event.(time.Time)
return NewPeerError(PingTimeout, "") case t := <-ping.C:
if lastActive.Add(pingTimeout * 2).Before(t) {
err = newPeerError(errPingTimeout, "")
} else if lastActive.Add(pingTimeout).Before(t) {
err = bp.rw.EncodeMsg(pingMsg)
} }
logger.Debugf("pinging at %v\n", time.Now())
if err := bp.rw.WriteMsg(NewMsg(pingMsg)); err != nil {
return NewPeerError(WriteError, "%v", err)
}
pinged = true
pingTimer.Reset(pingTimeout)
} }
} }
return err
} }
func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error { func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error {
switch code { switch code {
case handshakeMsg: case handshakeMsg:
return NewPeerError(ProtocolBreach, " extra handshake received") return newPeerError(errProtocolBreach, "extra handshake received")
case discMsg: case discMsg:
logger.Infof("Disconnect requested from peer %v, reason", DiscReason(data.Get(0).Uint())) bp.peer.Disconnect(DiscReason(data.Get(0).Uint()))
bp.peer.server.PeerDisconnect() <- DisconnectRequest{ return nil
addr: bp.peer.Address,
reason: DiscRequested,
}
case pingMsg: case pingMsg:
return bp.rw.WriteMsg(NewMsg(pongMsg)) return bp.rw.EncodeMsg(pongMsg)
case pongMsg: case pongMsg:
// reply for ping
case getPeersMsg: case getPeersMsg:
// Peer asked for list of connected peers. peers := bp.peerList()
peersRLP := bp.peer.server.encodedPeerList() // this is dangerous. the spec says that we should _delay_
if peersRLP != nil { // sending the response if no new information is available.
msg := Msg{ // this means that would need to send a response later when
Code: peersMsg, // new peers become available.
Size: uint32(len(peersRLP)), //
Payload: bytes.NewReader(peersRLP), // TODO: add event mechanism to notify baseProtocol for new peers
} if len(peers) > 0 {
return bp.rw.WriteMsg(msg) return bp.rw.EncodeMsg(peersMsg, peers)
} }
case peersMsg: case peersMsg:
bp.handlePeers(data) bp.handlePeers(data)
default: default:
return NewPeerError(InvalidMsgCode, "unknown message code %v", code) return newPeerError(errInvalidMsgCode, "unknown message code %v", code)
} }
return nil return nil
} }
@ -253,62 +189,102 @@ func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error {
func (bp *baseProtocol) handlePeers(data *ethutil.Value) { func (bp *baseProtocol) handlePeers(data *ethutil.Value) {
it := data.NewIterator() it := data.NewIterator()
for it.Next() { for it.Next() {
ip := net.IP(it.Value().Get(0).Bytes()) addr := &peerAddr{
port := it.Value().Get(1).Uint() IP: net.IP(it.Value().Get(0).Bytes()),
address := &net.TCPAddr{IP: ip, Port: int(port)} Port: it.Value().Get(1).Uint(),
go bp.peer.server.PeerConnect(address) Pubkey: it.Value().Get(2).Bytes(),
}
bp.peer.Debugf("received peer suggestion: %v", addr)
bp.peer.newPeerAddr <- addr
} }
} }
func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error { func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
var ( hs := handshake{
remoteVersion = c.Get(0).Uint() Version: c.Get(0).Uint(),
id = c.Get(1).Str() ID: c.Get(1).Str(),
caps = c.Get(2) Caps: nil, // decoded below
port = c.Get(3).Uint() ListenPort: c.Get(3).Uint(),
pubkey = c.Get(4).Bytes() NodeID: c.Get(4).Bytes(),
)
// Check correctness of p2p protocol version
if remoteVersion != p2pVersion {
return NewPeerError(P2PVersionMismatch, "Require protocol %d, received %d\n", p2pVersion, remoteVersion)
} }
if hs.Version != baseProtocolVersion {
// Handle the pub key (validation, uniqueness) return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
if len(pubkey) == 0 { baseProtocolVersion, hs.Version)
return NewPeerError(PubkeyMissing, "not supplied in handshake.")
} }
if len(hs.NodeID) == 0 {
if len(pubkey) != 64 { return newPeerError(errPubkeyMissing, "")
return NewPeerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8)
} }
if len(hs.NodeID) != 64 {
// self connect detection return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 {
return NewPeerError(PubkeyForbidden, "not allowed to connect to self")
} }
if da := bp.peer.dialAddr; da != nil {
// register pubkey on server. this also sets the pubkey on the peer (need lock) // verify that the peer we wanted to connect to
if err := bp.peer.server.RegisterPubkey(bp.peer, pubkey); err != nil { // actually holds the target public key.
return NewPeerError(PubkeyForbidden, err.Error()) if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
}
} }
pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
// check port if err := bp.peer.pubkeyHook(pa); err != nil {
if bp.peer.Inbound { return newPeerError(errPubkeyForbidden, "%v", err)
uint16port := uint16(port) }
if bp.peer.Port > 0 && bp.peer.Port != uint16port { capsIt := c.Get(2).NewIterator()
return NewPeerError(PortMismatch, "port mismatch: %v != %v", bp.peer.Port, port) for capsIt.Next() {
} else { cap := capsIt.Value()
bp.peer.Port = uint16port name := cap.Get(0).Str()
if name != "" {
hs.Caps = append(hs.Caps, Cap{Name: name, Version: uint(cap.Get(1).Uint())})
} }
} }
capsIt := caps.NewIterator() var addr *peerAddr
for capsIt.Next() { if hs.ListenPort != 0 {
cap := capsIt.Value().Str() addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
bp.peer.Caps = append(bp.peer.Caps, cap) addr.Port = hs.ListenPort
} }
sort.Strings(bp.peer.Caps) bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
bp.rw.(*proto).messenger.setRemoteProtocols(bp.peer.Caps) bp.peer.startSubprotocols(hs.Caps)
bp.peer.Id = id
return nil return nil
} }
func (bp *baseProtocol) handshakeMsg() Msg {
var (
port uint64
caps []interface{}
)
if bp.peer.ourListenAddr != nil {
port = bp.peer.ourListenAddr.Port
}
for _, proto := range bp.peer.protocols {
caps = append(caps, proto.cap())
}
return NewMsg(handshakeMsg,
baseProtocolVersion,
bp.peer.ourID.String(),
caps,
port,
bp.peer.ourID.Pubkey()[1:],
)
}
func (bp *baseProtocol) peerList() []ethutil.RlpEncodable {
peers := bp.peer.otherPeers()
ds := make([]ethutil.RlpEncodable, 0, len(peers))
for _, p := range peers {
p.infolock.Lock()
addr := p.listenAddr
p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == bp.peer || addr == nil {
continue
}
ds = append(ds, addr)
}
ourAddr := bp.peer.ourListenAddr
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
ds = append(ds, ourAddr)
}
return ds
}

View File

@ -2,21 +2,420 @@ package p2p
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"net" "net"
"sort"
"strconv"
"sync" "sync"
"time" "time"
logpkg "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
) )
const ( const (
outboundAddressPoolSize = 10 outboundAddressPoolSize = 500
disconnectGracePeriod = 2 defaultDialTimeout = 10 * time.Second
portMappingUpdateInterval = 15 * time.Minute
portMappingTimeout = 20 * time.Minute
) )
var srvlog = logger.NewLogger("P2P Server")
// Server manages all peer connections.
//
// The fields of Server are used as configuration parameters.
// You should set them before starting the Server. Fields may not be
// modified while the server is running.
type Server struct {
// This field must be set to a valid client identity.
Identity ClientIdentity
// MaxPeers is the maximum number of peers that can be
// connected. It must be greater than zero.
MaxPeers int
// Protocols should contain the protocols supported
// by the server. Matching protocols are launched for
// each peer.
Protocols []Protocol
// If Blacklist is set to a non-nil value, the given Blacklist
// is used to verify peer connections.
Blacklist Blacklist
// If ListenAddr is set to a non-nil address, the server
// will listen for incoming connections.
//
// If the port is zero, the operating system will pick a port. The
// ListenAddr field will be updated with the actual address when
// the server is started.
ListenAddr string
// If set to a non-nil value, the given NAT port mapper
// is used to make the listening port available to the
// Internet.
NAT NAT
// If Dialer is set to a non-nil value, the given Dialer
// is used to dial outbound peer connections.
Dialer *net.Dialer
// If NoDial is true, the server will not dial any peers.
NoDial bool
// Hook for testing. This is useful because we can inhibit
// the whole protocol stack.
newPeerFunc peerFunc
lock sync.RWMutex
running bool
listener net.Listener
laddr *net.TCPAddr // real listen addr
peers []*Peer
peerSlots chan int
peerCount int
quit chan struct{}
wg sync.WaitGroup
peerConnect chan *peerAddr
peerDisconnect chan *Peer
}
// NAT is implemented by NAT traversal methods.
type NAT interface {
GetExternalAddress() (net.IP, error)
AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
DeletePortMapping(protocol string, extport, intport int) error
// Should return name of the method.
String() string
}
type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer
// Peers returns all connected peers.
func (srv *Server) Peers() (peers []*Peer) {
srv.lock.RLock()
defer srv.lock.RUnlock()
for _, peer := range srv.peers {
if peer != nil {
peers = append(peers, peer)
}
}
return
}
// PeerCount returns the number of connected peers.
func (srv *Server) PeerCount() int {
srv.lock.RLock()
defer srv.lock.RUnlock()
return srv.peerCount
}
// SuggestPeer injects an address into the outbound address pool.
func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
select {
case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}:
default: // don't block
}
}
// Broadcast sends an RLP-encoded message to all connected peers.
// This method is deprecated and will be removed later.
func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
var payload []byte
if data != nil {
payload = encodePayload(data...)
}
srv.lock.RLock()
defer srv.lock.RUnlock()
for _, peer := range srv.peers {
if peer != nil {
var msg = Msg{Code: code}
if data != nil {
msg.Payload = bytes.NewReader(payload)
msg.Size = uint32(len(payload))
}
peer.writeProtoMsg(protocol, msg)
}
}
}
// Start starts running the server.
// Servers can be re-used and started again after stopping.
func (srv *Server) Start() (err error) {
srv.lock.Lock()
defer srv.lock.Unlock()
if srv.running {
return errors.New("server already running")
}
srvlog.Infoln("Starting Server")
// initialize fields
if srv.Identity == nil {
return fmt.Errorf("Server.Identity must be set to a non-nil identity")
}
if srv.MaxPeers <= 0 {
return fmt.Errorf("Server.MaxPeers must be > 0")
}
srv.quit = make(chan struct{})
srv.peers = make([]*Peer, srv.MaxPeers)
srv.peerSlots = make(chan int, srv.MaxPeers)
srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize)
srv.peerDisconnect = make(chan *Peer)
if srv.newPeerFunc == nil {
srv.newPeerFunc = newServerPeer
}
if srv.Blacklist == nil {
srv.Blacklist = NewBlacklist()
}
if srv.Dialer == nil {
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
if srv.ListenAddr != "" {
if err := srv.startListening(); err != nil {
return err
}
}
if !srv.NoDial {
srv.wg.Add(1)
go srv.dialLoop()
}
if srv.NoDial && srv.ListenAddr == "" {
srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
}
// make all slots available
for i := range srv.peers {
srv.peerSlots <- i
}
// note: discLoop is not part of WaitGroup
go srv.discLoop()
srv.running = true
return nil
}
func (srv *Server) startListening() error {
listener, err := net.Listen("tcp", srv.ListenAddr)
if err != nil {
return err
}
srv.ListenAddr = listener.Addr().String()
srv.laddr = listener.Addr().(*net.TCPAddr)
srv.listener = listener
srv.wg.Add(1)
go srv.listenLoop()
if !srv.laddr.IP.IsLoopback() && srv.NAT != nil {
srv.wg.Add(1)
go srv.natLoop(srv.laddr.Port)
}
return nil
}
// Stop terminates the server and all active peer connections.
// It blocks until all active connections have been closed.
func (srv *Server) Stop() {
srv.lock.Lock()
if !srv.running {
srv.lock.Unlock()
return
}
srv.running = false
srv.lock.Unlock()
srvlog.Infoln("Stopping server")
if srv.listener != nil {
// this unblocks listener Accept
srv.listener.Close()
}
close(srv.quit)
for _, peer := range srv.Peers() {
peer.Disconnect(DiscQuitting)
}
srv.wg.Wait()
// wait till they actually disconnect
// this is checked by claiming all peerSlots.
// slots become available as the peers disconnect.
for i := 0; i < cap(srv.peerSlots); i++ {
<-srv.peerSlots
}
// terminate discLoop
close(srv.peerDisconnect)
}
func (srv *Server) discLoop() {
for peer := range srv.peerDisconnect {
// peer has just disconnected. free up its slot.
srvlog.Infof("%v is gone", peer)
srv.peerSlots <- peer.slot
srv.lock.Lock()
srv.peers[peer.slot] = nil
srv.lock.Unlock()
}
}
// main loop for adding connections via listening
func (srv *Server) listenLoop() {
defer srv.wg.Done()
srvlog.Infoln("Listening on", srv.listener.Addr())
for {
select {
case slot := <-srv.peerSlots:
conn, err := srv.listener.Accept()
if err != nil {
srv.peerSlots <- slot
return
}
srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot)
srv.addPeer(conn, nil, slot)
case <-srv.quit:
return
}
}
}
func (srv *Server) natLoop(port int) {
defer srv.wg.Done()
for {
srv.updatePortMapping(port)
select {
case <-time.After(portMappingUpdateInterval):
// one more round
case <-srv.quit:
srv.removePortMapping(port)
return
}
}
}
func (srv *Server) updatePortMapping(port int) {
srvlog.Infoln("Attempting to map port", port, "with", srv.NAT)
err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout)
if err != nil {
srvlog.Errorln("Port mapping error:", err)
return
}
extip, err := srv.NAT.GetExternalAddress()
if err != nil {
srvlog.Errorln("Error getting external IP:", err)
return
}
srv.lock.Lock()
extaddr := *(srv.listener.Addr().(*net.TCPAddr))
extaddr.IP = extip
srvlog.Infoln("Mapped port, external addr is", &extaddr)
srv.laddr = &extaddr
srv.lock.Unlock()
}
func (srv *Server) removePortMapping(port int) {
srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT)
srv.NAT.DeletePortMapping("tcp", port, port)
}
func (srv *Server) dialLoop() {
defer srv.wg.Done()
var (
suggest chan *peerAddr
slot *int
slots = srv.peerSlots
)
for {
select {
case i := <-slots:
// we need a peer in slot i, slot reserved
slot = &i
// now we can watch for candidate peers in the next loop
suggest = srv.peerConnect
// do not consume more until candidate peer is found
slots = nil
case desc := <-suggest:
// candidate peer found, will dial out asyncronously
// if connection fails slot will be released
go srv.dialPeer(desc, *slot)
// we can watch if more peers needed in the next loop
slots = srv.peerSlots
// until then we dont care about candidate peers
suggest = nil
case <-srv.quit:
// give back the currently reserved slot
if slot != nil {
srv.peerSlots <- *slot
}
return
}
}
}
// connect to peer via dial out
func (srv *Server) dialPeer(desc *peerAddr, slot int) {
srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot)
conn, err := srv.Dialer.Dial(desc.Network(), desc.String())
if err != nil {
srvlog.Errorf("Dial error: %v", err)
srv.peerSlots <- slot
return
}
go srv.addPeer(conn, desc, slot)
}
// creates the new peer object and inserts it into its slot
func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer {
srv.lock.Lock()
defer srv.lock.Unlock()
if !srv.running {
conn.Close()
srv.peerSlots <- slot // release slot
return nil
}
peer := srv.newPeerFunc(srv, conn, desc)
peer.slot = slot
srv.peers[slot] = peer
srv.peerCount++
go func() { peer.loop(); srv.peerDisconnect <- peer }()
return peer
}
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
func (srv *Server) removePeer(peer *Peer) {
srv.lock.Lock()
defer srv.lock.Unlock()
srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot)
if srv.peers[peer.slot] != peer {
srvlog.Warnln("Invalid peer to remove:", peer)
return
}
// remove from list and index
srv.peerCount--
srv.peers[peer.slot] = nil
// release slot to signal need for a new peer, last!
srv.peerSlots <- peer.slot
}
func (srv *Server) verifyPeer(addr *peerAddr) error {
if srv.Blacklist.Exists(addr.Pubkey) {
return errors.New("blacklisted")
}
if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) {
return newPeerError(errPubkeyForbidden, "not allowed to connect to srv")
}
srv.lock.RLock()
defer srv.lock.RUnlock()
for _, peer := range srv.peers {
if peer != nil {
id := peer.Identity()
if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) {
return errors.New("already connected")
}
}
}
return nil
}
type Blacklist interface { type Blacklist interface {
Get([]byte) (bool, error) Get([]byte) (bool, error)
Put([]byte) error Put([]byte) error
@ -66,423 +465,3 @@ func (self *BlacklistMap) Delete(pubkey []byte) error {
delete(self.blacklist, string(pubkey)) delete(self.blacklist, string(pubkey))
return nil return nil
} }
type Server struct {
network Network
listening bool //needed?
dialing bool //needed?
closed bool
identity ClientIdentity
addr net.Addr
port uint16
protocols []string
quit chan chan bool
peersLock sync.RWMutex
maxPeers int
peers []*Peer
peerSlots chan int
peersTable map[string]int
peerCount int
cachedEncodedPeers []byte
peerConnect chan net.Addr
peerDisconnect chan DisconnectRequest
blacklist Blacklist
handlers Handlers
}
var logger = logpkg.NewLogger("P2P")
func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server {
// get alphabetical list of protocol names from handlers map
protocols := []string{}
for protocol := range handlers {
protocols = append(protocols, protocol)
}
sort.Strings(protocols)
_, port, _ := net.SplitHostPort(addr.String())
intport, _ := strconv.Atoi(port)
self := &Server{
// NewSimpleClientIdentity(clientIdentifier, version, customIdentifier)
network: network,
identity: identity,
addr: addr,
port: uint16(intport),
protocols: protocols,
quit: make(chan chan bool),
maxPeers: maxPeers,
peers: make([]*Peer, maxPeers),
peerSlots: make(chan int, maxPeers),
peersTable: make(map[string]int),
peerConnect: make(chan net.Addr, outboundAddressPoolSize),
peerDisconnect: make(chan DisconnectRequest),
blacklist: blacklist,
handlers: handlers,
}
for i := 0; i < maxPeers; i++ {
self.peerSlots <- i // fill up with indexes
}
return self
}
func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) {
addr, err = self.network.NewAddr(host, port)
return
}
func (self *Server) ParseAddr(address string) (addr net.Addr, err error) {
addr, err = self.network.ParseAddr(address)
return
}
func (self *Server) ClientIdentity() ClientIdentity {
return self.identity
}
func (self *Server) Peers() (peers []*Peer) {
self.peersLock.RLock()
defer self.peersLock.RUnlock()
for _, peer := range self.peers {
if peer != nil {
peers = append(peers, peer)
}
}
return
}
func (self *Server) PeerCount() int {
self.peersLock.RLock()
defer self.peersLock.RUnlock()
return self.peerCount
}
func (self *Server) PeerConnect(addr net.Addr) {
// TODO: should buffer, filter and uniq
// send GetPeersMsg if not blocking
select {
case self.peerConnect <- addr: // not enough peers
self.Broadcast("", getPeersMsg)
default: // we dont care
}
}
func (self *Server) PeerDisconnect() chan DisconnectRequest {
return self.peerDisconnect
}
func (self *Server) Blacklist() Blacklist {
return self.blacklist
}
func (self *Server) Handlers() Handlers {
return self.handlers
}
func (self *Server) Broadcast(protocol string, code MsgCode, data ...interface{}) {
var payload []byte
if data != nil {
payload = encodePayload(data...)
}
self.peersLock.RLock()
defer self.peersLock.RUnlock()
for _, peer := range self.peers {
if peer != nil {
var msg = Msg{Code: code}
if data != nil {
msg.Payload = bytes.NewReader(payload)
msg.Size = uint32(len(payload))
}
peer.messenger.writeProtoMsg(protocol, msg)
}
}
}
// Start the server
func (self *Server) Start(listen bool, dial bool) {
self.network.Start()
if listen {
listener, err := self.network.Listener(self.addr)
if err != nil {
logger.Warnf("Error initializing listener: %v", err)
logger.Warnf("Connection listening disabled")
self.listening = false
} else {
self.listening = true
logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr())
go self.inboundPeerHandler(listener)
}
}
if dial {
dialer, err := self.network.Dialer(self.addr)
if err != nil {
logger.Warnf("Error initializing dialer: %v", err)
logger.Warnf("Connection dialout disabled")
self.dialing = false
} else {
self.dialing = true
logger.Infoln("Dial peers watching outbound address pool")
go self.outboundPeerHandler(dialer)
}
}
logger.Infoln("server started")
}
func (self *Server) Stop() {
logger.Infoln("server stopping...")
// // quit one loop if dialing
if self.dialing {
logger.Infoln("stop dialout...")
dialq := make(chan bool)
self.quit <- dialq
<-dialq
fmt.Println("quit another")
}
// quit the other loop if listening
if self.listening {
logger.Infoln("stop listening...")
listenq := make(chan bool)
self.quit <- listenq
<-listenq
fmt.Println("quit one")
}
fmt.Println("quit waited")
logger.Infoln("stopping peers...")
peers := []net.Addr{}
self.peersLock.RLock()
self.closed = true
for _, peer := range self.peers {
if peer != nil {
peers = append(peers, peer.Address)
}
}
self.peersLock.RUnlock()
for _, address := range peers {
go self.removePeer(DisconnectRequest{
addr: address,
reason: DiscQuitting,
})
}
// wait till they actually disconnect
// this is checked by draining the peerSlots (slots are released back if a peer is removed)
i := 0
fmt.Println("draining peers")
FOR:
for {
select {
case slot := <-self.peerSlots:
i++
fmt.Printf("%v: found slot %v\n", i, slot)
if i == self.maxPeers {
break FOR
}
}
}
logger.Infoln("server stopped")
}
// main loop for adding connections via listening
func (self *Server) inboundPeerHandler(listener net.Listener) {
for {
select {
case slot := <-self.peerSlots:
go self.connectInboundPeer(listener, slot)
case errc := <-self.quit:
listener.Close()
fmt.Println("quit listenloop")
errc <- true
return
}
}
}
// main loop for adding outbound peers based on peerConnect address pool
// this same loop handles peer disconnect requests as well
func (self *Server) outboundPeerHandler(dialer Dialer) {
// addressChan initially set to nil (only watches peerConnect if we need more peers)
var addressChan chan net.Addr
slots := self.peerSlots
var slot *int
for {
select {
case i := <-slots:
// we need a peer in slot i, slot reserved
slot = &i
// now we can watch for candidate peers in the next loop
addressChan = self.peerConnect
// do not consume more until candidate peer is found
slots = nil
case address := <-addressChan:
// candidate peer found, will dial out asyncronously
// if connection fails slot will be released
go self.connectOutboundPeer(dialer, address, *slot)
// we can watch if more peers needed in the next loop
slots = self.peerSlots
// until then we dont care about candidate peers
addressChan = nil
case request := <-self.peerDisconnect:
go self.removePeer(request)
case errc := <-self.quit:
if addressChan != nil && slot != nil {
self.peerSlots <- *slot
}
fmt.Println("quit dialloop")
errc <- true
return
}
}
}
// check if peer address already connected
func (self *Server) isConnected(address net.Addr) bool {
self.peersLock.RLock()
defer self.peersLock.RUnlock()
_, found := self.peersTable[address.String()]
return found
}
// connect to peer via listener.Accept()
func (self *Server) connectInboundPeer(listener net.Listener, slot int) {
var address net.Addr
conn, err := listener.Accept()
if err != nil {
logger.Debugln(err)
self.peerSlots <- slot
return
}
address = conn.RemoteAddr()
// XXX: this won't work because the remote socket
// address does not identify the peer. we should
// probably get rid of this check and rely on public
// key detection in the base protocol.
if self.isConnected(address) {
conn.Close()
self.peerSlots <- slot
return
}
fmt.Printf("adding %v\n", address)
go self.addPeer(conn, address, true, slot)
}
// connect to peer via dial out
func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) {
if self.isConnected(address) {
return
}
conn, err := dialer.Dial(address.Network(), address.String())
if err != nil {
self.peerSlots <- slot
return
}
go self.addPeer(conn, address, false, slot)
}
// creates the new peer object and inserts it into its slot
func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) *Peer {
self.peersLock.Lock()
defer self.peersLock.Unlock()
if self.closed {
fmt.Println("oopsy, not no longer need peer")
conn.Close() //oopsy our bad
self.peerSlots <- slot // release slot
return nil
}
logger.Infoln("adding new peer", address)
peer := NewPeer(conn, address, inbound, self)
self.peers[slot] = peer
self.peersTable[address.String()] = slot
self.peerCount++
self.cachedEncodedPeers = nil
fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
peer.Start()
return peer
}
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
func (self *Server) removePeer(request DisconnectRequest) {
self.peersLock.Lock()
address := request.addr
slot := self.peersTable[address.String()]
peer := self.peers[slot]
fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot)
if peer == nil {
logger.Debugf("already removed peer on %v", address)
self.peersLock.Unlock()
return
}
// remove from list and index
self.peerCount--
self.peers[slot] = nil
delete(self.peersTable, address.String())
self.cachedEncodedPeers = nil
fmt.Printf("removed peer %v (slot %v)\n", peer, slot)
self.peersLock.Unlock()
// sending disconnect message
disconnectMsg := NewMsg(discMsg, request.reason)
peer.Write("", disconnectMsg)
// be nice and wait
time.Sleep(disconnectGracePeriod * time.Second)
// switch off peer and close connections etc.
fmt.Println("stopping peer")
peer.Stop()
fmt.Println("stopped peer")
// release slot to signal need for a new peer, last!
self.peerSlots <- slot
}
// encodedPeerList returns an RLP-encoded list of peers.
// the returned slice will be nil if there are no peers.
func (self *Server) encodedPeerList() []byte {
// TODO: memoize and reset when peers change
self.peersLock.RLock()
defer self.peersLock.RUnlock()
if self.cachedEncodedPeers == nil && self.peerCount > 0 {
var peerData []interface{}
for _, i := range self.peersTable {
peer := self.peers[i]
peerData = append(peerData, peer.Encode())
}
self.cachedEncodedPeers = encodePayload(peerData)
}
return self.cachedEncodedPeers
}
// fix handshake message to push to peers
func (self *Server) handshakeMsg() Msg {
return NewMsg(handshakeMsg,
p2pVersion,
[]byte(self.identity.String()),
[]interface{}{self.protocols},
self.port,
self.identity.Pubkey()[1:],
)
}
func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error {
// Check for blacklisting
if self.blacklist.Exists(pubkey) {
return fmt.Errorf("blacklisted")
}
self.peersLock.RLock()
defer self.peersLock.RUnlock()
for _, peer := range self.peers {
if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 {
return fmt.Errorf("already connected")
}
}
candidate.Pubkey = pubkey
return nil
}

View File

@ -1,289 +1,161 @@
package p2p package p2p
import ( import (
"fmt" "bytes"
"io" "io"
"net" "net"
"sync"
"testing" "testing"
"time" "time"
) )
type TestNetwork struct { func startTestServer(t *testing.T, pf peerFunc) *Server {
connections map[string]*TestNetworkConnection server := &Server{
dialer Dialer Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"),
maxinbound int MaxPeers: 10,
ListenAddr: "127.0.0.1:0",
newPeerFunc: pf,
}
if err := server.Start(); err != nil {
t.Fatalf("Could not start server: %v", err)
}
return server
} }
func NewTestNetwork(maxinbound int) *TestNetwork { func TestServerListen(t *testing.T) {
connections := make(map[string]*TestNetworkConnection) defer testlog(t).detach()
return &TestNetwork{
connections: connections, // start the test server
dialer: &TestDialer{connections}, connected := make(chan *Peer)
maxinbound: maxinbound, srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
if conn == nil {
t.Error("peer func called with nil conn")
}
if dialAddr != nil {
t.Error("peer func called with non-nil dialAddr")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
})
defer close(connected)
defer srv.Stop()
// dial the test server
conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
if err != nil {
t.Fatalf("could not dial: %v", err)
}
defer conn.Close()
select {
case peer := <-connected:
if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.LocalAddr(), conn.RemoteAddr())
}
case <-time.After(1 * time.Second):
t.Error("server did not accept within one second")
} }
} }
func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) { func TestServerDial(t *testing.T) {
return self.dialer, nil defer testlog(t).detach()
}
func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) { // run a fake TCP server to handle the connection.
return &TestListener{ listener, err := net.Listen("tcp", "127.0.0.1:0")
connections: self.connections, if err != nil {
addr: addr, t.Fatalf("could not setup listener: %v")
max: self.maxinbound,
close: make(chan struct{}),
}, nil
}
func (self *TestNetwork) Start() error {
return nil
}
func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) {
return
}
func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) {
return
}
type TestAddr struct {
name string
}
func (self *TestAddr) String() string {
return self.name
}
func (*TestAddr) Network() string {
return "test"
}
type TestDialer struct {
connections map[string]*TestNetworkConnection
}
func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) {
address := &TestAddr{addr}
tconn := NewTestNetworkConnection(address)
self.connections[addr] = tconn
conn = net.Conn(tconn)
return
}
type TestListener struct {
connections map[string]*TestNetworkConnection
addr net.Addr
max int
i int
close chan struct{}
}
func (self *TestListener) Accept() (net.Conn, error) {
self.i++
if self.i > self.max {
<-self.close
return nil, io.EOF
} }
addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)} defer listener.Close()
tconn := NewTestNetworkConnection(addr) accepted := make(chan net.Conn)
key := tconn.RemoteAddr().String() go func() {
self.connections[key] = tconn conn, err := listener.Accept()
fmt.Printf("accepted connection from: %v \n", addr) if err != nil {
return tconn, nil t.Error("acccept error:", err)
} }
conn.Close()
accepted <- conn
}()
func (self *TestListener) Close() error { // start the test server
close(self.close) connected := make(chan *Peer)
return nil srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
} if conn == nil {
t.Error("peer func called with nil conn")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
})
defer close(connected)
defer srv.Stop()
func (self *TestListener) Addr() net.Addr { // tell the server to connect.
return self.addr connAddr := newPeerAddr(listener.Addr(), nil)
} srv.peerConnect <- connAddr
type TestNetworkConnection struct { select {
in chan []byte case conn := <-accepted:
close chan struct{}
current []byte
Out [][]byte
addr net.Addr
}
func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
return &TestNetworkConnection{
in: make(chan []byte),
close: make(chan struct{}),
current: []byte{},
Out: [][]byte{},
addr: addr,
}
}
func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
time.Sleep(latency)
for _, s := range packets {
self.in <- s
}
}
func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
if len(self.current) == 0 {
var ok bool
select { select {
case self.current, ok = <-self.in: case peer := <-connected:
if !ok { if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() {
return 0, io.EOF t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.RemoteAddr(), conn.LocalAddr())
} }
case <-self.close: if peer.dialAddr != connAddr {
return 0, io.EOF t.Errorf("peer started with wrong dialAddr: got %v, want %v",
peer.dialAddr, connAddr)
}
case <-time.After(1 * time.Second):
t.Error("server did not launch peer within one second")
} }
}
length := len(self.current) case <-time.After(1 * time.Second):
if length > len(buff) { t.Error("server did not connect within one second")
copy(buff[:], self.current[:len(buff)])
self.current = self.current[len(buff):]
return len(buff), nil
} else {
copy(buff[:length], self.current[:])
self.current = []byte{}
return length, io.EOF
} }
} }
func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) { func TestServerBroadcast(t *testing.T) {
self.Out = append(self.Out, buff) defer testlog(t).detach()
fmt.Printf("net write(%d): %x\n", len(self.Out), buff) var connected sync.WaitGroup
return len(buff), nil srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer {
} peer := newPeer(c, []Protocol{discard}, dialAddr)
peer.startSubprotocols([]Cap{discard.cap()})
connected.Done()
return peer
})
defer srv.Stop()
func (self *TestNetworkConnection) Close() error { // dial a bunch of conns
close(self.close) var conns = make([]net.Conn, 8)
return nil connected.Add(len(conns))
} deadline := time.Now().Add(3 * time.Second)
dialer := &net.Dialer{Deadline: deadline}
func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) { for i := range conns {
return conn, err := dialer.Dial("tcp", srv.ListenAddr)
} if err != nil {
t.Fatalf("conn %d: dial error: %v", i, err)
func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) { }
return self.addr defer conn.Close()
} conn.SetDeadline(deadline)
conns[i] = conn
func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
return
}
func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
return
}
func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
return
}
func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) {
network = NewTestNetwork(1)
addr := &TestAddr{"test:30303"}
identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey")
maxPeers := 2
if handlers == nil {
handlers = make(Handlers)
} }
blackist := NewBlacklist() connected.Wait()
server = New(network, addr, identity, handlers, maxPeers, blackist)
fmt.Println(server.identity.Pubkey())
return
}
func TestServerListener(t *testing.T) { // broadcast one message
t.SkipNow() srv.Broadcast("discard", 0, "foo")
goldbuf := new(bytes.Buffer)
writeMsg(goldbuf, NewMsg(16, "foo"))
golden := goldbuf.Bytes()
network, server := SetupTestServer(nil) // check that the message has been written everywhere
server.Start(true, false) for i, conn := range conns {
time.Sleep(10 * time.Millisecond) buf := make([]byte, len(golden))
server.Stop() if _, err := io.ReadFull(conn, buf); err != nil {
peer1, ok := network.connections["inboundpeer-1"] t.Errorf("conn %d: read error: %v", i, err)
if !ok { } else if !bytes.Equal(buf, golden) {
t.Error("not found inbound peer 1") t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden)
} else {
if len(peer1.Out) != 2 {
t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2)
} }
} }
} }
func TestServerDialer(t *testing.T) {
network, server := SetupTestServer(nil)
server.Start(false, true)
server.peerConnect <- &TestAddr{"outboundpeer-1"}
time.Sleep(10 * time.Millisecond)
server.Stop()
peer1, ok := network.connections["outboundpeer-1"]
if !ok {
t.Error("not found outbound peer 1")
} else {
if len(peer1.Out) != 2 {
t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2)
}
}
}
// func TestServerBroadcast(t *testing.T) {
// handlers := make(Handlers)
// testProtocol := &TestProtocol{Msgs: []*Msg{}}
// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
// network, server := SetupTestServer(handlers)
// server.Start(true, true)
// server.peerConnect <- &TestAddr{"outboundpeer-1"}
// time.Sleep(10 * time.Millisecond)
// msg := NewMsg(0)
// server.Broadcast("", msg)
// packet := Packet(0, 0)
// time.Sleep(10 * time.Millisecond)
// server.Stop()
// peer1, ok := network.connections["outboundpeer-1"]
// if !ok {
// t.Error("not found outbound peer 1")
// } else {
// fmt.Printf("out: %v\n", peer1.Out)
// if len(peer1.Out) != 3 {
// t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
// } else {
// if bytes.Compare(peer1.Out[1], packet) != 0 {
// t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
// }
// }
// }
// peer2, ok := network.connections["inboundpeer-1"]
// if !ok {
// t.Error("not found inbound peer 2")
// } else {
// fmt.Printf("out: %v\n", peer2.Out)
// if len(peer1.Out) != 3 {
// t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
// } else {
// if bytes.Compare(peer2.Out[1], packet) != 0 {
// t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
// }
// }
// }
// }
func TestServerPeersMessage(t *testing.T) {
t.SkipNow()
_, server := SetupTestServer(nil)
server.Start(true, true)
defer server.Stop()
server.peerConnect <- &TestAddr{"outboundpeer-1"}
time.Sleep(2000 * time.Millisecond)
pl := server.encodedPeerList()
if pl == nil {
t.Errorf("expect non-nil peer list")
}
if c := server.PeerCount(); c != 2 {
t.Errorf("expect 2 peers, got %v", c)
}
}

28
p2p/testlog_test.go Normal file
View File

@ -0,0 +1,28 @@
package p2p
import (
"testing"
"github.com/ethereum/go-ethereum/logger"
)
type testLogger struct{ t *testing.T }
func testlog(t *testing.T) testLogger {
logger.Reset()
l := testLogger{t}
logger.AddLogSystem(l)
return l
}
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
func (testLogger) SetLogLevel(logger.LogLevel) {}
func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
l.t.Logf("%s", msg)
}
func (testLogger) detach() {
logger.Flush()
logger.Reset()
}

40
p2p/testpoc7.go Normal file
View File

@ -0,0 +1,40 @@
// +build none
package main
import (
"fmt"
"log"
"net"
"os"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
"github.com/obscuren/secp256k1-go"
)
func main() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
pub, _ := secp256k1.GenerateKeyPair()
srv := p2p.Server{
MaxPeers: 10,
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
ListenAddr: ":30303",
NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
}
if err := srv.Start(); err != nil {
fmt.Println("could not start server:", err)
os.Exit(1)
}
// add seed peers
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
if err != nil {
fmt.Println("couldn't resolve:", err)
os.Exit(1)
}
srv.SuggestPeer(seed.IP, seed.Port, nil)
select {}
}