222 lines
5.0 KiB
Go
222 lines
5.0 KiB
Go
package p2p
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type Handlers map[string]Protocol
|
|
|
|
type proto struct {
|
|
in chan Msg
|
|
maxcode, offset MsgCode
|
|
messenger *messenger
|
|
}
|
|
|
|
func (rw *proto) WriteMsg(msg Msg) error {
|
|
if msg.Code >= rw.maxcode {
|
|
return NewPeerError(InvalidMsgCode, "not handled")
|
|
}
|
|
msg.Code += rw.offset
|
|
return rw.messenger.writeMsg(msg)
|
|
}
|
|
|
|
func (rw *proto) ReadMsg() (Msg, error) {
|
|
msg, ok := <-rw.in
|
|
if !ok {
|
|
return msg, io.EOF
|
|
}
|
|
msg.Code -= rw.offset
|
|
return msg, nil
|
|
}
|
|
|
|
// eofSignal wraps a reader with eof signaling.
|
|
// the eof channel is closed when the wrapped reader
|
|
// reaches EOF.
|
|
type eofSignal struct {
|
|
wrapped io.Reader
|
|
eof chan struct{}
|
|
}
|
|
|
|
func (r *eofSignal) Read(buf []byte) (int, error) {
|
|
n, err := r.wrapped.Read(buf)
|
|
if err != nil {
|
|
close(r.eof) // tell messenger that msg has been consumed
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
// messenger represents a message-oriented peer connection.
|
|
// It keeps track of the set of protocols understood
|
|
// by the remote peer.
|
|
type messenger struct {
|
|
peer *Peer
|
|
handlers Handlers
|
|
|
|
// the mutex protects the connection
|
|
// so only one protocol can write at a time.
|
|
writeMu sync.Mutex
|
|
conn net.Conn
|
|
bufconn *bufio.ReadWriter
|
|
|
|
protocolLock sync.RWMutex
|
|
protocols map[string]*proto
|
|
offsets map[MsgCode]*proto
|
|
protoWG sync.WaitGroup
|
|
|
|
err chan error
|
|
pulse chan bool
|
|
}
|
|
|
|
func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger {
|
|
return &messenger{
|
|
conn: conn,
|
|
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
|
|
peer: peer,
|
|
handlers: handlers,
|
|
protocols: make(map[string]*proto),
|
|
err: errchan,
|
|
pulse: make(chan bool, 1),
|
|
}
|
|
}
|
|
|
|
func (m *messenger) Start() {
|
|
m.protocols[""] = m.startProto(0, "", &baseProtocol{})
|
|
go m.readLoop()
|
|
}
|
|
|
|
func (m *messenger) Stop() {
|
|
m.conn.Close()
|
|
m.protoWG.Wait()
|
|
}
|
|
|
|
const (
|
|
// maximum amount of time allowed for reading a message
|
|
msgReadTimeout = 5 * time.Second
|
|
|
|
// messages smaller than this many bytes will be read at
|
|
// once before passing them to a protocol.
|
|
wholePayloadSize = 64 * 1024
|
|
)
|
|
|
|
func (m *messenger) readLoop() {
|
|
defer m.closeProtocols()
|
|
for {
|
|
m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
|
|
msg, err := readMsg(m.bufconn)
|
|
if err != nil {
|
|
m.err <- err
|
|
return
|
|
}
|
|
// send ping to heartbeat channel signalling time of last message
|
|
m.pulse <- true
|
|
proto, err := m.getProto(msg.Code)
|
|
if err != nil {
|
|
m.err <- err
|
|
return
|
|
}
|
|
if msg.Size <= wholePayloadSize {
|
|
// optimization: msg is small enough, read all
|
|
// of it and move on to the next message
|
|
buf, err := ioutil.ReadAll(msg.Payload)
|
|
if err != nil {
|
|
m.err <- err
|
|
return
|
|
}
|
|
msg.Payload = bytes.NewReader(buf)
|
|
proto.in <- msg
|
|
} else {
|
|
pr := &eofSignal{msg.Payload, make(chan struct{})}
|
|
msg.Payload = pr
|
|
proto.in <- msg
|
|
<-pr.eof
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *messenger) closeProtocols() {
|
|
m.protocolLock.RLock()
|
|
for _, p := range m.protocols {
|
|
close(p.in)
|
|
}
|
|
m.protocolLock.RUnlock()
|
|
}
|
|
|
|
func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto {
|
|
proto := &proto{
|
|
in: make(chan Msg),
|
|
offset: offset,
|
|
maxcode: impl.Offset(),
|
|
messenger: m,
|
|
}
|
|
m.protoWG.Add(1)
|
|
go func() {
|
|
if err := impl.Start(m.peer, proto); err != nil && err != io.EOF {
|
|
logger.Errorf("protocol %q error: %v\n", name, err)
|
|
m.err <- err
|
|
}
|
|
m.protoWG.Done()
|
|
}()
|
|
return proto
|
|
}
|
|
|
|
// getProto finds the protocol responsible for handling
|
|
// the given message code.
|
|
func (m *messenger) getProto(code MsgCode) (*proto, error) {
|
|
m.protocolLock.RLock()
|
|
defer m.protocolLock.RUnlock()
|
|
for _, proto := range m.protocols {
|
|
if code >= proto.offset && code < proto.offset+proto.maxcode {
|
|
return proto, nil
|
|
}
|
|
}
|
|
return nil, NewPeerError(InvalidMsgCode, "%d", code)
|
|
}
|
|
|
|
// setProtocols starts all subprotocols shared with the
|
|
// remote peer. the protocols must be sorted alphabetically.
|
|
func (m *messenger) setRemoteProtocols(protocols []string) {
|
|
m.protocolLock.Lock()
|
|
defer m.protocolLock.Unlock()
|
|
offset := baseProtocolOffset
|
|
for _, name := range protocols {
|
|
inst, ok := m.handlers[name]
|
|
if !ok {
|
|
continue // not handled
|
|
}
|
|
m.protocols[name] = m.startProto(offset, name, inst)
|
|
offset += inst.Offset()
|
|
}
|
|
}
|
|
|
|
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
|
func (m *messenger) writeProtoMsg(protoName string, msg Msg) error {
|
|
m.protocolLock.RLock()
|
|
proto, ok := m.protocols[protoName]
|
|
m.protocolLock.RUnlock()
|
|
if !ok {
|
|
return fmt.Errorf("protocol %s not handled by peer", protoName)
|
|
}
|
|
if msg.Code >= proto.maxcode {
|
|
return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
|
|
}
|
|
msg.Code += proto.offset
|
|
return m.writeMsg(msg)
|
|
}
|
|
|
|
// writeMsg writes a message to the connection.
|
|
func (m *messenger) writeMsg(msg Msg) error {
|
|
m.writeMu.Lock()
|
|
defer m.writeMu.Unlock()
|
|
if err := writeMsg(m.bufconn, msg); err != nil {
|
|
return err
|
|
}
|
|
return m.bufconn.Flush()
|
|
}
|