diff --git a/p2p/client_identity.go b/p2p/client_identity.go index 236b23106..bc865b63b 100644 --- a/p2p/client_identity.go +++ b/p2p/client_identity.go @@ -5,10 +5,10 @@ import ( "runtime" ) -// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc. +// ClientIdentity represents the identity of a peer. type ClientIdentity interface { - String() string - Pubkey() []byte + String() string // human readable identity + Pubkey() []byte // 512-bit public key } type SimpleClientIdentity struct { diff --git a/p2p/message.go b/p2p/message.go index 97d440a27..89ad189d7 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -11,8 +11,6 @@ import ( "github.com/ethereum/go-ethereum/ethutil" ) -type MsgCode uint64 - // Msg defines the structure of a p2p message. // // Note that a Msg can only be sent once since the Payload reader is @@ -21,13 +19,13 @@ type MsgCode uint64 // structure, encode the payload into a byte array and create a // separate Msg with a bytes.Reader as Payload for each send. type Msg struct { - Code MsgCode + Code uint64 Size uint32 // size of the paylod Payload io.Reader } // NewMsg creates an RLP-encoded message with the given code. -func NewMsg(code MsgCode, params ...interface{}) Msg { +func NewMsg(code uint64, params ...interface{}) Msg { buf := new(bytes.Buffer) for _, p := range params { buf.Write(ethutil.Encode(p)) @@ -63,6 +61,52 @@ func (msg Msg) Discard() error { return err } +type MsgReader interface { + ReadMsg() (Msg, error) +} + +type MsgWriter interface { + // WriteMsg sends an existing message. + // The Payload reader of the message is consumed. + // Note that messages can be sent only once. + WriteMsg(Msg) error + + // EncodeMsg writes an RLP-encoded message with the given + // code and data elements. + EncodeMsg(code uint64, data ...interface{}) error +} + +// MsgReadWriter provides reading and writing of encoded messages. +type MsgReadWriter interface { + MsgReader + MsgWriter +} + +// MsgLoop reads messages off the given reader and +// calls the handler function for each decoded message until +// it returns an error or the peer connection is closed. +// +// If a message is larger than the given maximum size, +// MsgLoop returns an appropriate error. +func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Value) error) error { + for { + msg, err := r.ReadMsg() + if err != nil { + return err + } + if msg.Size > maxsize { + return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize) + } + value, err := msg.Data() + if err != nil { + return err + } + if err := f(msg.Code, value); err != nil { + return err + } + } +} + var magicToken = []byte{34, 64, 8, 145} func writeMsg(w io.Writer, msg Msg) error { @@ -103,10 +147,10 @@ func readMsg(r byteReader) (msg Msg, err error) { // read magic and payload size start := make([]byte, 8) if _, err = io.ReadFull(r, start); err != nil { - return msg, NewPeerError(ReadError, "%v", err) + return msg, newPeerError(errRead, "%v", err) } if !bytes.HasPrefix(start, magicToken) { - return msg, NewPeerError(MagicTokenMismatch, "got %x, want %x", start[:4], magicToken) + return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken) } size := binary.BigEndian.Uint32(start[4:]) @@ -152,13 +196,13 @@ func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) { } // readUint reads an RLP-encoded unsigned integer from r. -func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) { +func readMsgCode(r byteReader) (code uint64, codelen uint32, err error) { b, err := r.ReadByte() if err != nil { return 0, 0, err } if b < 0x80 { - return MsgCode(b), 1, nil + return uint64(b), 1, nil } else if b < 0x89 { // max length for uint64 is 8 bytes codelen = uint32(b - 0x80) if codelen == 0 { @@ -168,7 +212,7 @@ func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) { if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil { return 0, 0, err } - return MsgCode(binary.BigEndian.Uint64(buf)), codelen, nil + return binary.BigEndian.Uint64(buf), codelen, nil } return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b) } diff --git a/p2p/messenger.go b/p2p/messenger.go deleted file mode 100644 index c7948a9ac..000000000 --- a/p2p/messenger.go +++ /dev/null @@ -1,221 +0,0 @@ -package p2p - -import ( - "bufio" - "bytes" - "fmt" - "io" - "io/ioutil" - "net" - "sync" - "time" -) - -type Handlers map[string]Protocol - -type proto struct { - in chan Msg - maxcode, offset MsgCode - messenger *messenger -} - -func (rw *proto) WriteMsg(msg Msg) error { - if msg.Code >= rw.maxcode { - return NewPeerError(InvalidMsgCode, "not handled") - } - msg.Code += rw.offset - return rw.messenger.writeMsg(msg) -} - -func (rw *proto) ReadMsg() (Msg, error) { - msg, ok := <-rw.in - if !ok { - return msg, io.EOF - } - msg.Code -= rw.offset - return msg, nil -} - -// eofSignal wraps a reader with eof signaling. -// the eof channel is closed when the wrapped reader -// reaches EOF. -type eofSignal struct { - wrapped io.Reader - eof chan struct{} -} - -func (r *eofSignal) Read(buf []byte) (int, error) { - n, err := r.wrapped.Read(buf) - if err != nil { - close(r.eof) // tell messenger that msg has been consumed - } - return n, err -} - -// messenger represents a message-oriented peer connection. -// It keeps track of the set of protocols understood -// by the remote peer. -type messenger struct { - peer *Peer - handlers Handlers - - // the mutex protects the connection - // so only one protocol can write at a time. - writeMu sync.Mutex - conn net.Conn - bufconn *bufio.ReadWriter - - protocolLock sync.RWMutex - protocols map[string]*proto - offsets map[MsgCode]*proto - protoWG sync.WaitGroup - - err chan error - pulse chan bool -} - -func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger { - return &messenger{ - conn: conn, - bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), - peer: peer, - handlers: handlers, - protocols: make(map[string]*proto), - err: errchan, - pulse: make(chan bool, 1), - } -} - -func (m *messenger) Start() { - m.protocols[""] = m.startProto(0, "", &baseProtocol{}) - go m.readLoop() -} - -func (m *messenger) Stop() { - m.conn.Close() - m.protoWG.Wait() -} - -const ( - // maximum amount of time allowed for reading a message - msgReadTimeout = 5 * time.Second - - // messages smaller than this many bytes will be read at - // once before passing them to a protocol. - wholePayloadSize = 64 * 1024 -) - -func (m *messenger) readLoop() { - defer m.closeProtocols() - for { - m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) - msg, err := readMsg(m.bufconn) - if err != nil { - m.err <- err - return - } - // send ping to heartbeat channel signalling time of last message - m.pulse <- true - proto, err := m.getProto(msg.Code) - if err != nil { - m.err <- err - return - } - if msg.Size <= wholePayloadSize { - // optimization: msg is small enough, read all - // of it and move on to the next message - buf, err := ioutil.ReadAll(msg.Payload) - if err != nil { - m.err <- err - return - } - msg.Payload = bytes.NewReader(buf) - proto.in <- msg - } else { - pr := &eofSignal{msg.Payload, make(chan struct{})} - msg.Payload = pr - proto.in <- msg - <-pr.eof - } - } -} - -func (m *messenger) closeProtocols() { - m.protocolLock.RLock() - for _, p := range m.protocols { - close(p.in) - } - m.protocolLock.RUnlock() -} - -func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto { - proto := &proto{ - in: make(chan Msg), - offset: offset, - maxcode: impl.Offset(), - messenger: m, - } - m.protoWG.Add(1) - go func() { - if err := impl.Start(m.peer, proto); err != nil && err != io.EOF { - logger.Errorf("protocol %q error: %v\n", name, err) - m.err <- err - } - m.protoWG.Done() - }() - return proto -} - -// getProto finds the protocol responsible for handling -// the given message code. -func (m *messenger) getProto(code MsgCode) (*proto, error) { - m.protocolLock.RLock() - defer m.protocolLock.RUnlock() - for _, proto := range m.protocols { - if code >= proto.offset && code < proto.offset+proto.maxcode { - return proto, nil - } - } - return nil, NewPeerError(InvalidMsgCode, "%d", code) -} - -// setProtocols starts all subprotocols shared with the -// remote peer. the protocols must be sorted alphabetically. -func (m *messenger) setRemoteProtocols(protocols []string) { - m.protocolLock.Lock() - defer m.protocolLock.Unlock() - offset := baseProtocolOffset - for _, name := range protocols { - inst, ok := m.handlers[name] - if !ok { - continue // not handled - } - m.protocols[name] = m.startProto(offset, name, inst) - offset += inst.Offset() - } -} - -// writeProtoMsg sends the given message on behalf of the given named protocol. -func (m *messenger) writeProtoMsg(protoName string, msg Msg) error { - m.protocolLock.RLock() - proto, ok := m.protocols[protoName] - m.protocolLock.RUnlock() - if !ok { - return fmt.Errorf("protocol %s not handled by peer", protoName) - } - if msg.Code >= proto.maxcode { - return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) - } - msg.Code += proto.offset - return m.writeMsg(msg) -} - -// writeMsg writes a message to the connection. -func (m *messenger) writeMsg(msg Msg) error { - m.writeMu.Lock() - defer m.writeMu.Unlock() - if err := writeMsg(m.bufconn, msg); err != nil { - return err - } - return m.bufconn.Flush() -} diff --git a/p2p/messenger_test.go b/p2p/messenger_test.go deleted file mode 100644 index 2264e10d3..000000000 --- a/p2p/messenger_test.go +++ /dev/null @@ -1,203 +0,0 @@ -package p2p - -import ( - "bufio" - "fmt" - "io" - "log" - "net" - "os" - "reflect" - "testing" - "time" - - logpkg "github.com/ethereum/go-ethereum/logger" -) - -func init() { - logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel)) -} - -func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) { - conn1, conn2 := net.Pipe() - id := NewSimpleClientIdentity("test", "0", "0", "public key") - server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist()) - peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0) - return conn2, peer, peer.messenger -} - -func performTestHandshake(r *bufio.Reader, w io.Writer) error { - // read remote handshake - msg, err := readMsg(r) - if err != nil { - return fmt.Errorf("read error: %v", err) - } - if msg.Code != handshakeMsg { - return fmt.Errorf("first message should be handshake, got %d", msg.Code) - } - if err := msg.Discard(); err != nil { - return err - } - // send empty handshake - pubkey := make([]byte, 64) - msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey) - return writeMsg(w, msg) -} - -type testProtocol struct { - offset MsgCode - f func(MsgReadWriter) -} - -func (p *testProtocol) Offset() MsgCode { - return p.offset -} - -func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error { - p.f(rw) - return nil -} - -func TestRead(t *testing.T) { - done := make(chan struct{}) - handlers := Handlers{ - "a": &testProtocol{5, func(rw MsgReadWriter) { - msg, err := rw.ReadMsg() - if err != nil { - t.Errorf("read error: %v", err) - } - if msg.Code != 2 { - t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) - } - data, err := msg.Data() - if err != nil { - t.Errorf("data decoding error: %v", err) - } - expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} - if !reflect.DeepEqual(data.Slice(), expdata) { - t.Errorf("incorrect msg data %#v", data.Slice()) - } - close(done) - }}, - } - - net, peer, m := testMessenger(handlers) - defer peer.Stop() - bufr := bufio.NewReader(net) - if err := performTestHandshake(bufr, net); err != nil { - t.Fatalf("handshake failed: %v", err) - } - m.setRemoteProtocols([]string{"a"}) - - writeMsg(net, NewMsg(18, 1, "000")) - select { - case <-done: - case <-time.After(2 * time.Second): - t.Errorf("receive timeout") - } -} - -func TestWriteFromProto(t *testing.T) { - handlers := Handlers{ - "a": &testProtocol{2, func(rw MsgReadWriter) { - if err := rw.WriteMsg(NewMsg(2)); err == nil { - t.Error("expected error for out-of-range msg code, got nil") - } - if err := rw.WriteMsg(NewMsg(1)); err != nil { - t.Errorf("write error: %v", err) - } - }}, - } - net, peer, mess := testMessenger(handlers) - defer peer.Stop() - bufr := bufio.NewReader(net) - if err := performTestHandshake(bufr, net); err != nil { - t.Fatalf("handshake failed: %v", err) - } - mess.setRemoteProtocols([]string{"a"}) - - msg, err := readMsg(bufr) - if err != nil { - t.Errorf("read error: %v") - } - if msg.Code != 17 { - t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) - } -} - -var discardProto = &testProtocol{1, func(rw MsgReadWriter) { - for { - msg, err := rw.ReadMsg() - if err != nil { - return - } - if err = msg.Discard(); err != nil { - return - } - } -}} - -func TestMessengerWriteProtoMsg(t *testing.T) { - handlers := Handlers{"a": discardProto} - net, peer, mess := testMessenger(handlers) - defer peer.Stop() - bufr := bufio.NewReader(net) - if err := performTestHandshake(bufr, net); err != nil { - t.Fatalf("handshake failed: %v", err) - } - mess.setRemoteProtocols([]string{"a"}) - - // test write errors - if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil { - t.Errorf("expected error for unknown protocol, got nil") - } - if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil { - t.Errorf("expected error for out-of-range msg code, got nil") - } else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode { - t.Errorf("wrong error for out-of-range msg code, got %#v") - } - - // test succcessful write - read, readerr := make(chan Msg), make(chan error) - go func() { - if msg, err := readMsg(bufr); err != nil { - readerr <- err - } else { - read <- msg - } - }() - if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil { - t.Errorf("expect no error for known protocol: %v", err) - } - select { - case msg := <-read: - if msg.Code != 16 { - t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) - } - msg.Discard() - case err := <-readerr: - t.Errorf("read error: %v", err) - } -} - -func TestPulse(t *testing.T) { - net, peer, _ := testMessenger(nil) - defer peer.Stop() - bufr := bufio.NewReader(net) - if err := performTestHandshake(bufr, net); err != nil { - t.Fatalf("handshake failed: %v", err) - } - - before := time.Now() - msg, err := readMsg(bufr) - if err != nil { - t.Fatalf("read error: %v", err) - } - after := time.Now() - if msg.Code != pingMsg { - t.Errorf("expected ping message, got %d", msg.Code) - } - if d := after.Sub(before); d < pingTimeout { - t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout) - } -} diff --git a/p2p/natpmp.go b/p2p/natpmp.go index ff966d070..6714678c4 100644 --- a/p2p/natpmp.go +++ b/p2p/natpmp.go @@ -3,6 +3,7 @@ package p2p import ( "fmt" "net" + "time" natpmp "github.com/jackpal/go-nat-pmp" ) @@ -13,38 +14,37 @@ import ( // + Register for changes to the external address. // + Re-register port mapping when router reboots. // + A mechanism for keeping a port mapping registered. +// + Discover gateway address automatically. type natPMPClient struct { client *natpmp.Client } -func NewNatPMP(gateway net.IP) (nat NAT) { +// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway +// address should be the IP of your router. +func PMP(gateway net.IP) (nat NAT) { return &natPMPClient{natpmp.NewClient(gateway)} } -func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) { - response, err := n.client.GetExternalAddress() - if err != nil { - return - } - ip := response.ExternalIPAddress - addr = net.IPv4(ip[0], ip[1], ip[2], ip[3]) - return +func (*natPMPClient) String() string { + return "NAT-PMP" } -func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int, - description string, timeout int) (mappedExternalPort int, err error) { - if timeout <= 0 { - err = fmt.Errorf("timeout must not be <= 0") - return +func (n *natPMPClient) GetExternalAddress() (net.IP, error) { + response, err := n.client.GetExternalAddress() + if err != nil { + return nil, err + } + return response.ExternalIPAddress[:], nil +} + +func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error { + if lifetime <= 0 { + return fmt.Errorf("lifetime must not be <= 0") } // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping. - response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout) - if err != nil { - return - } - mappedExternalPort = int(response.MappedExternalPort) - return + _, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second)) + return err } func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { diff --git a/p2p/natupnp.go b/p2p/natupnp.go index fa9798d4d..2e0d8ce8d 100644 --- a/p2p/natupnp.go +++ b/p2p/natupnp.go @@ -7,6 +7,7 @@ import ( "bytes" "encoding/xml" "errors" + "fmt" "net" "net/http" "os" @@ -15,28 +16,46 @@ import ( "time" ) +const ( + upnpDiscoverAttempts = 3 + upnpDiscoverTimeout = 5 * time.Second +) + +// UPNP returns a NAT port mapper that uses UPnP. It will attempt to +// discover the address of your router using UDP broadcasts. +func UPNP() NAT { + return &upnpNAT{} +} + type upnpNAT struct { serviceURL string ourIP string } -func upnpDiscover(attempts int) (nat NAT, err error) { +func (n *upnpNAT) String() string { + return "UPNP" +} + +func (n *upnpNAT) discover() error { + if n.serviceURL != "" { + // already discovered + return nil + } + ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") if err != nil { - return + return err } + // TODO: try on all network interfaces simultaneously. + // Broadcasting on 0.0.0.0 could select a random interface + // to send on (platform specific). conn, err := net.ListenPacket("udp4", ":0") if err != nil { - return - } - socket := conn.(*net.UDPConn) - defer socket.Close() - - err = socket.SetDeadline(time.Now().Add(10 * time.Second)) - if err != nil { - return + return err } + defer conn.Close() + conn.SetDeadline(time.Now().Add(10 * time.Second)) st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" buf := bytes.NewBufferString( "M-SEARCH * HTTP/1.1\r\n" + @@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) { "MX: 2\r\n\r\n") message := buf.Bytes() answerBytes := make([]byte, 1024) - for i := 0; i < attempts; i++ { - _, err = socket.WriteToUDP(message, ssdp) + for i := 0; i < upnpDiscoverAttempts; i++ { + _, err = conn.WriteTo(message, ssdp) if err != nil { - return + return err } - var n int - n, _, err = socket.ReadFromUDP(answerBytes) + nn, _, err := conn.ReadFrom(answerBytes) if err != nil { continue - // socket.Close() - // return } - answer := string(answerBytes[0:n]) + answer := string(answerBytes[0:nn]) if strings.Index(answer, "\r\n"+st) < 0 { continue } @@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) { var serviceURL string serviceURL, err = getServiceURL(locURL) if err != nil { - return + return err } var ourIP string ourIP, err = getOurIP() if err != nil { - return + return err } - nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP} + n.serviceURL = serviceURL + n.ourIP = ourIP + return nil + } + return errors.New("UPnP port discovery failed.") +} + +func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { + if err := n.discover(); err != nil { + return nil, err + } + info, err := n.getStatusInfo() + return net.ParseIP(info.externalIpAddress), err +} + +func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error { + if err := n.discover(); err != nil { + return err + } + + // A single concatenation would break ARM compilation. + message := "\r\n" + + "" + strconv.Itoa(extport) + message += "" + protocol + "" + message += "" + strconv.Itoa(extport) + "" + + "" + n.ourIP + "" + + "1" + message += description + + "" + fmt.Sprint(lifetime/time.Second) + + "" + + // TODO: check response to see if the port was forwarded + _, err := soapRequest(n.serviceURL, "AddPortMapping", message) + return err +} + +func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error { + if err := n.discover(); err != nil { + return err + } + + message := "\r\n" + + "" + strconv.Itoa(externalPort) + + "" + protocol + "" + + "" + + // TODO: check response to see if the port was deleted + _, err := soapRequest(n.serviceURL, "DeletePortMapping", message) + return err +} + +type statusInfo struct { + externalIpAddress string +} + +func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { + message := "\r\n" + + "" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) + if err != nil { return } - err = errors.New("UPnP port discovery failed.") + + // TODO: Write a soap reply parser. It has to eat the Body and envelope tags... + + response.Body.Close() return } @@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) { } return } - -type statusInfo struct { - externalIpAddress string -} - -func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { - - message := "\r\n" + - "" - - var response *http.Response - response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) - if err != nil { - return - } - - // TODO: Write a soap reply parser. It has to eat the Body and envelope tags... - - response.Body.Close() - return -} - -func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { - info, err := n.getStatusInfo() - if err != nil { - return - } - addr = net.ParseIP(info.externalIpAddress) - return -} - -func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { - // A single concatenation would break ARM compilation. - message := "\r\n" + - "" + strconv.Itoa(externalPort) - message += "" + protocol + "" - message += "" + strconv.Itoa(internalPort) + "" + - "" + n.ourIP + "" + - "1" - message += description + - "" + strconv.Itoa(timeout) + - "" - - var response *http.Response - response, err = soapRequest(n.serviceURL, "AddPortMapping", message) - if err != nil { - return - } - - // TODO: check response to see if the port was forwarded - // log.Println(message, response) - mappedExternalPort = externalPort - _ = response - return -} - -func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { - - message := "\r\n" + - "" + strconv.Itoa(externalPort) + - "" + protocol + "" + - "" - - var response *http.Response - response, err = soapRequest(n.serviceURL, "DeletePortMapping", message) - if err != nil { - return - } - - // TODO: check response to see if the port was deleted - // log.Println(message, response) - _ = response - return -} diff --git a/p2p/network.go b/p2p/network.go deleted file mode 100644 index 820cef1a9..000000000 --- a/p2p/network.go +++ /dev/null @@ -1,196 +0,0 @@ -package p2p - -import ( - "fmt" - "math/rand" - "net" - "strconv" - "time" -) - -const ( - DialerTimeout = 180 //seconds - KeepAlivePeriod = 60 //minutes - portMappingUpdateInterval = 900 // seconds = 15 mins - upnpDiscoverAttempts = 3 -) - -// Dialer is not an interface in net, so we define one -// *net.Dialer conforms to this -type Dialer interface { - Dial(network, address string) (net.Conn, error) -} - -type Network interface { - Start() error - Listener(net.Addr) (net.Listener, error) - Dialer(net.Addr) (Dialer, error) - NewAddr(string, int) (addr net.Addr, err error) - ParseAddr(string) (addr net.Addr, err error) -} - -type NAT interface { - GetExternalAddress() (addr net.IP, err error) - AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) - DeletePortMapping(protocol string, externalPort, internalPort int) (err error) -} - -type TCPNetwork struct { - nat NAT - natType NATType - quit chan chan bool - ports chan string -} - -type NATType int - -const ( - NONE = iota - UPNP - PMP -) - -const ( - portMappingTimeout = 1200 // 20 mins -) - -func NewTCPNetwork(natType NATType) (net *TCPNetwork) { - return &TCPNetwork{ - natType: natType, - ports: make(chan string), - } -} - -func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) { - return &net.Dialer{ - Timeout: DialerTimeout * time.Second, - // KeepAlive: KeepAlivePeriod * time.Minute, - LocalAddr: addr, - }, nil -} - -func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) { - if self.natType == UPNP { - _, port, _ := net.SplitHostPort(addr.String()) - if self.quit == nil { - self.quit = make(chan chan bool) - go self.updatePortMappings() - } - self.ports <- port - } - return net.Listen(addr.Network(), addr.String()) -} - -func (self *TCPNetwork) Start() (err error) { - switch self.natType { - case NONE: - case UPNP: - nat, uerr := upnpDiscover(upnpDiscoverAttempts) - if uerr != nil { - err = fmt.Errorf("UPNP failed: ", uerr) - } else { - self.nat = nat - } - case PMP: - err = fmt.Errorf("PMP not implemented") - default: - err = fmt.Errorf("Invalid NAT type: %v", self.natType) - } - return -} - -func (self *TCPNetwork) Stop() { - q := make(chan bool) - self.quit <- q - <-q -} - -func (self *TCPNetwork) addPortMapping(lport int) (err error) { - _, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout) - if err != nil { - logger.Errorf("unable to add port mapping on %v: %v", lport, err) - } else { - logger.Debugf("succesfully added port mapping on %v", lport) - } - return -} - -func (self *TCPNetwork) updatePortMappings() { - timer := time.NewTimer(portMappingUpdateInterval * time.Second) - lports := []int{} -out: - for { - select { - case port := <-self.ports: - int64lport, _ := strconv.ParseInt(port, 10, 16) - lport := int(int64lport) - if err := self.addPortMapping(lport); err != nil { - lports = append(lports, lport) - } - case <-timer.C: - for lport := range lports { - if err := self.addPortMapping(lport); err != nil { - } - } - case errc := <-self.quit: - errc <- true - break out - } - } - - timer.Stop() - for lport := range lports { - if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil { - logger.Debugf("unable to remove port mapping on %v: %v", lport, err) - } else { - logger.Debugf("succesfully removed port mapping on %v", lport) - } - } -} - -func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) { - ip, err := self.lookupIP(host) - if err == nil { - return &net.TCPAddr{ - IP: ip, - Port: port, - }, nil - } - return nil, err -} - -func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) { - host, port, err := net.SplitHostPort(address) - if err == nil { - iport, _ := strconv.Atoi(port) - addr, e := self.NewAddr(host, iport) - return addr, e - } - return nil, err -} - -func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) { - if ip = net.ParseIP(host); ip != nil { - return - } - - var ips []net.IP - ips, err = net.LookupIP(host) - if err != nil { - logger.Warnln(err) - return - } - if len(ips) == 0 { - err = fmt.Errorf("No IP addresses available for %v", host) - logger.Warnln(err) - return - } - if len(ips) > 1 { - // Pick a random IP address, simulating round-robin DNS. - rand.Seed(time.Now().UTC().UnixNano()) - ip = ips[rand.Intn(len(ips))] - } else { - ip = ips[0] - } - return -} diff --git a/p2p/peer.go b/p2p/peer.go index 34b6152a3..238d3d9c9 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -1,66 +1,454 @@ package p2p import ( + "bufio" + "bytes" "fmt" + "io" + "io/ioutil" "net" - "strconv" + "sort" + "sync" + "time" + + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/logger" ) -type Peer struct { - Inbound bool // inbound (via listener) or outbound (via dialout) - Address net.Addr - Host []byte - Port uint16 - Pubkey []byte - Id string - Caps []string - peerErrorChan chan error - messenger *messenger - peerErrorHandler *PeerErrorHandler - server *Server +// peerAddr is the structure of a peer list element. +// It is also a valid net.Addr. +type peerAddr struct { + IP net.IP + Port uint64 + Pubkey []byte // optional } -func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer { - peerErrorChan := NewPeerErrorChannel() - host, port, _ := net.SplitHostPort(address.String()) - intport, _ := strconv.Atoi(port) - peer := &Peer{ - Inbound: inbound, - Address: address, - Port: uint16(intport), - Host: net.ParseIP(host), - peerErrorChan: peerErrorChan, - server: server, +func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr { + n := addr.Network() + if n != "tcp" && n != "tcp4" && n != "tcp6" { + // for testing with non-TCP + return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey} } - peer.messenger = newMessenger(peer, conn, peerErrorChan, server.Handlers()) - peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan) + ta := addr.(*net.TCPAddr) + return &peerAddr{ta.IP, uint64(ta.Port), pubkey} +} + +func (d peerAddr) Network() string { + if d.IP.To4() != nil { + return "tcp4" + } else { + return "tcp6" + } +} + +func (d peerAddr) String() string { + return fmt.Sprintf("%v:%d", d.IP, d.Port) +} + +func (d peerAddr) RlpData() interface{} { + return []interface{}{d.IP, d.Port, d.Pubkey} +} + +// Peer represents a remote peer. +type Peer struct { + // Peers have all the log methods. + // Use them to display messages related to the peer. + *logger.Logger + + infolock sync.Mutex + identity ClientIdentity + caps []Cap + listenAddr *peerAddr // what remote peer is listening on + dialAddr *peerAddr // non-nil if dialing + + // The mutex protects the connection + // so only one protocol can write at a time. + writeMu sync.Mutex + conn net.Conn + bufconn *bufio.ReadWriter + + // These fields maintain the running protocols. + protocols []Protocol + runBaseProtocol bool // for testing + + runlock sync.RWMutex // protects running + running map[string]*proto + + protoWG sync.WaitGroup + protoErr chan error + closed chan struct{} + disc chan DiscReason + + activity event.TypeMux // for activity events + + slot int // index into Server peer list + + // These fields are kept so base protocol can access them. + // TODO: this should be one or more interfaces + ourID ClientIdentity // client id of the Server + ourListenAddr *peerAddr // listen addr of Server, nil if not listening + newPeerAddr chan<- *peerAddr // tell server about received peers + otherPeers func() []*Peer // should return the list of all peers + pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey +} + +// NewPeer returns a peer for testing purposes. +func NewPeer(id ClientIdentity, caps []Cap) *Peer { + conn, _ := net.Pipe() + peer := newPeer(conn, nil, nil) + peer.setHandshakeInfo(id, nil, caps) return peer } -func (self *Peer) String() string { - var kind string - if self.Inbound { - kind = "inbound" - } else { +func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer { + p := newPeer(conn, server.Protocols, dialAddr) + p.ourID = server.Identity + p.newPeerAddr = server.peerConnect + p.otherPeers = server.Peers + p.pubkeyHook = server.verifyPeer + p.runBaseProtocol = true + + // laddr can be updated concurrently by NAT traversal. + // newServerPeer must be called with the server lock held. + if server.laddr != nil { + p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey()) + } + return p +} + +func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer { + p := &Peer{ + Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()), + conn: conn, + dialAddr: dialAddr, + bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), + protocols: protocols, + running: make(map[string]*proto), + disc: make(chan DiscReason), + protoErr: make(chan error), + closed: make(chan struct{}), + } + return p +} + +// Identity returns the client identity of the remote peer. The +// identity can be nil if the peer has not yet completed the +// handshake. +func (p *Peer) Identity() ClientIdentity { + p.infolock.Lock() + defer p.infolock.Unlock() + return p.identity +} + +// Caps returns the capabilities (supported subprotocols) of the remote peer. +func (p *Peer) Caps() []Cap { + p.infolock.Lock() + defer p.infolock.Unlock() + return p.caps +} + +func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) { + p.infolock.Lock() + p.identity = id + p.listenAddr = laddr + p.caps = caps + p.infolock.Unlock() +} + +// RemoteAddr returns the remote address of the network connection. +func (p *Peer) RemoteAddr() net.Addr { + return p.conn.RemoteAddr() +} + +// LocalAddr returns the local address of the network connection. +func (p *Peer) LocalAddr() net.Addr { + return p.conn.LocalAddr() +} + +// Disconnect terminates the peer connection with the given reason. +// It returns immediately and does not wait until the connection is closed. +func (p *Peer) Disconnect(reason DiscReason) { + select { + case p.disc <- reason: + case <-p.closed: + } +} + +// String implements fmt.Stringer. +func (p *Peer) String() string { + kind := "inbound" + p.infolock.Lock() + if p.dialAddr != nil { kind = "outbound" } - return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps) + p.infolock.Unlock() + return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind) } -func (self *Peer) Write(protocol string, msg Msg) error { - return self.messenger.writeProtoMsg(protocol, msg) +const ( + // maximum amount of time allowed for reading a message + msgReadTimeout = 5 * time.Second + // maximum amount of time allowed for writing a message + msgWriteTimeout = 5 * time.Second + // messages smaller than this many bytes will be read at + // once before passing them to a protocol. + wholePayloadSize = 64 * 1024 +) + +var ( + inactivityTimeout = 2 * time.Second + disconnectGracePeriod = 2 * time.Second +) + +func (p *Peer) loop() (reason DiscReason, err error) { + defer p.activity.Stop() + defer p.closeProtocols() + defer close(p.closed) + defer p.conn.Close() + + // read loop + readMsg := make(chan Msg) + readErr := make(chan error) + readNext := make(chan bool, 1) + protoDone := make(chan struct{}, 1) + go p.readLoop(readMsg, readErr, readNext) + readNext <- true + + if p.runBaseProtocol { + p.startBaseProtocol() + } + +loop: + for { + select { + case msg := <-readMsg: + // a new message has arrived. + var wait bool + if wait, err = p.dispatch(msg, protoDone); err != nil { + p.Errorf("msg dispatch error: %v\n", err) + reason = discReasonForError(err) + break loop + } + if !wait { + // Msg has already been read completely, continue with next message. + readNext <- true + } + p.activity.Post(time.Now()) + case <-protoDone: + // protocol has consumed the message payload, + // we can continue reading from the socket. + readNext <- true + + case err := <-readErr: + // read failed. there is no need to run the + // polite disconnect sequence because the connection + // is probably dead anyway. + // TODO: handle write errors as well + return DiscNetworkError, err + case err = <-p.protoErr: + reason = discReasonForError(err) + break loop + case reason = <-p.disc: + break loop + } + } + + // wait for read loop to return. + close(readNext) + <-readErr + // tell the remote end to disconnect + done := make(chan struct{}) + go func() { + p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod)) + p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod) + io.Copy(ioutil.Discard, p.conn) + close(done) + }() + select { + case <-done: + case <-time.After(disconnectGracePeriod): + } + return reason, err } -func (self *Peer) Start() { - self.peerErrorHandler.Start() - self.messenger.Start() +func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) { + for _ = range unblock { + p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) + if msg, err := readMsg(p.bufconn); err != nil { + errc <- err + } else { + msgc <- msg + } + } + close(errc) } -func (self *Peer) Stop() { - self.peerErrorHandler.Stop() - self.messenger.Stop() +func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) { + proto, err := p.getProto(msg.Code) + if err != nil { + return false, err + } + if msg.Size <= wholePayloadSize { + // optimization: msg is small enough, read all + // of it and move on to the next message + buf, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return false, err + } + msg.Payload = bytes.NewReader(buf) + proto.in <- msg + } else { + wait = true + pr := &eofSignal{msg.Payload, protoDone} + msg.Payload = pr + proto.in <- msg + } + return wait, nil } -func (p *Peer) Encode() []interface{} { - return []interface{}{p.Host, p.Port, p.Pubkey} +func (p *Peer) startBaseProtocol() { + p.runlock.Lock() + defer p.runlock.Unlock() + p.running[""] = p.startProto(0, Protocol{ + Length: baseProtocolLength, + Run: runBaseProtocol, + }) +} + +// startProtocols starts matching named subprotocols. +func (p *Peer) startSubprotocols(caps []Cap) { + sort.Sort(capsByName(caps)) + + p.runlock.Lock() + defer p.runlock.Unlock() + offset := baseProtocolLength +outer: + for _, cap := range caps { + for _, proto := range p.protocols { + if proto.Name == cap.Name && + proto.Version == cap.Version && + p.running[cap.Name] == nil { + p.running[cap.Name] = p.startProto(offset, proto) + offset += proto.Length + continue outer + } + } + } +} + +func (p *Peer) startProto(offset uint64, impl Protocol) *proto { + rw := &proto{ + in: make(chan Msg), + offset: offset, + maxcode: impl.Length, + peer: p, + } + p.protoWG.Add(1) + go func() { + err := impl.Run(p, rw) + if err == nil { + p.Infof("protocol %q returned", impl.Name) + err = newPeerError(errMisc, "protocol returned") + } else { + p.Errorf("protocol %q error: %v\n", impl.Name, err) + } + select { + case p.protoErr <- err: + case <-p.closed: + } + p.protoWG.Done() + }() + return rw +} + +// getProto finds the protocol responsible for handling +// the given message code. +func (p *Peer) getProto(code uint64) (*proto, error) { + p.runlock.RLock() + defer p.runlock.RUnlock() + for _, proto := range p.running { + if code >= proto.offset && code < proto.offset+proto.maxcode { + return proto, nil + } + } + return nil, newPeerError(errInvalidMsgCode, "%d", code) +} + +func (p *Peer) closeProtocols() { + p.runlock.RLock() + for _, p := range p.running { + close(p.in) + } + p.runlock.RUnlock() + p.protoWG.Wait() +} + +// writeProtoMsg sends the given message on behalf of the given named protocol. +func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { + p.runlock.RLock() + proto, ok := p.running[protoName] + p.runlock.RUnlock() + if !ok { + return fmt.Errorf("protocol %s not handled by peer", protoName) + } + if msg.Code >= proto.maxcode { + return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) + } + msg.Code += proto.offset + return p.writeMsg(msg, msgWriteTimeout) +} + +// writeMsg writes a message to the connection. +func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error { + p.writeMu.Lock() + defer p.writeMu.Unlock() + p.conn.SetWriteDeadline(time.Now().Add(timeout)) + if err := writeMsg(p.bufconn, msg); err != nil { + return newPeerError(errWrite, "%v", err) + } + return p.bufconn.Flush() +} + +type proto struct { + name string + in chan Msg + maxcode, offset uint64 + peer *Peer +} + +func (rw *proto) WriteMsg(msg Msg) error { + if msg.Code >= rw.maxcode { + return newPeerError(errInvalidMsgCode, "not handled") + } + msg.Code += rw.offset + return rw.peer.writeMsg(msg, msgWriteTimeout) +} + +func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { + return rw.WriteMsg(NewMsg(code, data)) +} + +func (rw *proto) ReadMsg() (Msg, error) { + msg, ok := <-rw.in + if !ok { + return msg, io.EOF + } + msg.Code -= rw.offset + return msg, nil +} + +// eofSignal wraps a reader with eof signaling. +// the eof channel is closed when the wrapped reader +// reaches EOF. +type eofSignal struct { + wrapped io.Reader + eof chan<- struct{} +} + +func (r *eofSignal) Read(buf []byte) (int, error) { + n, err := r.wrapped.Read(buf) + if err != nil { + r.eof <- struct{}{} // tell Peer that msg has been consumed + } + return n, err } diff --git a/p2p/peer_error.go b/p2p/peer_error.go index f3ef98d98..88b870fbd 100644 --- a/p2p/peer_error.go +++ b/p2p/peer_error.go @@ -4,71 +4,121 @@ import ( "fmt" ) -type ErrorCode int - -const errorChanCapacity = 10 - const ( - PacketTooLong = iota - PayloadTooShort - MagicTokenMismatch - ReadError - WriteError - MiscError - InvalidMsgCode - InvalidMsg - P2PVersionMismatch - PubkeyMissing - PubkeyInvalid - PubkeyForbidden - ProtocolBreach - PortMismatch - PingTimeout - InvalidGenesis - InvalidNetworkId - InvalidProtocolVersion + errMagicTokenMismatch = iota + errRead + errWrite + errMisc + errInvalidMsgCode + errInvalidMsg + errP2PVersionMismatch + errPubkeyMissing + errPubkeyInvalid + errPubkeyForbidden + errProtocolBreach + errPingTimeout + errInvalidNetworkId + errInvalidProtocolVersion ) -var errorToString = map[ErrorCode]string{ - PacketTooLong: "Packet too long", - PayloadTooShort: "Payload too short", - MagicTokenMismatch: "Magic token mismatch", - ReadError: "Read error", - WriteError: "Write error", - MiscError: "Misc error", - InvalidMsgCode: "Invalid message code", - InvalidMsg: "Invalid message", - P2PVersionMismatch: "P2P Version Mismatch", - PubkeyMissing: "Public key missing", - PubkeyInvalid: "Public key invalid", - PubkeyForbidden: "Public key forbidden", - ProtocolBreach: "Protocol Breach", - PortMismatch: "Port mismatch", - PingTimeout: "Ping timeout", - InvalidGenesis: "Invalid genesis block", - InvalidNetworkId: "Invalid network id", - InvalidProtocolVersion: "Invalid protocol version", +var errorToString = map[int]string{ + errMagicTokenMismatch: "Magic token mismatch", + errRead: "Read error", + errWrite: "Write error", + errMisc: "Misc error", + errInvalidMsgCode: "Invalid message code", + errInvalidMsg: "Invalid message", + errP2PVersionMismatch: "P2P Version Mismatch", + errPubkeyMissing: "Public key missing", + errPubkeyInvalid: "Public key invalid", + errPubkeyForbidden: "Public key forbidden", + errProtocolBreach: "Protocol Breach", + errPingTimeout: "Ping timeout", + errInvalidNetworkId: "Invalid network id", + errInvalidProtocolVersion: "Invalid protocol version", } -type PeerError struct { - Code ErrorCode +type peerError struct { + Code int message string } -func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError { +func newPeerError(code int, format string, v ...interface{}) *peerError { desc, ok := errorToString[code] if !ok { panic("invalid error code") } - format = desc + ": " + format - message := fmt.Sprintf(format, v...) - return &PeerError{code, message} + err := &peerError{code, desc} + if format != "" { + err.message += ": " + fmt.Sprintf(format, v...) + } + return err } -func (self *PeerError) Error() string { +func (self *peerError) Error() string { return self.message } -func NewPeerErrorChannel() chan error { - return make(chan error, errorChanCapacity) +type DiscReason byte + +const ( + DiscRequested DiscReason = 0x00 + DiscNetworkError = 0x01 + DiscProtocolError = 0x02 + DiscUselessPeer = 0x03 + DiscTooManyPeers = 0x04 + DiscAlreadyConnected = 0x05 + DiscIncompatibleVersion = 0x06 + DiscInvalidIdentity = 0x07 + DiscQuitting = 0x08 + DiscUnexpectedIdentity = 0x09 + DiscSelf = 0x0a + DiscReadTimeout = 0x0b + DiscSubprotocolError = 0x10 +) + +var discReasonToString = [DiscSubprotocolError + 1]string{ + DiscRequested: "Disconnect requested", + DiscNetworkError: "Network error", + DiscProtocolError: "Breach of protocol", + DiscUselessPeer: "Useless peer", + DiscTooManyPeers: "Too many peers", + DiscAlreadyConnected: "Already connected", + DiscIncompatibleVersion: "Incompatible P2P protocol version", + DiscInvalidIdentity: "Invalid node identity", + DiscQuitting: "Client quitting", + DiscUnexpectedIdentity: "Unexpected identity", + DiscSelf: "Connected to self", + DiscReadTimeout: "Read timeout", + DiscSubprotocolError: "Subprotocol error", +} + +func (d DiscReason) String() string { + if len(discReasonToString) < int(d) { + return fmt.Sprintf("Unknown Reason(%d)", d) + } + return discReasonToString[d] +} + +func discReasonForError(err error) DiscReason { + peerError, ok := err.(*peerError) + if !ok { + return DiscSubprotocolError + } + switch peerError.Code { + case errP2PVersionMismatch: + return DiscIncompatibleVersion + case errPubkeyMissing, errPubkeyInvalid: + return DiscInvalidIdentity + case errPubkeyForbidden: + return DiscUselessPeer + case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach: + return DiscProtocolError + case errPingTimeout: + return DiscReadTimeout + case errRead, errWrite, errMisc: + return DiscNetworkError + default: + return DiscSubprotocolError + } } diff --git a/p2p/peer_error_handler.go b/p2p/peer_error_handler.go deleted file mode 100644 index 47dcd14ff..000000000 --- a/p2p/peer_error_handler.go +++ /dev/null @@ -1,98 +0,0 @@ -package p2p - -import ( - "net" -) - -const ( - severityThreshold = 10 -) - -type DisconnectRequest struct { - addr net.Addr - reason DiscReason -} - -type PeerErrorHandler struct { - quit chan chan bool - address net.Addr - peerDisconnect chan DisconnectRequest - severity int - errc chan error -} - -func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, errc chan error) *PeerErrorHandler { - return &PeerErrorHandler{ - quit: make(chan chan bool), - address: address, - peerDisconnect: peerDisconnect, - errc: errc, - } -} - -func (self *PeerErrorHandler) Start() { - go self.listen() -} - -func (self *PeerErrorHandler) Stop() { - q := make(chan bool) - self.quit <- q - <-q -} - -func (self *PeerErrorHandler) listen() { - for { - select { - case err, ok := <-self.errc: - if ok { - logger.Debugf("error %v\n", err) - go self.handle(err) - } else { - return - } - case q := <-self.quit: - q <- true - return - } - } -} - -func (self *PeerErrorHandler) handle(err error) { - reason := DiscReason(' ') - peerError, ok := err.(*PeerError) - if !ok { - peerError = NewPeerError(MiscError, " %v", err) - } - switch peerError.Code { - case P2PVersionMismatch: - reason = DiscIncompatibleVersion - case PubkeyMissing, PubkeyInvalid: - reason = DiscInvalidIdentity - case PubkeyForbidden: - reason = DiscUselessPeer - case InvalidMsgCode, PacketTooLong, PayloadTooShort, MagicTokenMismatch, ProtocolBreach: - reason = DiscProtocolError - case PingTimeout: - reason = DiscReadTimeout - case ReadError, WriteError, MiscError: - reason = DiscNetworkError - case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion: - reason = DiscSubprotocolError - default: - self.severity += self.getSeverity(peerError) - } - - if self.severity >= severityThreshold { - reason = DiscSubprotocolError - } - if reason != DiscReason(' ') { - self.peerDisconnect <- DisconnectRequest{ - addr: self.address, - reason: reason, - } - } -} - -func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int { - return 1 -} diff --git a/p2p/peer_error_handler_test.go b/p2p/peer_error_handler_test.go deleted file mode 100644 index b93252f6a..000000000 --- a/p2p/peer_error_handler_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package p2p - -import ( - // "fmt" - "net" - "testing" - "time" -) - -func TestPeerErrorHandler(t *testing.T) { - address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303} - peerDisconnect := make(chan DisconnectRequest) - peerErrorChan := NewPeerErrorChannel() - peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan) - peh.Start() - defer peh.Stop() - for i := 0; i < 11; i++ { - select { - case <-peerDisconnect: - t.Errorf("expected no disconnect request") - default: - } - peerErrorChan <- NewPeerError(MiscError, "") - } - time.Sleep(1 * time.Millisecond) - select { - case request := <-peerDisconnect: - if request.addr.String() != address.String() { - t.Errorf("incorrect address %v != %v", request.addr, address) - } - default: - t.Errorf("expected disconnect request") - } -} diff --git a/p2p/peer_test.go b/p2p/peer_test.go index da62cc380..1afa0ab17 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -1,90 +1,222 @@ package p2p -// "net" +import ( + "bufio" + "net" + "reflect" + "testing" + "time" +) -// func TestPeer(t *testing.T) { -// handlers := make(Handlers) -// testProtocol := &TestProtocol{recv: make(chan testMsg)} -// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } -// handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } -// addr := &TestAddr{"test:30"} -// conn := NewTestNetworkConnection(addr) -// _, server := SetupTestServer(handlers) -// server.Handshake() -// peer := NewPeer(conn, addr, true, server) -// // peer.Messenger().AddProtocols([]string{"aaa", "ccc"}) -// peer.Start() -// defer peer.Stop() -// time.Sleep(2 * time.Millisecond) -// if len(conn.Out) != 1 { -// t.Errorf("handshake not sent") -// } else { -// out := conn.Out[0] -// packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:]) -// if bytes.Compare(out, packet) != 0 { -// t.Errorf("incorrect handshake packet %v != %v", out, packet) -// } -// } +var discard = Protocol{ + Name: "discard", + Length: 1, + Run: func(p *Peer, rw MsgReadWriter) error { + for { + msg, err := rw.ReadMsg() + if err != nil { + return err + } + if err = msg.Discard(); err != nil { + return err + } + } + }, +} -// packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) -// conn.In(0, packet) -// time.Sleep(10 * time.Millisecond) +func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { + conn1, conn2 := net.Pipe() + id := NewSimpleClientIdentity("test", "0", "0", "public key") + peer := newPeer(conn1, protos, nil) + peer.ourID = id + peer.pubkeyHook = func(*peerAddr) error { return nil } + errc := make(chan error, 1) + go func() { + _, err := peer.loop() + errc <- err + }() + return conn2, peer, errc +} -// pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) -// if pro.state != handshakeReceived { -// t.Errorf("handshake not received") -// } -// if peer.Port != 30 { -// t.Errorf("port incorrectly set") -// } -// if peer.Id != "peer" { -// t.Errorf("id incorrectly set") -// } -// if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" { -// t.Errorf("pubkey incorrectly set") -// } -// fmt.Println(peer.Caps) -// if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" { -// t.Errorf("protocols incorrectly set") -// } +func TestPeerProtoReadMsg(t *testing.T) { + defer testlog(t).detach() -// msg := NewMsg(3) -// err := peer.Write("aaa", msg) -// if err != nil { -// t.Errorf("expect no error for known protocol: %v", err) -// } else { -// time.Sleep(1 * time.Millisecond) -// if len(conn.Out) != 2 { -// t.Errorf("msg not written") -// } else { -// out := conn.Out[1] -// packet := Packet(16, 3) -// if bytes.Compare(out, packet) != 0 { -// t.Errorf("incorrect packet %v != %v", out, packet) -// } -// } -// } + done := make(chan struct{}) + proto := Protocol{ + Name: "a", + Length: 5, + Run: func(peer *Peer, rw MsgReadWriter) error { + msg, err := rw.ReadMsg() + if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Code != 2 { + t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) + } + data, err := msg.Data() + if err != nil { + t.Errorf("data decoding error: %v", err) + } + expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} + if !reflect.DeepEqual(data.Slice(), expdata) { + t.Errorf("incorrect msg data %#v", data.Slice()) + } + close(done) + return nil + }, + } -// msg = NewMsg(2) -// err = peer.Write("ccc", msg) -// if err != nil { -// t.Errorf("expect no error for known protocol: %v", err) -// } else { -// time.Sleep(1 * time.Millisecond) -// if len(conn.Out) != 3 { -// t.Errorf("msg not written") -// } else { -// out := conn.Out[2] -// packet := Packet(21, 2) -// if bytes.Compare(out, packet) != 0 { -// t.Errorf("incorrect packet %v != %v", out, packet) -// } -// } -// } + net, peer, errc := testPeer([]Protocol{proto}) + defer net.Close() + peer.startSubprotocols([]Cap{proto.cap()}) -// err = peer.Write("bbb", msg) -// time.Sleep(1 * time.Millisecond) -// if err == nil { -// t.Errorf("expect error for unknown protocol") -// } -// } + writeMsg(net, NewMsg(18, 1, "000")) + select { + case <-done: + case err := <-errc: + t.Errorf("peer returned: %v", err) + case <-time.After(2 * time.Second): + t.Errorf("receive timeout") + } +} + +func TestPeerProtoReadLargeMsg(t *testing.T) { + defer testlog(t).detach() + + msgsize := uint32(10 * 1024 * 1024) + done := make(chan struct{}) + proto := Protocol{ + Name: "a", + Length: 5, + Run: func(peer *Peer, rw MsgReadWriter) error { + msg, err := rw.ReadMsg() + if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Size != msgsize+4 { + t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize) + } + msg.Discard() + close(done) + return nil + }, + } + + net, peer, errc := testPeer([]Protocol{proto}) + defer net.Close() + peer.startSubprotocols([]Cap{proto.cap()}) + + writeMsg(net, NewMsg(18, make([]byte, msgsize))) + select { + case <-done: + case err := <-errc: + t.Errorf("peer returned: %v", err) + case <-time.After(2 * time.Second): + t.Errorf("receive timeout") + } +} + +func TestPeerProtoEncodeMsg(t *testing.T) { + defer testlog(t).detach() + + proto := Protocol{ + Name: "a", + Length: 2, + Run: func(peer *Peer, rw MsgReadWriter) error { + if err := rw.EncodeMsg(2); err == nil { + t.Error("expected error for out-of-range msg code, got nil") + } + if err := rw.EncodeMsg(1); err != nil { + t.Errorf("write error: %v", err) + } + return nil + }, + } + net, peer, _ := testPeer([]Protocol{proto}) + defer net.Close() + peer.startSubprotocols([]Cap{proto.cap()}) + + bufr := bufio.NewReader(net) + msg, err := readMsg(bufr) + if err != nil { + t.Errorf("read error: %v", err) + } + if msg.Code != 17 { + t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) + } +} + +func TestPeerWrite(t *testing.T) { + defer testlog(t).detach() + + net, peer, peerErr := testPeer([]Protocol{discard}) + defer net.Close() + peer.startSubprotocols([]Cap{discard.cap()}) + + // test write errors + if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { + t.Errorf("expected error for unknown protocol, got nil") + } + if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil { + t.Errorf("expected error for out-of-range msg code, got nil") + } else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode { + t.Errorf("wrong error for out-of-range msg code, got %#v", err) + } + + // setup for reading the message on the other end + read := make(chan struct{}) + go func() { + bufr := bufio.NewReader(net) + msg, err := readMsg(bufr) + if err != nil { + t.Errorf("read error: %v", err) + } else if msg.Code != 16 { + t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) + } + msg.Discard() + close(read) + }() + + // test succcessful write + if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { + t.Errorf("expect no error for known protocol: %v", err) + } + select { + case <-read: + case err := <-peerErr: + t.Fatalf("peer stopped: %v", err) + } +} + +func TestPeerActivity(t *testing.T) { + // shorten inactivityTimeout while this test is running + oldT := inactivityTimeout + defer func() { inactivityTimeout = oldT }() + inactivityTimeout = 20 * time.Millisecond + + net, peer, peerErr := testPeer([]Protocol{discard}) + defer net.Close() + peer.startSubprotocols([]Cap{discard.cap()}) + + sub := peer.activity.Subscribe(time.Time{}) + defer sub.Unsubscribe() + + for i := 0; i < 6; i++ { + writeMsg(net, NewMsg(16)) + select { + case <-sub.Chan(): + case <-time.After(inactivityTimeout / 2): + t.Fatal("no event within ", inactivityTimeout/2) + case err := <-peerErr: + t.Fatal("peer error", err) + } + } + + select { + case <-time.After(inactivityTimeout * 2): + case <-sub.Chan(): + t.Fatal("got activity event while connection was inactive") + case err := <-peerErr: + t.Fatal("peer error", err) + } +} diff --git a/p2p/protocol.go b/p2p/protocol.go index d22ba70cb..169dcdb6e 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -3,249 +3,185 @@ package p2p import ( "bytes" "net" - "sort" "time" "github.com/ethereum/go-ethereum/ethutil" ) -// Protocol is implemented by P2P subprotocols. -type Protocol interface { - // Start is called when the protocol becomes active. - // It should read and write messages from rw. - // Messages must be fully consumed. +// Protocol represents a P2P subprotocol implementation. +type Protocol struct { + // Name should contain the official protocol name, + // often a three-letter word. + Name string + + // Version should contain the version number of the protocol. + Version uint + + // Length should contain the number of message codes used + // by the protocol. + Length uint64 + + // Run is called in a new groutine when the protocol has been + // negotiated with a peer. It should read and write messages from + // rw. The Payload for each message must be fully consumed. // - // The connection is closed when Start returns. It should return + // The peer connection is closed when Start returns. It should return // any protocol-level error (such as an I/O error) that is // encountered. - Start(peer *Peer, rw MsgReadWriter) error - - // Offset should return the number of message codes - // used by the protocol. - Offset() MsgCode + Run func(peer *Peer, rw MsgReadWriter) error } -type MsgReader interface { - ReadMsg() (Msg, error) +func (p Protocol) cap() Cap { + return Cap{p.Name, p.Version} } -type MsgWriter interface { - WriteMsg(Msg) error +const ( + baseProtocolVersion = 2 + baseProtocolLength = uint64(16) + baseProtocolMaxMsgSize = 10 * 1024 * 1024 +) + +const ( + // devp2p message codes + handshakeMsg = 0x00 + discMsg = 0x01 + pingMsg = 0x02 + pongMsg = 0x03 + getPeersMsg = 0x04 + peersMsg = 0x05 +) + +// handshake is the structure of a handshake list. +type handshake struct { + Version uint64 + ID string + Caps []Cap + ListenPort uint64 + NodeID []byte } -// MsgReadWriter is passed to protocols. Protocol implementations can -// use it to write messages back to a connected peer. -type MsgReadWriter interface { - MsgReader - MsgWriter +func (h *handshake) String() string { + return h.ID +} +func (h *handshake) Pubkey() []byte { + return h.NodeID } -type MsgHandler func(code MsgCode, data *ethutil.Value) error - -// MsgLoop reads messages off the given reader and -// calls the handler function for each decoded message until -// it returns an error or the peer connection is closed. -// -// If a message is larger than the given maximum size, RunProtocol -// returns an appropriate error.n -func MsgLoop(r MsgReader, maxsize uint32, handler MsgHandler) error { - for { - msg, err := r.ReadMsg() - if err != nil { - return err - } - if msg.Size > maxsize { - return NewPeerError(InvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize) - } - value, err := msg.Data() - if err != nil { - return err - } - if err := handler(msg.Code, value); err != nil { - return err - } - } +// Cap is the structure of a peer capability. +type Cap struct { + Name string + Version uint } -// the ÐΞVp2p base protocol +func (cap Cap) RlpData() interface{} { + return []interface{}{cap.Name, cap.Version} +} + +type capsByName []Cap + +func (cs capsByName) Len() int { return len(cs) } +func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name } +func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] } + type baseProtocol struct { rw MsgReadWriter peer *Peer } -type bpMsg struct { - code MsgCode - data *ethutil.Value -} +func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { + bp := &baseProtocol{rw, peer} -const ( - p2pVersion = 0 - pingTimeout = 2 * time.Second - pingGracePeriod = 2 * time.Second -) - -const ( - // message codes - handshakeMsg = iota - discMsg - pingMsg - pongMsg - getPeersMsg - peersMsg -) - -const ( - baseProtocolOffset MsgCode = 16 - baseProtocolMaxMsgSize = 500 * 1024 -) - -type DiscReason byte - -const ( - // Values are given explicitly instead of by iota because these values are - // defined by the wire protocol spec; it is easier for humans to ensure - // correctness when values are explicit. - DiscRequested = 0x00 - DiscNetworkError = 0x01 - DiscProtocolError = 0x02 - DiscUselessPeer = 0x03 - DiscTooManyPeers = 0x04 - DiscAlreadyConnected = 0x05 - DiscIncompatibleVersion = 0x06 - DiscInvalidIdentity = 0x07 - DiscQuitting = 0x08 - DiscUnexpectedIdentity = 0x09 - DiscSelf = 0x0a - DiscReadTimeout = 0x0b - DiscSubprotocolError = 0x10 -) - -var discReasonToString = [DiscSubprotocolError + 1]string{ - DiscRequested: "Disconnect requested", - DiscNetworkError: "Network error", - DiscProtocolError: "Breach of protocol", - DiscUselessPeer: "Useless peer", - DiscTooManyPeers: "Too many peers", - DiscAlreadyConnected: "Already connected", - DiscIncompatibleVersion: "Incompatible P2P protocol version", - DiscInvalidIdentity: "Invalid node identity", - DiscQuitting: "Client quitting", - DiscUnexpectedIdentity: "Unexpected identity", - DiscSelf: "Connected to self", - DiscReadTimeout: "Read timeout", - DiscSubprotocolError: "Subprotocol error", -} - -func (d DiscReason) String() string { - if len(discReasonToString) < int(d) { - return "Unknown" + // do handshake + if err := rw.WriteMsg(bp.handshakeMsg()); err != nil { + return err } - return discReasonToString[d] -} - -func (bp *baseProtocol) Offset() MsgCode { - return baseProtocolOffset -} - -func (bp *baseProtocol) Start(peer *Peer, rw MsgReadWriter) error { - bp.peer, bp.rw = peer, rw - - // Do the handshake. - // TODO: disconnect is valid before handshake, too. - rw.WriteMsg(bp.peer.server.handshakeMsg()) msg, err := rw.ReadMsg() if err != nil { return err } if msg.Code != handshakeMsg { - return NewPeerError(ProtocolBreach, " first message must be handshake") + return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code) } data, err := msg.Data() if err != nil { - return NewPeerError(InvalidMsg, "%v", err) + return newPeerError(errInvalidMsg, "%v", err) } if err := bp.handleHandshake(data); err != nil { return err } - msgin := make(chan bpMsg) - done := make(chan error, 1) + // run main loop + quit := make(chan error, 1) go func() { - done <- MsgLoop(rw, baseProtocolMaxMsgSize, - func(code MsgCode, data *ethutil.Value) error { - msgin <- bpMsg{code, data} - return nil - }) + quit <- MsgLoop(rw, baseProtocolMaxMsgSize, bp.handle) }() - return bp.loop(msgin, done) + return bp.loop(quit) } -func (bp *baseProtocol) loop(msgin <-chan bpMsg, quit <-chan error) error { - logger.Debugf("pingpong keepalive started at %v\n", time.Now()) - messenger := bp.rw.(*proto).messenger - pingTimer := time.NewTimer(pingTimeout) - pinged := true +var pingTimeout = 2 * time.Second - for { +func (bp *baseProtocol) loop(quit <-chan error) error { + ping := time.NewTimer(pingTimeout) + activity := bp.peer.activity.Subscribe(time.Time{}) + lastActive := time.Time{} + defer ping.Stop() + defer activity.Unsubscribe() + + getPeersTick := time.NewTicker(10 * time.Second) + defer getPeersTick.Stop() + err := bp.rw.EncodeMsg(getPeersMsg) + + for err == nil { select { - case msg := <-msgin: - if err := bp.handle(msg.code, msg.data); err != nil { - return err - } - case err := <-quit: + case err = <-quit: return err - case <-messenger.pulse: - pingTimer.Reset(pingTimeout) - pinged = false - case <-pingTimer.C: - if pinged { - return NewPeerError(PingTimeout, "") + case <-getPeersTick.C: + err = bp.rw.EncodeMsg(getPeersMsg) + case event := <-activity.Chan(): + ping.Reset(pingTimeout) + lastActive = event.(time.Time) + case t := <-ping.C: + if lastActive.Add(pingTimeout * 2).Before(t) { + err = newPeerError(errPingTimeout, "") + } else if lastActive.Add(pingTimeout).Before(t) { + err = bp.rw.EncodeMsg(pingMsg) } - logger.Debugf("pinging at %v\n", time.Now()) - if err := bp.rw.WriteMsg(NewMsg(pingMsg)); err != nil { - return NewPeerError(WriteError, "%v", err) - } - pinged = true - pingTimer.Reset(pingTimeout) } } + return err } -func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error { +func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error { switch code { case handshakeMsg: - return NewPeerError(ProtocolBreach, " extra handshake received") + return newPeerError(errProtocolBreach, "extra handshake received") case discMsg: - logger.Infof("Disconnect requested from peer %v, reason", DiscReason(data.Get(0).Uint())) - bp.peer.server.PeerDisconnect() <- DisconnectRequest{ - addr: bp.peer.Address, - reason: DiscRequested, - } + bp.peer.Disconnect(DiscReason(data.Get(0).Uint())) + return nil case pingMsg: - return bp.rw.WriteMsg(NewMsg(pongMsg)) + return bp.rw.EncodeMsg(pongMsg) case pongMsg: - // reply for ping case getPeersMsg: - // Peer asked for list of connected peers. - peersRLP := bp.peer.server.encodedPeerList() - if peersRLP != nil { - msg := Msg{ - Code: peersMsg, - Size: uint32(len(peersRLP)), - Payload: bytes.NewReader(peersRLP), - } - return bp.rw.WriteMsg(msg) + peers := bp.peerList() + // this is dangerous. the spec says that we should _delay_ + // sending the response if no new information is available. + // this means that would need to send a response later when + // new peers become available. + // + // TODO: add event mechanism to notify baseProtocol for new peers + if len(peers) > 0 { + return bp.rw.EncodeMsg(peersMsg, peers) } case peersMsg: bp.handlePeers(data) default: - return NewPeerError(InvalidMsgCode, "unknown message code %v", code) + return newPeerError(errInvalidMsgCode, "unknown message code %v", code) } return nil } @@ -253,62 +189,102 @@ func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error { func (bp *baseProtocol) handlePeers(data *ethutil.Value) { it := data.NewIterator() for it.Next() { - ip := net.IP(it.Value().Get(0).Bytes()) - port := it.Value().Get(1).Uint() - address := &net.TCPAddr{IP: ip, Port: int(port)} - go bp.peer.server.PeerConnect(address) + addr := &peerAddr{ + IP: net.IP(it.Value().Get(0).Bytes()), + Port: it.Value().Get(1).Uint(), + Pubkey: it.Value().Get(2).Bytes(), + } + bp.peer.Debugf("received peer suggestion: %v", addr) + bp.peer.newPeerAddr <- addr } } func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error { - var ( - remoteVersion = c.Get(0).Uint() - id = c.Get(1).Str() - caps = c.Get(2) - port = c.Get(3).Uint() - pubkey = c.Get(4).Bytes() - ) - // Check correctness of p2p protocol version - if remoteVersion != p2pVersion { - return NewPeerError(P2PVersionMismatch, "Require protocol %d, received %d\n", p2pVersion, remoteVersion) + hs := handshake{ + Version: c.Get(0).Uint(), + ID: c.Get(1).Str(), + Caps: nil, // decoded below + ListenPort: c.Get(3).Uint(), + NodeID: c.Get(4).Bytes(), } - - // Handle the pub key (validation, uniqueness) - if len(pubkey) == 0 { - return NewPeerError(PubkeyMissing, "not supplied in handshake.") + if hs.Version != baseProtocolVersion { + return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", + baseProtocolVersion, hs.Version) } - - if len(pubkey) != 64 { - return NewPeerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8) + if len(hs.NodeID) == 0 { + return newPeerError(errPubkeyMissing, "") } - - // self connect detection - if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 { - return NewPeerError(PubkeyForbidden, "not allowed to connect to self") + if len(hs.NodeID) != 64 { + return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8) } - - // register pubkey on server. this also sets the pubkey on the peer (need lock) - if err := bp.peer.server.RegisterPubkey(bp.peer, pubkey); err != nil { - return NewPeerError(PubkeyForbidden, err.Error()) + if da := bp.peer.dialAddr; da != nil { + // verify that the peer we wanted to connect to + // actually holds the target public key. + if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) { + return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch") + } } - - // check port - if bp.peer.Inbound { - uint16port := uint16(port) - if bp.peer.Port > 0 && bp.peer.Port != uint16port { - return NewPeerError(PortMismatch, "port mismatch: %v != %v", bp.peer.Port, port) - } else { - bp.peer.Port = uint16port + pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) + if err := bp.peer.pubkeyHook(pa); err != nil { + return newPeerError(errPubkeyForbidden, "%v", err) + } + capsIt := c.Get(2).NewIterator() + for capsIt.Next() { + cap := capsIt.Value() + name := cap.Get(0).Str() + if name != "" { + hs.Caps = append(hs.Caps, Cap{Name: name, Version: uint(cap.Get(1).Uint())}) } } - capsIt := caps.NewIterator() - for capsIt.Next() { - cap := capsIt.Value().Str() - bp.peer.Caps = append(bp.peer.Caps, cap) + var addr *peerAddr + if hs.ListenPort != 0 { + addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) + addr.Port = hs.ListenPort } - sort.Strings(bp.peer.Caps) - bp.rw.(*proto).messenger.setRemoteProtocols(bp.peer.Caps) - bp.peer.Id = id + bp.peer.setHandshakeInfo(&hs, addr, hs.Caps) + bp.peer.startSubprotocols(hs.Caps) return nil } + +func (bp *baseProtocol) handshakeMsg() Msg { + var ( + port uint64 + caps []interface{} + ) + if bp.peer.ourListenAddr != nil { + port = bp.peer.ourListenAddr.Port + } + for _, proto := range bp.peer.protocols { + caps = append(caps, proto.cap()) + } + return NewMsg(handshakeMsg, + baseProtocolVersion, + bp.peer.ourID.String(), + caps, + port, + bp.peer.ourID.Pubkey()[1:], + ) +} + +func (bp *baseProtocol) peerList() []ethutil.RlpEncodable { + peers := bp.peer.otherPeers() + ds := make([]ethutil.RlpEncodable, 0, len(peers)) + for _, p := range peers { + p.infolock.Lock() + addr := p.listenAddr + p.infolock.Unlock() + // filter out this peer and peers that are not listening or + // have not completed the handshake. + // TODO: track previously sent peers and exclude them as well. + if p == bp.peer || addr == nil { + continue + } + ds = append(ds, addr) + } + ourAddr := bp.peer.ourListenAddr + if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() { + ds = append(ds, ourAddr) + } + return ds +} diff --git a/p2p/server.go b/p2p/server.go index 54d2cde30..8a6087566 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -2,21 +2,420 @@ package p2p import ( "bytes" + "errors" "fmt" "net" - "sort" - "strconv" "sync" "time" - logpkg "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger" ) const ( - outboundAddressPoolSize = 10 - disconnectGracePeriod = 2 + outboundAddressPoolSize = 500 + defaultDialTimeout = 10 * time.Second + portMappingUpdateInterval = 15 * time.Minute + portMappingTimeout = 20 * time.Minute ) +var srvlog = logger.NewLogger("P2P Server") + +// Server manages all peer connections. +// +// The fields of Server are used as configuration parameters. +// You should set them before starting the Server. Fields may not be +// modified while the server is running. +type Server struct { + // This field must be set to a valid client identity. + Identity ClientIdentity + + // MaxPeers is the maximum number of peers that can be + // connected. It must be greater than zero. + MaxPeers int + + // Protocols should contain the protocols supported + // by the server. Matching protocols are launched for + // each peer. + Protocols []Protocol + + // If Blacklist is set to a non-nil value, the given Blacklist + // is used to verify peer connections. + Blacklist Blacklist + + // If ListenAddr is set to a non-nil address, the server + // will listen for incoming connections. + // + // If the port is zero, the operating system will pick a port. The + // ListenAddr field will be updated with the actual address when + // the server is started. + ListenAddr string + + // If set to a non-nil value, the given NAT port mapper + // is used to make the listening port available to the + // Internet. + NAT NAT + + // If Dialer is set to a non-nil value, the given Dialer + // is used to dial outbound peer connections. + Dialer *net.Dialer + + // If NoDial is true, the server will not dial any peers. + NoDial bool + + // Hook for testing. This is useful because we can inhibit + // the whole protocol stack. + newPeerFunc peerFunc + + lock sync.RWMutex + running bool + listener net.Listener + laddr *net.TCPAddr // real listen addr + peers []*Peer + peerSlots chan int + peerCount int + + quit chan struct{} + wg sync.WaitGroup + peerConnect chan *peerAddr + peerDisconnect chan *Peer +} + +// NAT is implemented by NAT traversal methods. +type NAT interface { + GetExternalAddress() (net.IP, error) + AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error + DeletePortMapping(protocol string, extport, intport int) error + + // Should return name of the method. + String() string +} + +type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer + +// Peers returns all connected peers. +func (srv *Server) Peers() (peers []*Peer) { + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { + if peer != nil { + peers = append(peers, peer) + } + } + return +} + +// PeerCount returns the number of connected peers. +func (srv *Server) PeerCount() int { + srv.lock.RLock() + defer srv.lock.RUnlock() + return srv.peerCount +} + +// SuggestPeer injects an address into the outbound address pool. +func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) { + select { + case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}: + default: // don't block + } +} + +// Broadcast sends an RLP-encoded message to all connected peers. +// This method is deprecated and will be removed later. +func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) { + var payload []byte + if data != nil { + payload = encodePayload(data...) + } + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { + if peer != nil { + var msg = Msg{Code: code} + if data != nil { + msg.Payload = bytes.NewReader(payload) + msg.Size = uint32(len(payload)) + } + peer.writeProtoMsg(protocol, msg) + } + } +} + +// Start starts running the server. +// Servers can be re-used and started again after stopping. +func (srv *Server) Start() (err error) { + srv.lock.Lock() + defer srv.lock.Unlock() + if srv.running { + return errors.New("server already running") + } + srvlog.Infoln("Starting Server") + + // initialize fields + if srv.Identity == nil { + return fmt.Errorf("Server.Identity must be set to a non-nil identity") + } + if srv.MaxPeers <= 0 { + return fmt.Errorf("Server.MaxPeers must be > 0") + } + srv.quit = make(chan struct{}) + srv.peers = make([]*Peer, srv.MaxPeers) + srv.peerSlots = make(chan int, srv.MaxPeers) + srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize) + srv.peerDisconnect = make(chan *Peer) + if srv.newPeerFunc == nil { + srv.newPeerFunc = newServerPeer + } + if srv.Blacklist == nil { + srv.Blacklist = NewBlacklist() + } + if srv.Dialer == nil { + srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} + } + + if srv.ListenAddr != "" { + if err := srv.startListening(); err != nil { + return err + } + } + if !srv.NoDial { + srv.wg.Add(1) + go srv.dialLoop() + } + if srv.NoDial && srv.ListenAddr == "" { + srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.") + } + + // make all slots available + for i := range srv.peers { + srv.peerSlots <- i + } + // note: discLoop is not part of WaitGroup + go srv.discLoop() + srv.running = true + return nil +} + +func (srv *Server) startListening() error { + listener, err := net.Listen("tcp", srv.ListenAddr) + if err != nil { + return err + } + srv.ListenAddr = listener.Addr().String() + srv.laddr = listener.Addr().(*net.TCPAddr) + srv.listener = listener + srv.wg.Add(1) + go srv.listenLoop() + if !srv.laddr.IP.IsLoopback() && srv.NAT != nil { + srv.wg.Add(1) + go srv.natLoop(srv.laddr.Port) + } + return nil +} + +// Stop terminates the server and all active peer connections. +// It blocks until all active connections have been closed. +func (srv *Server) Stop() { + srv.lock.Lock() + if !srv.running { + srv.lock.Unlock() + return + } + srv.running = false + srv.lock.Unlock() + + srvlog.Infoln("Stopping server") + if srv.listener != nil { + // this unblocks listener Accept + srv.listener.Close() + } + close(srv.quit) + for _, peer := range srv.Peers() { + peer.Disconnect(DiscQuitting) + } + srv.wg.Wait() + + // wait till they actually disconnect + // this is checked by claiming all peerSlots. + // slots become available as the peers disconnect. + for i := 0; i < cap(srv.peerSlots); i++ { + <-srv.peerSlots + } + // terminate discLoop + close(srv.peerDisconnect) +} + +func (srv *Server) discLoop() { + for peer := range srv.peerDisconnect { + // peer has just disconnected. free up its slot. + srvlog.Infof("%v is gone", peer) + srv.peerSlots <- peer.slot + srv.lock.Lock() + srv.peers[peer.slot] = nil + srv.lock.Unlock() + } +} + +// main loop for adding connections via listening +func (srv *Server) listenLoop() { + defer srv.wg.Done() + + srvlog.Infoln("Listening on", srv.listener.Addr()) + for { + select { + case slot := <-srv.peerSlots: + conn, err := srv.listener.Accept() + if err != nil { + srv.peerSlots <- slot + return + } + srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot) + srv.addPeer(conn, nil, slot) + case <-srv.quit: + return + } + } +} + +func (srv *Server) natLoop(port int) { + defer srv.wg.Done() + for { + srv.updatePortMapping(port) + select { + case <-time.After(portMappingUpdateInterval): + // one more round + case <-srv.quit: + srv.removePortMapping(port) + return + } + } +} + +func (srv *Server) updatePortMapping(port int) { + srvlog.Infoln("Attempting to map port", port, "with", srv.NAT) + err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout) + if err != nil { + srvlog.Errorln("Port mapping error:", err) + return + } + extip, err := srv.NAT.GetExternalAddress() + if err != nil { + srvlog.Errorln("Error getting external IP:", err) + return + } + srv.lock.Lock() + extaddr := *(srv.listener.Addr().(*net.TCPAddr)) + extaddr.IP = extip + srvlog.Infoln("Mapped port, external addr is", &extaddr) + srv.laddr = &extaddr + srv.lock.Unlock() +} + +func (srv *Server) removePortMapping(port int) { + srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT) + srv.NAT.DeletePortMapping("tcp", port, port) +} + +func (srv *Server) dialLoop() { + defer srv.wg.Done() + var ( + suggest chan *peerAddr + slot *int + slots = srv.peerSlots + ) + for { + select { + case i := <-slots: + // we need a peer in slot i, slot reserved + slot = &i + // now we can watch for candidate peers in the next loop + suggest = srv.peerConnect + // do not consume more until candidate peer is found + slots = nil + + case desc := <-suggest: + // candidate peer found, will dial out asyncronously + // if connection fails slot will be released + go srv.dialPeer(desc, *slot) + // we can watch if more peers needed in the next loop + slots = srv.peerSlots + // until then we dont care about candidate peers + suggest = nil + + case <-srv.quit: + // give back the currently reserved slot + if slot != nil { + srv.peerSlots <- *slot + } + return + } + } +} + +// connect to peer via dial out +func (srv *Server) dialPeer(desc *peerAddr, slot int) { + srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot) + conn, err := srv.Dialer.Dial(desc.Network(), desc.String()) + if err != nil { + srvlog.Errorf("Dial error: %v", err) + srv.peerSlots <- slot + return + } + go srv.addPeer(conn, desc, slot) +} + +// creates the new peer object and inserts it into its slot +func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer { + srv.lock.Lock() + defer srv.lock.Unlock() + if !srv.running { + conn.Close() + srv.peerSlots <- slot // release slot + return nil + } + peer := srv.newPeerFunc(srv, conn, desc) + peer.slot = slot + srv.peers[slot] = peer + srv.peerCount++ + go func() { peer.loop(); srv.peerDisconnect <- peer }() + return peer +} + +// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot +func (srv *Server) removePeer(peer *Peer) { + srv.lock.Lock() + defer srv.lock.Unlock() + srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot) + if srv.peers[peer.slot] != peer { + srvlog.Warnln("Invalid peer to remove:", peer) + return + } + // remove from list and index + srv.peerCount-- + srv.peers[peer.slot] = nil + // release slot to signal need for a new peer, last! + srv.peerSlots <- peer.slot +} + +func (srv *Server) verifyPeer(addr *peerAddr) error { + if srv.Blacklist.Exists(addr.Pubkey) { + return errors.New("blacklisted") + } + if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) { + return newPeerError(errPubkeyForbidden, "not allowed to connect to srv") + } + srv.lock.RLock() + defer srv.lock.RUnlock() + for _, peer := range srv.peers { + if peer != nil { + id := peer.Identity() + if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) { + return errors.New("already connected") + } + } + } + return nil +} + type Blacklist interface { Get([]byte) (bool, error) Put([]byte) error @@ -66,423 +465,3 @@ func (self *BlacklistMap) Delete(pubkey []byte) error { delete(self.blacklist, string(pubkey)) return nil } - -type Server struct { - network Network - listening bool //needed? - dialing bool //needed? - closed bool - identity ClientIdentity - addr net.Addr - port uint16 - protocols []string - - quit chan chan bool - peersLock sync.RWMutex - - maxPeers int - peers []*Peer - peerSlots chan int - peersTable map[string]int - peerCount int - cachedEncodedPeers []byte - - peerConnect chan net.Addr - peerDisconnect chan DisconnectRequest - blacklist Blacklist - handlers Handlers -} - -var logger = logpkg.NewLogger("P2P") - -func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server { - // get alphabetical list of protocol names from handlers map - protocols := []string{} - for protocol := range handlers { - protocols = append(protocols, protocol) - } - sort.Strings(protocols) - - _, port, _ := net.SplitHostPort(addr.String()) - intport, _ := strconv.Atoi(port) - - self := &Server{ - // NewSimpleClientIdentity(clientIdentifier, version, customIdentifier) - network: network, - identity: identity, - addr: addr, - port: uint16(intport), - protocols: protocols, - - quit: make(chan chan bool), - - maxPeers: maxPeers, - peers: make([]*Peer, maxPeers), - peerSlots: make(chan int, maxPeers), - peersTable: make(map[string]int), - - peerConnect: make(chan net.Addr, outboundAddressPoolSize), - peerDisconnect: make(chan DisconnectRequest), - blacklist: blacklist, - - handlers: handlers, - } - for i := 0; i < maxPeers; i++ { - self.peerSlots <- i // fill up with indexes - } - return self -} - -func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) { - addr, err = self.network.NewAddr(host, port) - return -} - -func (self *Server) ParseAddr(address string) (addr net.Addr, err error) { - addr, err = self.network.ParseAddr(address) - return -} - -func (self *Server) ClientIdentity() ClientIdentity { - return self.identity -} - -func (self *Server) Peers() (peers []*Peer) { - self.peersLock.RLock() - defer self.peersLock.RUnlock() - for _, peer := range self.peers { - if peer != nil { - peers = append(peers, peer) - } - } - return -} - -func (self *Server) PeerCount() int { - self.peersLock.RLock() - defer self.peersLock.RUnlock() - return self.peerCount -} - -func (self *Server) PeerConnect(addr net.Addr) { - // TODO: should buffer, filter and uniq - // send GetPeersMsg if not blocking - select { - case self.peerConnect <- addr: // not enough peers - self.Broadcast("", getPeersMsg) - default: // we dont care - } -} - -func (self *Server) PeerDisconnect() chan DisconnectRequest { - return self.peerDisconnect -} - -func (self *Server) Blacklist() Blacklist { - return self.blacklist -} - -func (self *Server) Handlers() Handlers { - return self.handlers -} - -func (self *Server) Broadcast(protocol string, code MsgCode, data ...interface{}) { - var payload []byte - if data != nil { - payload = encodePayload(data...) - } - self.peersLock.RLock() - defer self.peersLock.RUnlock() - for _, peer := range self.peers { - if peer != nil { - var msg = Msg{Code: code} - if data != nil { - msg.Payload = bytes.NewReader(payload) - msg.Size = uint32(len(payload)) - } - peer.messenger.writeProtoMsg(protocol, msg) - } - } -} - -// Start the server -func (self *Server) Start(listen bool, dial bool) { - self.network.Start() - if listen { - listener, err := self.network.Listener(self.addr) - if err != nil { - logger.Warnf("Error initializing listener: %v", err) - logger.Warnf("Connection listening disabled") - self.listening = false - } else { - self.listening = true - logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr()) - go self.inboundPeerHandler(listener) - } - } - if dial { - dialer, err := self.network.Dialer(self.addr) - if err != nil { - logger.Warnf("Error initializing dialer: %v", err) - logger.Warnf("Connection dialout disabled") - self.dialing = false - } else { - self.dialing = true - logger.Infoln("Dial peers watching outbound address pool") - go self.outboundPeerHandler(dialer) - } - } - logger.Infoln("server started") -} - -func (self *Server) Stop() { - logger.Infoln("server stopping...") - // // quit one loop if dialing - if self.dialing { - logger.Infoln("stop dialout...") - dialq := make(chan bool) - self.quit <- dialq - <-dialq - fmt.Println("quit another") - } - // quit the other loop if listening - if self.listening { - logger.Infoln("stop listening...") - listenq := make(chan bool) - self.quit <- listenq - <-listenq - fmt.Println("quit one") - } - - fmt.Println("quit waited") - - logger.Infoln("stopping peers...") - peers := []net.Addr{} - self.peersLock.RLock() - self.closed = true - for _, peer := range self.peers { - if peer != nil { - peers = append(peers, peer.Address) - } - } - self.peersLock.RUnlock() - for _, address := range peers { - go self.removePeer(DisconnectRequest{ - addr: address, - reason: DiscQuitting, - }) - } - // wait till they actually disconnect - // this is checked by draining the peerSlots (slots are released back if a peer is removed) - i := 0 - fmt.Println("draining peers") - -FOR: - for { - select { - case slot := <-self.peerSlots: - i++ - fmt.Printf("%v: found slot %v\n", i, slot) - if i == self.maxPeers { - break FOR - } - } - } - logger.Infoln("server stopped") -} - -// main loop for adding connections via listening -func (self *Server) inboundPeerHandler(listener net.Listener) { - for { - select { - case slot := <-self.peerSlots: - go self.connectInboundPeer(listener, slot) - case errc := <-self.quit: - listener.Close() - fmt.Println("quit listenloop") - errc <- true - return - } - } -} - -// main loop for adding outbound peers based on peerConnect address pool -// this same loop handles peer disconnect requests as well -func (self *Server) outboundPeerHandler(dialer Dialer) { - // addressChan initially set to nil (only watches peerConnect if we need more peers) - var addressChan chan net.Addr - slots := self.peerSlots - var slot *int - for { - select { - case i := <-slots: - // we need a peer in slot i, slot reserved - slot = &i - // now we can watch for candidate peers in the next loop - addressChan = self.peerConnect - // do not consume more until candidate peer is found - slots = nil - case address := <-addressChan: - // candidate peer found, will dial out asyncronously - // if connection fails slot will be released - go self.connectOutboundPeer(dialer, address, *slot) - // we can watch if more peers needed in the next loop - slots = self.peerSlots - // until then we dont care about candidate peers - addressChan = nil - case request := <-self.peerDisconnect: - go self.removePeer(request) - case errc := <-self.quit: - if addressChan != nil && slot != nil { - self.peerSlots <- *slot - } - fmt.Println("quit dialloop") - errc <- true - return - } - } -} - -// check if peer address already connected -func (self *Server) isConnected(address net.Addr) bool { - self.peersLock.RLock() - defer self.peersLock.RUnlock() - _, found := self.peersTable[address.String()] - return found -} - -// connect to peer via listener.Accept() -func (self *Server) connectInboundPeer(listener net.Listener, slot int) { - var address net.Addr - conn, err := listener.Accept() - if err != nil { - logger.Debugln(err) - self.peerSlots <- slot - return - } - address = conn.RemoteAddr() - // XXX: this won't work because the remote socket - // address does not identify the peer. we should - // probably get rid of this check and rely on public - // key detection in the base protocol. - if self.isConnected(address) { - conn.Close() - self.peerSlots <- slot - return - } - fmt.Printf("adding %v\n", address) - go self.addPeer(conn, address, true, slot) -} - -// connect to peer via dial out -func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) { - if self.isConnected(address) { - return - } - conn, err := dialer.Dial(address.Network(), address.String()) - if err != nil { - self.peerSlots <- slot - return - } - go self.addPeer(conn, address, false, slot) -} - -// creates the new peer object and inserts it into its slot -func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) *Peer { - self.peersLock.Lock() - defer self.peersLock.Unlock() - if self.closed { - fmt.Println("oopsy, not no longer need peer") - conn.Close() //oopsy our bad - self.peerSlots <- slot // release slot - return nil - } - logger.Infoln("adding new peer", address) - peer := NewPeer(conn, address, inbound, self) - self.peers[slot] = peer - self.peersTable[address.String()] = slot - self.peerCount++ - self.cachedEncodedPeers = nil - fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot) - peer.Start() - return peer -} - -// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot -func (self *Server) removePeer(request DisconnectRequest) { - self.peersLock.Lock() - - address := request.addr - slot := self.peersTable[address.String()] - peer := self.peers[slot] - fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot) - if peer == nil { - logger.Debugf("already removed peer on %v", address) - self.peersLock.Unlock() - return - } - // remove from list and index - self.peerCount-- - self.peers[slot] = nil - delete(self.peersTable, address.String()) - self.cachedEncodedPeers = nil - fmt.Printf("removed peer %v (slot %v)\n", peer, slot) - self.peersLock.Unlock() - - // sending disconnect message - disconnectMsg := NewMsg(discMsg, request.reason) - peer.Write("", disconnectMsg) - // be nice and wait - time.Sleep(disconnectGracePeriod * time.Second) - // switch off peer and close connections etc. - fmt.Println("stopping peer") - peer.Stop() - fmt.Println("stopped peer") - // release slot to signal need for a new peer, last! - self.peerSlots <- slot -} - -// encodedPeerList returns an RLP-encoded list of peers. -// the returned slice will be nil if there are no peers. -func (self *Server) encodedPeerList() []byte { - // TODO: memoize and reset when peers change - self.peersLock.RLock() - defer self.peersLock.RUnlock() - if self.cachedEncodedPeers == nil && self.peerCount > 0 { - var peerData []interface{} - for _, i := range self.peersTable { - peer := self.peers[i] - peerData = append(peerData, peer.Encode()) - } - self.cachedEncodedPeers = encodePayload(peerData) - } - return self.cachedEncodedPeers -} - -// fix handshake message to push to peers -func (self *Server) handshakeMsg() Msg { - return NewMsg(handshakeMsg, - p2pVersion, - []byte(self.identity.String()), - []interface{}{self.protocols}, - self.port, - self.identity.Pubkey()[1:], - ) -} - -func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error { - // Check for blacklisting - if self.blacklist.Exists(pubkey) { - return fmt.Errorf("blacklisted") - } - - self.peersLock.RLock() - defer self.peersLock.RUnlock() - for _, peer := range self.peers { - if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 { - return fmt.Errorf("already connected") - } - } - candidate.Pubkey = pubkey - return nil -} diff --git a/p2p/server_test.go b/p2p/server_test.go index a2594acba..5c0d08d39 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -1,289 +1,161 @@ package p2p import ( - "fmt" + "bytes" "io" "net" + "sync" "testing" "time" ) -type TestNetwork struct { - connections map[string]*TestNetworkConnection - dialer Dialer - maxinbound int +func startTestServer(t *testing.T, pf peerFunc) *Server { + server := &Server{ + Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"), + MaxPeers: 10, + ListenAddr: "127.0.0.1:0", + newPeerFunc: pf, + } + if err := server.Start(); err != nil { + t.Fatalf("Could not start server: %v", err) + } + return server } -func NewTestNetwork(maxinbound int) *TestNetwork { - connections := make(map[string]*TestNetworkConnection) - return &TestNetwork{ - connections: connections, - dialer: &TestDialer{connections}, - maxinbound: maxinbound, +func TestServerListen(t *testing.T) { + defer testlog(t).detach() + + // start the test server + connected := make(chan *Peer) + srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { + if conn == nil { + t.Error("peer func called with nil conn") + } + if dialAddr != nil { + t.Error("peer func called with non-nil dialAddr") + } + peer := newPeer(conn, nil, dialAddr) + connected <- peer + return peer + }) + defer close(connected) + defer srv.Stop() + + // dial the test server + conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) + if err != nil { + t.Fatalf("could not dial: %v", err) + } + defer conn.Close() + + select { + case peer := <-connected: + if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() { + t.Errorf("peer started with wrong conn: got %v, want %v", + peer.conn.LocalAddr(), conn.RemoteAddr()) + } + case <-time.After(1 * time.Second): + t.Error("server did not accept within one second") } } -func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) { - return self.dialer, nil -} +func TestServerDial(t *testing.T) { + defer testlog(t).detach() -func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) { - return &TestListener{ - connections: self.connections, - addr: addr, - max: self.maxinbound, - close: make(chan struct{}), - }, nil -} - -func (self *TestNetwork) Start() error { - return nil -} - -func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) { - return -} - -func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) { - return -} - -type TestAddr struct { - name string -} - -func (self *TestAddr) String() string { - return self.name -} - -func (*TestAddr) Network() string { - return "test" -} - -type TestDialer struct { - connections map[string]*TestNetworkConnection -} - -func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) { - address := &TestAddr{addr} - tconn := NewTestNetworkConnection(address) - self.connections[addr] = tconn - conn = net.Conn(tconn) - return -} - -type TestListener struct { - connections map[string]*TestNetworkConnection - addr net.Addr - max int - i int - close chan struct{} -} - -func (self *TestListener) Accept() (net.Conn, error) { - self.i++ - if self.i > self.max { - <-self.close - return nil, io.EOF + // run a fake TCP server to handle the connection. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("could not setup listener: %v") } - addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)} - tconn := NewTestNetworkConnection(addr) - key := tconn.RemoteAddr().String() - self.connections[key] = tconn - fmt.Printf("accepted connection from: %v \n", addr) - return tconn, nil -} + defer listener.Close() + accepted := make(chan net.Conn) + go func() { + conn, err := listener.Accept() + if err != nil { + t.Error("acccept error:", err) + } + conn.Close() + accepted <- conn + }() -func (self *TestListener) Close() error { - close(self.close) - return nil -} + // start the test server + connected := make(chan *Peer) + srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { + if conn == nil { + t.Error("peer func called with nil conn") + } + peer := newPeer(conn, nil, dialAddr) + connected <- peer + return peer + }) + defer close(connected) + defer srv.Stop() -func (self *TestListener) Addr() net.Addr { - return self.addr -} + // tell the server to connect. + connAddr := newPeerAddr(listener.Addr(), nil) + srv.peerConnect <- connAddr -type TestNetworkConnection struct { - in chan []byte - close chan struct{} - current []byte - Out [][]byte - addr net.Addr -} - -func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection { - return &TestNetworkConnection{ - in: make(chan []byte), - close: make(chan struct{}), - current: []byte{}, - Out: [][]byte{}, - addr: addr, - } -} - -func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) { - time.Sleep(latency) - for _, s := range packets { - self.in <- s - } -} - -func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) { - if len(self.current) == 0 { - var ok bool + select { + case conn := <-accepted: select { - case self.current, ok = <-self.in: - if !ok { - return 0, io.EOF + case peer := <-connected: + if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() { + t.Errorf("peer started with wrong conn: got %v, want %v", + peer.conn.RemoteAddr(), conn.LocalAddr()) } - case <-self.close: - return 0, io.EOF + if peer.dialAddr != connAddr { + t.Errorf("peer started with wrong dialAddr: got %v, want %v", + peer.dialAddr, connAddr) + } + case <-time.After(1 * time.Second): + t.Error("server did not launch peer within one second") } - } - length := len(self.current) - if length > len(buff) { - copy(buff[:], self.current[:len(buff)]) - self.current = self.current[len(buff):] - return len(buff), nil - } else { - copy(buff[:length], self.current[:]) - self.current = []byte{} - return length, io.EOF + + case <-time.After(1 * time.Second): + t.Error("server did not connect within one second") } } -func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) { - self.Out = append(self.Out, buff) - fmt.Printf("net write(%d): %x\n", len(self.Out), buff) - return len(buff), nil -} +func TestServerBroadcast(t *testing.T) { + defer testlog(t).detach() + var connected sync.WaitGroup + srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer { + peer := newPeer(c, []Protocol{discard}, dialAddr) + peer.startSubprotocols([]Cap{discard.cap()}) + connected.Done() + return peer + }) + defer srv.Stop() -func (self *TestNetworkConnection) Close() error { - close(self.close) - return nil -} - -func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) { - return -} - -func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) { - return self.addr -} - -func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) { - return -} - -func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) { - return -} - -func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) { - return -} - -func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) { - network = NewTestNetwork(1) - addr := &TestAddr{"test:30303"} - identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey") - maxPeers := 2 - if handlers == nil { - handlers = make(Handlers) + // dial a bunch of conns + var conns = make([]net.Conn, 8) + connected.Add(len(conns)) + deadline := time.Now().Add(3 * time.Second) + dialer := &net.Dialer{Deadline: deadline} + for i := range conns { + conn, err := dialer.Dial("tcp", srv.ListenAddr) + if err != nil { + t.Fatalf("conn %d: dial error: %v", i, err) + } + defer conn.Close() + conn.SetDeadline(deadline) + conns[i] = conn } - blackist := NewBlacklist() - server = New(network, addr, identity, handlers, maxPeers, blackist) - fmt.Println(server.identity.Pubkey()) - return -} + connected.Wait() -func TestServerListener(t *testing.T) { - t.SkipNow() + // broadcast one message + srv.Broadcast("discard", 0, "foo") + goldbuf := new(bytes.Buffer) + writeMsg(goldbuf, NewMsg(16, "foo")) + golden := goldbuf.Bytes() - network, server := SetupTestServer(nil) - server.Start(true, false) - time.Sleep(10 * time.Millisecond) - server.Stop() - peer1, ok := network.connections["inboundpeer-1"] - if !ok { - t.Error("not found inbound peer 1") - } else { - if len(peer1.Out) != 2 { - t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2) + // check that the message has been written everywhere + for i, conn := range conns { + buf := make([]byte, len(golden)) + if _, err := io.ReadFull(conn, buf); err != nil { + t.Errorf("conn %d: read error: %v", i, err) + } else if !bytes.Equal(buf, golden) { + t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden) } } } - -func TestServerDialer(t *testing.T) { - network, server := SetupTestServer(nil) - server.Start(false, true) - server.peerConnect <- &TestAddr{"outboundpeer-1"} - time.Sleep(10 * time.Millisecond) - server.Stop() - peer1, ok := network.connections["outboundpeer-1"] - if !ok { - t.Error("not found outbound peer 1") - } else { - if len(peer1.Out) != 2 { - t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2) - } - } -} - -// func TestServerBroadcast(t *testing.T) { -// handlers := make(Handlers) -// testProtocol := &TestProtocol{Msgs: []*Msg{}} -// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } -// network, server := SetupTestServer(handlers) -// server.Start(true, true) -// server.peerConnect <- &TestAddr{"outboundpeer-1"} -// time.Sleep(10 * time.Millisecond) -// msg := NewMsg(0) -// server.Broadcast("", msg) -// packet := Packet(0, 0) -// time.Sleep(10 * time.Millisecond) -// server.Stop() -// peer1, ok := network.connections["outboundpeer-1"] -// if !ok { -// t.Error("not found outbound peer 1") -// } else { -// fmt.Printf("out: %v\n", peer1.Out) -// if len(peer1.Out) != 3 { -// t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) -// } else { -// if bytes.Compare(peer1.Out[1], packet) != 0 { -// t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet) -// } -// } -// } -// peer2, ok := network.connections["inboundpeer-1"] -// if !ok { -// t.Error("not found inbound peer 2") -// } else { -// fmt.Printf("out: %v\n", peer2.Out) -// if len(peer1.Out) != 3 { -// t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out)) -// } else { -// if bytes.Compare(peer2.Out[1], packet) != 0 { -// t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet) -// } -// } -// } -// } - -func TestServerPeersMessage(t *testing.T) { - t.SkipNow() - _, server := SetupTestServer(nil) - server.Start(true, true) - defer server.Stop() - server.peerConnect <- &TestAddr{"outboundpeer-1"} - time.Sleep(2000 * time.Millisecond) - - pl := server.encodedPeerList() - if pl == nil { - t.Errorf("expect non-nil peer list") - } - if c := server.PeerCount(); c != 2 { - t.Errorf("expect 2 peers, got %v", c) - } -} diff --git a/p2p/testlog_test.go b/p2p/testlog_test.go new file mode 100644 index 000000000..951d43243 --- /dev/null +++ b/p2p/testlog_test.go @@ -0,0 +1,28 @@ +package p2p + +import ( + "testing" + + "github.com/ethereum/go-ethereum/logger" +) + +type testLogger struct{ t *testing.T } + +func testlog(t *testing.T) testLogger { + logger.Reset() + l := testLogger{t} + logger.AddLogSystem(l) + return l +} + +func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel } +func (testLogger) SetLogLevel(logger.LogLevel) {} + +func (l testLogger) LogPrint(level logger.LogLevel, msg string) { + l.t.Logf("%s", msg) +} + +func (testLogger) detach() { + logger.Flush() + logger.Reset() +} diff --git a/p2p/testpoc7.go b/p2p/testpoc7.go new file mode 100644 index 000000000..c0cc5c544 --- /dev/null +++ b/p2p/testpoc7.go @@ -0,0 +1,40 @@ +// +build none + +package main + +import ( + "fmt" + "log" + "net" + "os" + + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/p2p" + "github.com/obscuren/secp256k1-go" +) + +func main() { + logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel)) + + pub, _ := secp256k1.GenerateKeyPair() + srv := p2p.Server{ + MaxPeers: 10, + Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)), + ListenAddr: ":30303", + NAT: p2p.PMP(net.ParseIP("10.0.0.1")), + } + if err := srv.Start(); err != nil { + fmt.Println("could not start server:", err) + os.Exit(1) + } + + // add seed peers + seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303") + if err != nil { + fmt.Println("couldn't resolve:", err) + os.Exit(1) + } + srv.SuggestPeer(seed.IP, seed.Port, nil) + + select {} +}