forked from cerc-io/plugeth
p2p: improve and test eofSignal
This commit is contained in:
parent
9423401d73
commit
e28c60caf9
17
p2p/peer.go
17
p2p/peer.go
@ -300,7 +300,7 @@ func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error)
|
|||||||
proto.in <- msg
|
proto.in <- msg
|
||||||
} else {
|
} else {
|
||||||
wait = true
|
wait = true
|
||||||
pr := &eofSignal{msg.Payload, protoDone}
|
pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
|
||||||
msg.Payload = pr
|
msg.Payload = pr
|
||||||
proto.in <- msg
|
proto.in <- msg
|
||||||
}
|
}
|
||||||
@ -438,18 +438,25 @@ func (rw *proto) ReadMsg() (Msg, error) {
|
|||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// eofSignal wraps a reader with eof signaling.
|
// eofSignal wraps a reader with eof signaling. the eof channel is
|
||||||
// the eof channel is closed when the wrapped reader
|
// closed when the wrapped reader returns an error or when count bytes
|
||||||
// reaches EOF.
|
// have been read.
|
||||||
|
//
|
||||||
type eofSignal struct {
|
type eofSignal struct {
|
||||||
wrapped io.Reader
|
wrapped io.Reader
|
||||||
|
count int64
|
||||||
eof chan<- struct{}
|
eof chan<- struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// note: when using eofSignal to detect whether a message payload
|
||||||
|
// has been read, Read might not be called for zero sized messages.
|
||||||
|
|
||||||
func (r *eofSignal) Read(buf []byte) (int, error) {
|
func (r *eofSignal) Read(buf []byte) (int, error) {
|
||||||
n, err := r.wrapped.Read(buf)
|
n, err := r.wrapped.Read(buf)
|
||||||
if err != nil {
|
r.count -= int64(n)
|
||||||
|
if (err != nil || r.count <= 0) && r.eof != nil {
|
||||||
r.eof <- struct{}{} // tell Peer that msg has been consumed
|
r.eof <- struct{}{} // tell Peer that msg has been consumed
|
||||||
|
r.eof = nil
|
||||||
}
|
}
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -237,3 +238,58 @@ func TestNewPeer(t *testing.T) {
|
|||||||
// Should not hang.
|
// Should not hang.
|
||||||
p.Disconnect(DiscAlreadyConnected)
|
p.Disconnect(DiscAlreadyConnected)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEOFSignal(t *testing.T) {
|
||||||
|
rb := make([]byte, 10)
|
||||||
|
|
||||||
|
// empty reader
|
||||||
|
eof := make(chan struct{}, 1)
|
||||||
|
sig := &eofSignal{new(bytes.Buffer), 0, eof}
|
||||||
|
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
||||||
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-eof:
|
||||||
|
default:
|
||||||
|
t.Error("EOF chan not signaled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// count before error
|
||||||
|
eof = make(chan struct{}, 1)
|
||||||
|
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
|
||||||
|
if n, err := sig.Read(rb); n != 8 || err != nil {
|
||||||
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-eof:
|
||||||
|
default:
|
||||||
|
t.Error("EOF chan not signaled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// error before count
|
||||||
|
eof = make(chan struct{}, 1)
|
||||||
|
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
|
||||||
|
if n, err := sig.Read(rb); n != 4 || err != nil {
|
||||||
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||||
|
}
|
||||||
|
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
||||||
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-eof:
|
||||||
|
default:
|
||||||
|
t.Error("EOF chan not signaled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// no signal if neither occurs
|
||||||
|
eof = make(chan struct{}, 1)
|
||||||
|
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
|
||||||
|
if n, err := sig.Read(rb); n != 10 || err != nil {
|
||||||
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-eof:
|
||||||
|
t.Error("unexpected EOF signal")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user