411 lines
9.3 KiB
Go
411 lines
9.3 KiB
Go
package dht
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/libp2p/go-libp2p-core/helpers"
|
|
"github.com/libp2p/go-libp2p-core/network"
|
|
"github.com/libp2p/go-libp2p-core/peer"
|
|
|
|
"github.com/libp2p/go-libp2p-kad-dht/metrics"
|
|
pb "github.com/libp2p/go-libp2p-kad-dht/pb"
|
|
|
|
ggio "github.com/gogo/protobuf/io"
|
|
|
|
"go.opencensus.io/stats"
|
|
"go.opencensus.io/tag"
|
|
)
|
|
|
|
var dhtReadMessageTimeout = time.Minute
|
|
var dhtStreamIdleTimeout = 10 * time.Minute
|
|
var ErrReadTimeout = fmt.Errorf("timed out reading response")
|
|
|
|
// The Protobuf writer performs multiple small writes when writing a message.
|
|
// We need to buffer those writes, to make sure that we're not sending a new
|
|
// packet for every single write.
|
|
type bufferedDelimitedWriter struct {
|
|
*bufio.Writer
|
|
ggio.WriteCloser
|
|
}
|
|
|
|
var writerPool = sync.Pool{
|
|
New: func() interface{} {
|
|
w := bufio.NewWriter(nil)
|
|
return &bufferedDelimitedWriter{
|
|
Writer: w,
|
|
WriteCloser: ggio.NewDelimitedWriter(w),
|
|
}
|
|
},
|
|
}
|
|
|
|
func writeMsg(w io.Writer, mes *pb.Message) error {
|
|
bw := writerPool.Get().(*bufferedDelimitedWriter)
|
|
bw.Reset(w)
|
|
err := bw.WriteMsg(mes)
|
|
if err == nil {
|
|
err = bw.Flush()
|
|
}
|
|
bw.Reset(nil)
|
|
writerPool.Put(bw)
|
|
return err
|
|
}
|
|
|
|
func (w *bufferedDelimitedWriter) Flush() error {
|
|
return w.Writer.Flush()
|
|
}
|
|
|
|
// handleNewStream implements the network.StreamHandler
|
|
func (dht *IpfsDHT) handleNewStream(s network.Stream) {
|
|
defer s.Reset()
|
|
if dht.handleNewMessage(s) {
|
|
// Gracefully close the stream for writes.
|
|
s.Close()
|
|
}
|
|
}
|
|
|
|
// Returns true on orderly completion of writes (so we can Close the stream).
|
|
func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool {
|
|
ctx := dht.ctx
|
|
r := ggio.NewDelimitedReader(s, network.MessageSizeMax)
|
|
|
|
mPeer := s.Conn().RemotePeer()
|
|
|
|
timer := time.AfterFunc(dhtStreamIdleTimeout, func() { s.Reset() })
|
|
defer timer.Stop()
|
|
|
|
for {
|
|
var req pb.Message
|
|
switch err := r.ReadMsg(&req); err {
|
|
case io.EOF:
|
|
return true
|
|
default:
|
|
// This string test is necessary because there isn't a single stream reset error
|
|
// instance in use.
|
|
if err.Error() != "stream reset" {
|
|
logger.Debugf("error reading message: %#v", err)
|
|
}
|
|
stats.RecordWithTags(
|
|
ctx,
|
|
[]tag.Mutator{tag.Upsert(metrics.KeyMessageType, "UNKNOWN")},
|
|
metrics.ReceivedMessageErrors.M(1),
|
|
)
|
|
return false
|
|
case nil:
|
|
}
|
|
|
|
timer.Reset(dhtStreamIdleTimeout)
|
|
|
|
startTime := time.Now()
|
|
ctx, _ = tag.New(
|
|
ctx,
|
|
tag.Upsert(metrics.KeyMessageType, req.GetType().String()),
|
|
)
|
|
|
|
stats.Record(
|
|
ctx,
|
|
metrics.ReceivedMessages.M(1),
|
|
metrics.ReceivedBytes.M(int64(req.Size())),
|
|
)
|
|
|
|
handler := dht.handlerForMsgType(req.GetType())
|
|
if handler == nil {
|
|
stats.Record(ctx, metrics.ReceivedMessageErrors.M(1))
|
|
logger.Warningf("can't handle received message of type %v", req.GetType())
|
|
return false
|
|
}
|
|
|
|
resp, err := handler(ctx, mPeer, &req)
|
|
if err != nil {
|
|
stats.Record(ctx, metrics.ReceivedMessageErrors.M(1))
|
|
logger.Debugf("error handling message: %v", err)
|
|
return false
|
|
}
|
|
|
|
dht.updateFromMessage(ctx, mPeer, &req)
|
|
|
|
if resp == nil {
|
|
continue
|
|
}
|
|
|
|
// send out response msg
|
|
err = writeMsg(s, resp)
|
|
if err != nil {
|
|
stats.Record(ctx, metrics.ReceivedMessageErrors.M(1))
|
|
logger.Debugf("error writing response: %v", err)
|
|
return false
|
|
}
|
|
|
|
elapsedTime := time.Since(startTime)
|
|
latencyMillis := float64(elapsedTime) / float64(time.Millisecond)
|
|
stats.Record(ctx, metrics.InboundRequestLatency.M(latencyMillis))
|
|
}
|
|
}
|
|
|
|
// sendRequest sends out a request, but also makes sure to
|
|
// measure the RTT for latency measurements.
|
|
func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {
|
|
ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes))
|
|
|
|
ms, err := dht.messageSenderForPeer(ctx, p)
|
|
if err != nil {
|
|
stats.Record(ctx, metrics.SentRequestErrors.M(1))
|
|
return nil, err
|
|
}
|
|
|
|
start := time.Now()
|
|
|
|
rpmes, err := ms.SendRequest(ctx, pmes)
|
|
if err != nil {
|
|
stats.Record(ctx, metrics.SentRequestErrors.M(1))
|
|
return nil, err
|
|
}
|
|
|
|
// update the peer (on valid msgs only)
|
|
dht.updateFromMessage(ctx, p, rpmes)
|
|
|
|
stats.Record(
|
|
ctx,
|
|
metrics.SentRequests.M(1),
|
|
metrics.SentBytes.M(int64(pmes.Size())),
|
|
metrics.OutboundRequestLatency.M(
|
|
float64(time.Since(start))/float64(time.Millisecond),
|
|
),
|
|
)
|
|
dht.peerstore.RecordLatency(p, time.Since(start))
|
|
logger.Event(ctx, "dhtReceivedMessage", dht.self, p, rpmes)
|
|
return rpmes, nil
|
|
}
|
|
|
|
// sendMessage sends out a message
|
|
func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error {
|
|
ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes))
|
|
|
|
ms, err := dht.messageSenderForPeer(ctx, p)
|
|
if err != nil {
|
|
stats.Record(ctx, metrics.SentMessageErrors.M(1))
|
|
return err
|
|
}
|
|
|
|
if err := ms.SendMessage(ctx, pmes); err != nil {
|
|
stats.Record(ctx, metrics.SentMessageErrors.M(1))
|
|
return err
|
|
}
|
|
|
|
stats.Record(
|
|
ctx,
|
|
metrics.SentMessages.M(1),
|
|
metrics.SentBytes.M(int64(pmes.Size())),
|
|
)
|
|
logger.Event(ctx, "dhtSentMessage", dht.self, p, pmes)
|
|
return nil
|
|
}
|
|
|
|
func (dht *IpfsDHT) updateFromMessage(ctx context.Context, p peer.ID, mes *pb.Message) error {
|
|
// Make sure that this node is actually a DHT server, not just a client.
|
|
protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...)
|
|
if err == nil && len(protos) > 0 {
|
|
dht.Update(ctx, p)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) {
|
|
dht.smlk.Lock()
|
|
ms, ok := dht.strmap[p]
|
|
if ok {
|
|
dht.smlk.Unlock()
|
|
return ms, nil
|
|
}
|
|
ms = &messageSender{p: p, dht: dht}
|
|
dht.strmap[p] = ms
|
|
dht.smlk.Unlock()
|
|
|
|
if err := ms.prepOrInvalidate(ctx); err != nil {
|
|
dht.smlk.Lock()
|
|
defer dht.smlk.Unlock()
|
|
|
|
if msCur, ok := dht.strmap[p]; ok {
|
|
// Changed. Use the new one, old one is invalid and
|
|
// not in the map so we can just throw it away.
|
|
if ms != msCur {
|
|
return msCur, nil
|
|
}
|
|
// Not changed, remove the now invalid stream from the
|
|
// map.
|
|
delete(dht.strmap, p)
|
|
}
|
|
// Invalid but not in map. Must have been removed by a disconnect.
|
|
return nil, err
|
|
}
|
|
// All ready to go.
|
|
return ms, nil
|
|
}
|
|
|
|
type messageSender struct {
|
|
s network.Stream
|
|
r ggio.ReadCloser
|
|
lk sync.Mutex
|
|
p peer.ID
|
|
dht *IpfsDHT
|
|
|
|
invalid bool
|
|
singleMes int
|
|
}
|
|
|
|
// invalidate is called before this messageSender is removed from the strmap.
|
|
// It prevents the messageSender from being reused/reinitialized and then
|
|
// forgotten (leaving the stream open).
|
|
func (ms *messageSender) invalidate() {
|
|
ms.invalid = true
|
|
if ms.s != nil {
|
|
ms.s.Reset()
|
|
ms.s = nil
|
|
}
|
|
}
|
|
|
|
func (ms *messageSender) prepOrInvalidate(ctx context.Context) error {
|
|
ms.lk.Lock()
|
|
defer ms.lk.Unlock()
|
|
if err := ms.prep(ctx); err != nil {
|
|
ms.invalidate()
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ms *messageSender) prep(ctx context.Context) error {
|
|
if ms.invalid {
|
|
return fmt.Errorf("message sender has been invalidated")
|
|
}
|
|
if ms.s != nil {
|
|
return nil
|
|
}
|
|
|
|
nstr, err := ms.dht.host.NewStream(ctx, ms.p, ms.dht.protocols...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ms.r = ggio.NewDelimitedReader(nstr, network.MessageSizeMax)
|
|
ms.s = nstr
|
|
|
|
return nil
|
|
}
|
|
|
|
// streamReuseTries is the number of times we will try to reuse a stream to a
|
|
// given peer before giving up and reverting to the old one-message-per-stream
|
|
// behaviour.
|
|
const streamReuseTries = 3
|
|
|
|
func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error {
|
|
ms.lk.Lock()
|
|
defer ms.lk.Unlock()
|
|
retry := false
|
|
for {
|
|
if err := ms.prep(ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := ms.writeMsg(pmes); err != nil {
|
|
ms.s.Reset()
|
|
ms.s = nil
|
|
|
|
if retry {
|
|
logger.Info("error writing message, bailing: ", err)
|
|
return err
|
|
}
|
|
logger.Info("error writing message, trying again: ", err)
|
|
retry = true
|
|
continue
|
|
}
|
|
|
|
logger.Event(ctx, "dhtSentMessage", ms.dht.self, ms.p, pmes)
|
|
|
|
if ms.singleMes > streamReuseTries {
|
|
go helpers.FullClose(ms.s)
|
|
ms.s = nil
|
|
} else if retry {
|
|
ms.singleMes++
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) {
|
|
ms.lk.Lock()
|
|
defer ms.lk.Unlock()
|
|
retry := false
|
|
for {
|
|
if err := ms.prep(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := ms.writeMsg(pmes); err != nil {
|
|
ms.s.Reset()
|
|
ms.s = nil
|
|
|
|
if retry {
|
|
logger.Info("error writing message, bailing: ", err)
|
|
return nil, err
|
|
}
|
|
logger.Info("error writing message, trying again: ", err)
|
|
retry = true
|
|
continue
|
|
}
|
|
|
|
mes := new(pb.Message)
|
|
if err := ms.ctxReadMsg(ctx, mes); err != nil {
|
|
ms.s.Reset()
|
|
ms.s = nil
|
|
|
|
if retry {
|
|
logger.Info("error reading message, bailing: ", err)
|
|
return nil, err
|
|
}
|
|
logger.Info("error reading message, trying again: ", err)
|
|
retry = true
|
|
continue
|
|
}
|
|
|
|
logger.Event(ctx, "dhtSentMessage", ms.dht.self, ms.p, pmes)
|
|
|
|
if ms.singleMes > streamReuseTries {
|
|
go helpers.FullClose(ms.s)
|
|
ms.s = nil
|
|
} else if retry {
|
|
ms.singleMes++
|
|
}
|
|
|
|
return mes, nil
|
|
}
|
|
}
|
|
|
|
func (ms *messageSender) writeMsg(pmes *pb.Message) error {
|
|
return writeMsg(ms.s, pmes)
|
|
}
|
|
|
|
func (ms *messageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error {
|
|
errc := make(chan error, 1)
|
|
go func(r ggio.ReadCloser) {
|
|
errc <- r.ReadMsg(mes)
|
|
}(ms.r)
|
|
|
|
t := time.NewTimer(dhtReadMessageTimeout)
|
|
defer t.Stop()
|
|
|
|
select {
|
|
case err := <-errc:
|
|
return err
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-t.C:
|
|
return ErrReadTimeout
|
|
}
|
|
}
|