Merge pull request #1261 from fjl/p2p-no-writes-at-shutdown

p2p: prevent writes at shutdown time
This commit is contained in:
Jeffrey Wilcke 2015-06-15 08:21:30 -07:00
commit f475a01326
2 changed files with 61 additions and 28 deletions

View File

@ -115,41 +115,60 @@ func newPeer(conn *conn, protocols []Protocol) *Peer {
} }
func (p *Peer) run() DiscReason { func (p *Peer) run() DiscReason {
readErr := make(chan error, 1) var (
writeStart = make(chan struct{}, 1)
writeErr = make(chan error, 1)
readErr = make(chan error, 1)
reason DiscReason
requested bool
)
p.wg.Add(2) p.wg.Add(2)
go p.readLoop(readErr) go p.readLoop(readErr)
go p.pingLoop() go p.pingLoop()
p.startProtocols() // Start all protocol handlers.
writeStart <- struct{}{}
p.startProtocols(writeStart, writeErr)
// Wait for an error or disconnect. // Wait for an error or disconnect.
var ( loop:
reason DiscReason for {
requested bool
)
select { select {
case err := <-writeErr:
// A write finished. Allow the next write to start if
// there was no error.
if err != nil {
glog.V(logger.Detail).Infof("%v: write error: %v\n", p, err)
reason = DiscNetworkError
break loop
}
writeStart <- struct{}{}
case err := <-readErr: case err := <-readErr:
if r, ok := err.(DiscReason); ok { if r, ok := err.(DiscReason); ok {
glog.V(logger.Debug).Infof("%v: remote requested disconnect: %v\n", p, r)
requested = true
reason = r reason = r
} else { } else {
// Note: We rely on protocols to abort if there is a write glog.V(logger.Detail).Infof("%v: read error: %v\n", p, err)
// error. It might be more robust to handle them here as well.
glog.V(logger.Detail).Infof("%v: Read error: %v\n", p, err)
reason = DiscNetworkError reason = DiscNetworkError
} }
break loop
case err := <-p.protoErr: case err := <-p.protoErr:
reason = discReasonForError(err) reason = discReasonForError(err)
glog.V(logger.Debug).Infof("%v: protocol error: %v (%v)\n", p, err, reason)
break loop
case reason = <-p.disc: case reason = <-p.disc:
requested = true glog.V(logger.Debug).Infof("%v: locally requested disconnect: %v\n", p, reason)
break loop
} }
}
close(p.closed) close(p.closed)
p.rw.close(reason) p.rw.close(reason)
p.wg.Wait() p.wg.Wait()
if requested { if requested {
reason = DiscRequested reason = DiscRequested
} }
glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
return reason return reason
} }
@ -196,7 +215,6 @@ func (p *Peer) handle(msg Msg) error {
// This is the last message. We don't need to discard or // This is the last message. We don't need to discard or
// check errors because, the connection will be closed after it. // check errors because, the connection will be closed after it.
rlp.Decode(msg.Payload, &reason) rlp.Decode(msg.Payload, &reason)
glog.V(logger.Debug).Infof("%v: Disconnect Requested: %v\n", p, reason[0])
return reason[0] return reason[0]
case msg.Code < baseProtocolLength: case msg.Code < baseProtocolLength:
// ignore other base protocol messages // ignore other base protocol messages
@ -247,11 +265,13 @@ outer:
return result return result
} }
func (p *Peer) startProtocols() { func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error) {
p.wg.Add(len(p.running)) p.wg.Add(len(p.running))
for _, proto := range p.running { for _, proto := range p.running {
proto := proto proto := proto
proto.closed = p.closed proto.closed = p.closed
proto.wstart = writeStart
proto.werr = writeErr
glog.V(logger.Detail).Infof("%v: Starting protocol %s/%d\n", p, proto.Name, proto.Version) glog.V(logger.Detail).Infof("%v: Starting protocol %s/%d\n", p, proto.Name, proto.Version)
go func() { go func() {
err := proto.Run(p, proto) err := proto.Run(p, proto)
@ -280,18 +300,31 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) {
type protoRW struct { type protoRW struct {
Protocol Protocol
in chan Msg in chan Msg // receices read messages
closed <-chan struct{} closed <-chan struct{} // receives when peer is shutting down
wstart <-chan struct{} // receives when write may start
werr chan<- error // for write results
offset uint64 offset uint64
w MsgWriter w MsgWriter
} }
func (rw *protoRW) WriteMsg(msg Msg) error { func (rw *protoRW) WriteMsg(msg Msg) (err error) {
if msg.Code >= rw.Length { if msg.Code >= rw.Length {
return newPeerError(errInvalidMsgCode, "not handled") return newPeerError(errInvalidMsgCode, "not handled")
} }
msg.Code += rw.offset msg.Code += rw.offset
return rw.w.WriteMsg(msg) select {
case <-rw.wstart:
err = rw.w.WriteMsg(msg)
// Report write status back to Peer.run. It will initiate
// shutdown if the error is non-nil and unblock the next write
// otherwise. The calling protocol code should exit for errors
// as well but we don't want to rely on that.
rw.werr <- err
case <-rw.closed:
err = fmt.Errorf("shutting down")
}
return err
} }
func (rw *protoRW) ReadMsg() (Msg, error) { func (rw *protoRW) ReadMsg() (Msg, error) {

View File

@ -121,7 +121,7 @@ func TestPeerDisconnect(t *testing.T) {
} }
select { select {
case reason := <-disc: case reason := <-disc:
if reason != DiscQuitting { if reason != DiscRequested {
t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested) t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
} }
case <-time.After(500 * time.Millisecond): case <-time.After(500 * time.Millisecond):