forked from cerc-io/plugeth
p2p: fix issues found during review
This commit is contained in:
parent
f38052c499
commit
7149191dd9
@ -98,7 +98,7 @@ type byteReader interface {
|
|||||||
io.ByteReader
|
io.ByteReader
|
||||||
}
|
}
|
||||||
|
|
||||||
// readMsg reads a message header.
|
// readMsg reads a message header from r.
|
||||||
func readMsg(r byteReader) (msg Msg, err error) {
|
func readMsg(r byteReader) (msg Msg, err error) {
|
||||||
// read magic and payload size
|
// read magic and payload size
|
||||||
start := make([]byte, 8)
|
start := make([]byte, 8)
|
||||||
|
@ -11,7 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Handlers map[string]func() Protocol
|
type Handlers map[string]Protocol
|
||||||
|
|
||||||
type proto struct {
|
type proto struct {
|
||||||
in chan Msg
|
in chan Msg
|
||||||
@ -23,6 +23,7 @@ func (rw *proto) WriteMsg(msg Msg) error {
|
|||||||
if msg.Code >= rw.maxcode {
|
if msg.Code >= rw.maxcode {
|
||||||
return NewPeerError(InvalidMsgCode, "not handled")
|
return NewPeerError(InvalidMsgCode, "not handled")
|
||||||
}
|
}
|
||||||
|
msg.Code += rw.offset
|
||||||
return rw.messenger.writeMsg(msg)
|
return rw.messenger.writeMsg(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -31,12 +32,13 @@ func (rw *proto) ReadMsg() (Msg, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return msg, io.EOF
|
return msg, io.EOF
|
||||||
}
|
}
|
||||||
|
msg.Code -= rw.offset
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// eofSignal is used to 'lend' the network connection
|
// eofSignal wraps a reader with eof signaling.
|
||||||
// to a protocol. when the protocol's read loop has read the
|
// the eof channel is closed when the wrapped reader
|
||||||
// whole payload, the done channel is closed.
|
// reaches EOF.
|
||||||
type eofSignal struct {
|
type eofSignal struct {
|
||||||
wrapped io.Reader
|
wrapped io.Reader
|
||||||
eof chan struct{}
|
eof chan struct{}
|
||||||
@ -119,7 +121,6 @@ func (m *messenger) readLoop() {
|
|||||||
m.err <- err
|
m.err <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
msg.Code -= proto.offset
|
|
||||||
if msg.Size <= wholePayloadSize {
|
if msg.Size <= wholePayloadSize {
|
||||||
// optimization: msg is small enough, read all
|
// optimization: msg is small enough, read all
|
||||||
// of it and move on to the next message
|
// of it and move on to the next message
|
||||||
@ -185,11 +186,10 @@ func (m *messenger) setRemoteProtocols(protocols []string) {
|
|||||||
defer m.protocolLock.Unlock()
|
defer m.protocolLock.Unlock()
|
||||||
offset := baseProtocolOffset
|
offset := baseProtocolOffset
|
||||||
for _, name := range protocols {
|
for _, name := range protocols {
|
||||||
protocolFunc, ok := m.handlers[name]
|
inst, ok := m.handlers[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
continue // not handled
|
continue // not handled
|
||||||
}
|
}
|
||||||
inst := protocolFunc()
|
|
||||||
m.protocols[name] = m.startProto(offset, name, inst)
|
m.protocols[name] = m.startProto(offset, name, inst)
|
||||||
offset += inst.Offset()
|
offset += inst.Offset()
|
||||||
}
|
}
|
||||||
|
@ -11,14 +11,14 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/ethutil"
|
logpkg "github.com/ethereum/go-ethereum/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
ethlog.AddLogSystem(ethlog.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlog.DebugLevel))
|
logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel))
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
|
func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
|
||||||
conn1, conn2 := net.Pipe()
|
conn1, conn2 := net.Pipe()
|
||||||
id := NewSimpleClientIdentity("test", "0", "0", "public key")
|
id := NewSimpleClientIdentity("test", "0", "0", "public key")
|
||||||
server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist())
|
server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist())
|
||||||
@ -33,7 +33,7 @@ func performTestHandshake(r *bufio.Reader, w io.Writer) error {
|
|||||||
return fmt.Errorf("read error: %v", err)
|
return fmt.Errorf("read error: %v", err)
|
||||||
}
|
}
|
||||||
if msg.Code != handshakeMsg {
|
if msg.Code != handshakeMsg {
|
||||||
return fmt.Errorf("first message should be handshake, got %x", msg.Code)
|
return fmt.Errorf("first message should be handshake, got %d", msg.Code)
|
||||||
}
|
}
|
||||||
if err := msg.Discard(); err != nil {
|
if err := msg.Discard(); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -44,56 +44,102 @@ func performTestHandshake(r *bufio.Reader, w io.Writer) error {
|
|||||||
return writeMsg(w, msg)
|
return writeMsg(w, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
type testMsg struct {
|
type testProtocol struct {
|
||||||
code MsgCode
|
offset MsgCode
|
||||||
data *ethutil.Value
|
f func(MsgReadWriter)
|
||||||
}
|
}
|
||||||
|
|
||||||
type testProto struct {
|
func (p *testProtocol) Offset() MsgCode {
|
||||||
recv chan testMsg
|
return p.offset
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*testProto) Offset() MsgCode { return 5 }
|
func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error {
|
||||||
|
p.f(rw)
|
||||||
func (tp *testProto) Start(peer *Peer, rw MsgReadWriter) error {
|
|
||||||
return MsgLoop(rw, 1024, func(code MsgCode, data *ethutil.Value) error {
|
|
||||||
logger.Debugf("testprotocol got msg: %d\n", code)
|
|
||||||
tp.recv <- testMsg{code, data}
|
|
||||||
return nil
|
return nil
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRead(t *testing.T) {
|
func TestRead(t *testing.T) {
|
||||||
testProtocol := &testProto{make(chan testMsg)}
|
done := make(chan struct{})
|
||||||
handlers := Handlers{"a": func() Protocol { return testProtocol }}
|
handlers := Handlers{
|
||||||
net, peer, mess := setupMessenger(handlers)
|
"a": &testProtocol{5, func(rw MsgReadWriter) {
|
||||||
bufr := bufio.NewReader(net)
|
msg, err := rw.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("read error: %v", err)
|
||||||
|
}
|
||||||
|
if msg.Code != 2 {
|
||||||
|
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
|
||||||
|
}
|
||||||
|
data, err := msg.Data()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("data decoding error: %v", err)
|
||||||
|
}
|
||||||
|
expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
|
||||||
|
if !reflect.DeepEqual(data.Slice(), expdata) {
|
||||||
|
t.Errorf("incorrect msg data %#v", data.Slice())
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
net, peer, m := testMessenger(handlers)
|
||||||
defer peer.Stop()
|
defer peer.Stop()
|
||||||
|
bufr := bufio.NewReader(net)
|
||||||
if err := performTestHandshake(bufr, net); err != nil {
|
if err := performTestHandshake(bufr, net); err != nil {
|
||||||
t.Fatalf("handshake failed: %v", err)
|
t.Fatalf("handshake failed: %v", err)
|
||||||
}
|
}
|
||||||
|
m.setRemoteProtocols([]string{"a"})
|
||||||
|
|
||||||
mess.setRemoteProtocols([]string{"a"})
|
writeMsg(net, NewMsg(18, 1, "000"))
|
||||||
writeMsg(net, NewMsg(17, uint32(1), "000"))
|
|
||||||
select {
|
select {
|
||||||
case msg := <-testProtocol.recv:
|
case <-done:
|
||||||
if msg.code != 1 {
|
|
||||||
t.Errorf("incorrect msg code %d relayed to protocol", msg.code)
|
|
||||||
}
|
|
||||||
expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
|
|
||||||
if !reflect.DeepEqual(msg.data.Slice(), expdata) {
|
|
||||||
t.Errorf("incorrect msg data %#v", msg.data.Slice())
|
|
||||||
}
|
|
||||||
case <-time.After(2 * time.Second):
|
case <-time.After(2 * time.Second):
|
||||||
t.Errorf("receive timeout")
|
t.Errorf("receive timeout")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteProtoMsg(t *testing.T) {
|
func TestWriteFromProto(t *testing.T) {
|
||||||
handlers := make(Handlers)
|
handlers := Handlers{
|
||||||
testProtocol := &testProto{recv: make(chan testMsg, 1)}
|
"a": &testProtocol{2, func(rw MsgReadWriter) {
|
||||||
handlers["a"] = func() Protocol { return testProtocol }
|
if err := rw.WriteMsg(NewMsg(2)); err == nil {
|
||||||
net, peer, mess := setupMessenger(handlers)
|
t.Error("expected error for out-of-range msg code, got nil")
|
||||||
|
}
|
||||||
|
if err := rw.WriteMsg(NewMsg(1)); err != nil {
|
||||||
|
t.Errorf("write error: %v", err)
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
net, peer, mess := testMessenger(handlers)
|
||||||
|
defer peer.Stop()
|
||||||
|
bufr := bufio.NewReader(net)
|
||||||
|
if err := performTestHandshake(bufr, net); err != nil {
|
||||||
|
t.Fatalf("handshake failed: %v", err)
|
||||||
|
}
|
||||||
|
mess.setRemoteProtocols([]string{"a"})
|
||||||
|
|
||||||
|
msg, err := readMsg(bufr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("read error: %v")
|
||||||
|
}
|
||||||
|
if msg.Code != 17 {
|
||||||
|
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var discardProto = &testProtocol{1, func(rw MsgReadWriter) {
|
||||||
|
for {
|
||||||
|
msg, err := rw.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = msg.Discard(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
|
||||||
|
func TestMessengerWriteProtoMsg(t *testing.T) {
|
||||||
|
handlers := Handlers{"a": discardProto}
|
||||||
|
net, peer, mess := testMessenger(handlers)
|
||||||
defer peer.Stop()
|
defer peer.Stop()
|
||||||
bufr := bufio.NewReader(net)
|
bufr := bufio.NewReader(net)
|
||||||
if err := performTestHandshake(bufr, net); err != nil {
|
if err := performTestHandshake(bufr, net); err != nil {
|
||||||
@ -120,13 +166,13 @@ func TestWriteProtoMsg(t *testing.T) {
|
|||||||
read <- msg
|
read <- msg
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if err := mess.writeProtoMsg("a", NewMsg(3)); err != nil {
|
if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil {
|
||||||
t.Errorf("expect no error for known protocol: %v", err)
|
t.Errorf("expect no error for known protocol: %v", err)
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case msg := <-read:
|
case msg := <-read:
|
||||||
if msg.Code != 19 {
|
if msg.Code != 16 {
|
||||||
t.Errorf("wrong code, got %d, expected %d", msg.Code, 19)
|
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
|
||||||
}
|
}
|
||||||
msg.Discard()
|
msg.Discard()
|
||||||
case err := <-readerr:
|
case err := <-readerr:
|
||||||
@ -135,7 +181,7 @@ func TestWriteProtoMsg(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPulse(t *testing.T) {
|
func TestPulse(t *testing.T) {
|
||||||
net, peer, _ := setupMessenger(nil)
|
net, peer, _ := testMessenger(nil)
|
||||||
defer peer.Stop()
|
defer peer.Stop()
|
||||||
bufr := bufio.NewReader(net)
|
bufr := bufio.NewReader(net)
|
||||||
if err := performTestHandshake(bufr, net); err != nil {
|
if err := performTestHandshake(bufr, net); err != nil {
|
||||||
@ -149,7 +195,7 @@ func TestPulse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
after := time.Now()
|
after := time.Now()
|
||||||
if msg.Code != pingMsg {
|
if msg.Code != pingMsg {
|
||||||
t.Errorf("expected ping message, got %x", msg.Code)
|
t.Errorf("expected ping message, got %d", msg.Code)
|
||||||
}
|
}
|
||||||
if d := after.Sub(before); d < pingTimeout {
|
if d := after.Sub(before); d < pingTimeout {
|
||||||
t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout)
|
t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout)
|
||||||
|
@ -143,9 +143,6 @@ func (d DiscReason) String() string {
|
|||||||
return discReasonToString[d]
|
return discReasonToString[d]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bp *baseProtocol) Ping() {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *baseProtocol) Offset() MsgCode {
|
func (bp *baseProtocol) Offset() MsgCode {
|
||||||
return baseProtocolOffset
|
return baseProtocolOffset
|
||||||
}
|
}
|
||||||
@ -287,7 +284,7 @@ func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
|
|||||||
|
|
||||||
// self connect detection
|
// self connect detection
|
||||||
if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 {
|
if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 {
|
||||||
return NewPeerError(PubkeyForbidden, "not allowed to connect to bp")
|
return NewPeerError(PubkeyForbidden, "not allowed to connect to self")
|
||||||
}
|
}
|
||||||
|
|
||||||
// register pubkey on server. this also sets the pubkey on the peer (need lock)
|
// register pubkey on server. this also sets the pubkey on the peer (need lock)
|
||||||
|
Loading…
Reference in New Issue
Block a user