p2p: improve and test eofSignal

This commit is contained in:
Felix Lange 2014-12-12 11:38:42 +01:00
parent 9423401d73
commit e28c60caf9
2 changed files with 68 additions and 5 deletions

View File

@ -300,7 +300,7 @@ func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error)
proto.in <- msg proto.in <- msg
} else { } else {
wait = true wait = true
pr := &eofSignal{msg.Payload, protoDone} pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
msg.Payload = pr msg.Payload = pr
proto.in <- msg proto.in <- msg
} }
@ -438,18 +438,25 @@ func (rw *proto) ReadMsg() (Msg, error) {
return msg, nil return msg, nil
} }
// eofSignal wraps a reader with eof signaling. // eofSignal wraps a reader with eof signaling. the eof channel is
// the eof channel is closed when the wrapped reader // closed when the wrapped reader returns an error or when count bytes
// reaches EOF. // have been read.
//
type eofSignal struct { type eofSignal struct {
wrapped io.Reader wrapped io.Reader
count int64
eof chan<- struct{} eof chan<- struct{}
} }
// note: when using eofSignal to detect whether a message payload
// has been read, Read might not be called for zero sized messages.
func (r *eofSignal) Read(buf []byte) (int, error) { func (r *eofSignal) Read(buf []byte) (int, error) {
n, err := r.wrapped.Read(buf) n, err := r.wrapped.Read(buf)
if err != nil { r.count -= int64(n)
if (err != nil || r.count <= 0) && r.eof != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed r.eof <- struct{}{} // tell Peer that msg has been consumed
r.eof = nil
} }
return n, err return n, err
} }

View File

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"io"
"io/ioutil" "io/ioutil"
"net" "net"
"reflect" "reflect"
@ -237,3 +238,58 @@ func TestNewPeer(t *testing.T) {
// Should not hang. // Should not hang.
p.Disconnect(DiscAlreadyConnected) p.Disconnect(DiscAlreadyConnected)
} }
func TestEOFSignal(t *testing.T) {
rb := make([]byte, 10)
// empty reader
eof := make(chan struct{}, 1)
sig := &eofSignal{new(bytes.Buffer), 0, eof}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// count before error
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
if n, err := sig.Read(rb); n != 8 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// error before count
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// no signal if neither occurs
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
if n, err := sig.Read(rb); n != 10 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
t.Error("unexpected EOF signal")
default:
}
}