8dcbdcad0a
As of this commit, we no longer rely on the protocol handler to report write errors in a timely fashion. When a write fails, shutdown is initiated immediately and no new writes can start. This will also prevent new writes from starting after Server.Stop has been called.
338 lines
7.8 KiB
Go
338 lines
7.8 KiB
Go
package p2p
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sort"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/ethereum/go-ethereum/logger"
|
|
"github.com/ethereum/go-ethereum/logger/glog"
|
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
|
"github.com/ethereum/go-ethereum/rlp"
|
|
)
|
|
|
|
const (
|
|
baseProtocolVersion = 4
|
|
baseProtocolLength = uint64(16)
|
|
baseProtocolMaxMsgSize = 2 * 1024
|
|
|
|
pingInterval = 15 * time.Second
|
|
)
|
|
|
|
const (
|
|
// devp2p message codes
|
|
handshakeMsg = 0x00
|
|
discMsg = 0x01
|
|
pingMsg = 0x02
|
|
pongMsg = 0x03
|
|
getPeersMsg = 0x04
|
|
peersMsg = 0x05
|
|
)
|
|
|
|
// protoHandshake is the RLP structure of the protocol handshake.
|
|
type protoHandshake struct {
|
|
Version uint64
|
|
Name string
|
|
Caps []Cap
|
|
ListenPort uint64
|
|
ID discover.NodeID
|
|
}
|
|
|
|
// Peer represents a connected remote node.
|
|
type Peer struct {
|
|
rw *conn
|
|
running map[string]*protoRW
|
|
|
|
wg sync.WaitGroup
|
|
protoErr chan error
|
|
closed chan struct{}
|
|
disc chan DiscReason
|
|
}
|
|
|
|
// NewPeer returns a peer for testing purposes.
|
|
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
|
|
pipe, _ := net.Pipe()
|
|
conn := &conn{fd: pipe, transport: nil, id: id, caps: caps, name: name}
|
|
peer := newPeer(conn, nil)
|
|
close(peer.closed) // ensures Disconnect doesn't block
|
|
return peer
|
|
}
|
|
|
|
// ID returns the node's public key.
|
|
func (p *Peer) ID() discover.NodeID {
|
|
return p.rw.id
|
|
}
|
|
|
|
// Name returns the node name that the remote node advertised.
|
|
func (p *Peer) Name() string {
|
|
return p.rw.name
|
|
}
|
|
|
|
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
|
func (p *Peer) Caps() []Cap {
|
|
// TODO: maybe return copy
|
|
return p.rw.caps
|
|
}
|
|
|
|
// RemoteAddr returns the remote address of the network connection.
|
|
func (p *Peer) RemoteAddr() net.Addr {
|
|
return p.rw.fd.RemoteAddr()
|
|
}
|
|
|
|
// LocalAddr returns the local address of the network connection.
|
|
func (p *Peer) LocalAddr() net.Addr {
|
|
return p.rw.fd.LocalAddr()
|
|
}
|
|
|
|
// Disconnect terminates the peer connection with the given reason.
|
|
// It returns immediately and does not wait until the connection is closed.
|
|
func (p *Peer) Disconnect(reason DiscReason) {
|
|
select {
|
|
case p.disc <- reason:
|
|
case <-p.closed:
|
|
}
|
|
}
|
|
|
|
// String implements fmt.Stringer.
|
|
func (p *Peer) String() string {
|
|
return fmt.Sprintf("Peer %x %v", p.rw.id[:8], p.RemoteAddr())
|
|
}
|
|
|
|
func newPeer(conn *conn, protocols []Protocol) *Peer {
|
|
protomap := matchProtocols(protocols, conn.caps, conn)
|
|
p := &Peer{
|
|
rw: conn,
|
|
running: protomap,
|
|
disc: make(chan DiscReason),
|
|
protoErr: make(chan error, len(protomap)+1), // protocols + pingLoop
|
|
closed: make(chan struct{}),
|
|
}
|
|
return p
|
|
}
|
|
|
|
func (p *Peer) run() DiscReason {
|
|
var (
|
|
writeStart = make(chan struct{}, 1)
|
|
writeErr = make(chan error, 1)
|
|
readErr = make(chan error, 1)
|
|
reason DiscReason
|
|
requested bool
|
|
)
|
|
p.wg.Add(2)
|
|
go p.readLoop(readErr)
|
|
go p.pingLoop()
|
|
|
|
// Start all protocol handlers.
|
|
writeStart <- struct{}{}
|
|
p.startProtocols(writeStart, writeErr)
|
|
|
|
// Wait for an error or disconnect.
|
|
loop:
|
|
for {
|
|
select {
|
|
case err := <-writeErr:
|
|
// A write finished. Allow the next write to start if
|
|
// there was no error.
|
|
if err != nil {
|
|
glog.V(logger.Detail).Infof("%v: Write error: %v\n", p, err)
|
|
reason = DiscNetworkError
|
|
break loop
|
|
}
|
|
writeStart <- struct{}{}
|
|
case err := <-readErr:
|
|
if r, ok := err.(DiscReason); ok {
|
|
reason = r
|
|
} else {
|
|
glog.V(logger.Detail).Infof("%v: Read error: %v\n", p, err)
|
|
reason = DiscNetworkError
|
|
}
|
|
break loop
|
|
case err := <-p.protoErr:
|
|
reason = discReasonForError(err)
|
|
break loop
|
|
case reason = <-p.disc:
|
|
requested = true
|
|
break loop
|
|
}
|
|
}
|
|
|
|
close(p.closed)
|
|
p.rw.close(reason)
|
|
p.wg.Wait()
|
|
if requested {
|
|
reason = DiscRequested
|
|
}
|
|
glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
|
|
return reason
|
|
}
|
|
|
|
func (p *Peer) pingLoop() {
|
|
ping := time.NewTicker(pingInterval)
|
|
defer p.wg.Done()
|
|
defer ping.Stop()
|
|
for {
|
|
select {
|
|
case <-ping.C:
|
|
if err := SendItems(p.rw, pingMsg); err != nil {
|
|
p.protoErr <- err
|
|
return
|
|
}
|
|
case <-p.closed:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) readLoop(errc chan<- error) {
|
|
defer p.wg.Done()
|
|
for {
|
|
msg, err := p.rw.ReadMsg()
|
|
if err != nil {
|
|
errc <- err
|
|
return
|
|
}
|
|
msg.ReceivedAt = time.Now()
|
|
if err = p.handle(msg); err != nil {
|
|
errc <- err
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) handle(msg Msg) error {
|
|
switch {
|
|
case msg.Code == pingMsg:
|
|
msg.Discard()
|
|
go SendItems(p.rw, pongMsg)
|
|
case msg.Code == discMsg:
|
|
var reason [1]DiscReason
|
|
// This is the last message. We don't need to discard or
|
|
// check errors because, the connection will be closed after it.
|
|
rlp.Decode(msg.Payload, &reason)
|
|
glog.V(logger.Debug).Infof("%v: Disconnect Requested: %v\n", p, reason[0])
|
|
return reason[0]
|
|
case msg.Code < baseProtocolLength:
|
|
// ignore other base protocol messages
|
|
return msg.Discard()
|
|
default:
|
|
// it's a subprotocol message
|
|
proto, err := p.getProto(msg.Code)
|
|
if err != nil {
|
|
return fmt.Errorf("msg code out of range: %v", msg.Code)
|
|
}
|
|
select {
|
|
case proto.in <- msg:
|
|
return nil
|
|
case <-p.closed:
|
|
return io.EOF
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func countMatchingProtocols(protocols []Protocol, caps []Cap) int {
|
|
n := 0
|
|
for _, cap := range caps {
|
|
for _, proto := range protocols {
|
|
if proto.Name == cap.Name && proto.Version == cap.Version {
|
|
n++
|
|
}
|
|
}
|
|
}
|
|
return n
|
|
}
|
|
|
|
// matchProtocols creates structures for matching named subprotocols.
|
|
func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
|
|
sort.Sort(capsByName(caps))
|
|
offset := baseProtocolLength
|
|
result := make(map[string]*protoRW)
|
|
outer:
|
|
for _, cap := range caps {
|
|
for _, proto := range protocols {
|
|
if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil {
|
|
result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
|
|
offset += proto.Length
|
|
continue outer
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error) {
|
|
p.wg.Add(len(p.running))
|
|
for _, proto := range p.running {
|
|
proto := proto
|
|
proto.closed = p.closed
|
|
proto.wstart = writeStart
|
|
proto.werr = writeErr
|
|
glog.V(logger.Detail).Infof("%v: Starting protocol %s/%d\n", p, proto.Name, proto.Version)
|
|
go func() {
|
|
err := proto.Run(p, proto)
|
|
if err == nil {
|
|
glog.V(logger.Detail).Infof("%v: Protocol %s/%d returned\n", p, proto.Name, proto.Version)
|
|
err = errors.New("protocol returned")
|
|
} else if err != io.EOF {
|
|
glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: %v\n", p, proto.Name, proto.Version, err)
|
|
}
|
|
p.protoErr <- err
|
|
p.wg.Done()
|
|
}()
|
|
}
|
|
}
|
|
|
|
// getProto finds the protocol responsible for handling
|
|
// the given message code.
|
|
func (p *Peer) getProto(code uint64) (*protoRW, error) {
|
|
for _, proto := range p.running {
|
|
if code >= proto.offset && code < proto.offset+proto.Length {
|
|
return proto, nil
|
|
}
|
|
}
|
|
return nil, newPeerError(errInvalidMsgCode, "%d", code)
|
|
}
|
|
|
|
type protoRW struct {
|
|
Protocol
|
|
in chan Msg // receices read messages
|
|
closed <-chan struct{} // receives when peer is shutting down
|
|
wstart <-chan struct{} // receives when write may start
|
|
werr chan<- error // for write results
|
|
offset uint64
|
|
w MsgWriter
|
|
}
|
|
|
|
func (rw *protoRW) WriteMsg(msg Msg) (err error) {
|
|
if msg.Code >= rw.Length {
|
|
return newPeerError(errInvalidMsgCode, "not handled")
|
|
}
|
|
msg.Code += rw.offset
|
|
select {
|
|
case <-rw.wstart:
|
|
err = rw.w.WriteMsg(msg)
|
|
// Report write status back to Peer.run. It will initiate
|
|
// shutdown if the error is non-nil and unblock the next write
|
|
// otherwise. The calling protocol code should exit for errors
|
|
// as well but we don't want to rely on that.
|
|
rw.werr <- err
|
|
case <-rw.closed:
|
|
err = fmt.Errorf("shutting down")
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (rw *protoRW) ReadMsg() (Msg, error) {
|
|
select {
|
|
case msg := <-rw.in:
|
|
msg.Code -= rw.offset
|
|
return msg, nil
|
|
case <-rw.closed:
|
|
return Msg{}, io.EOF
|
|
}
|
|
}
|