p2p: fix issues found during review

This commit is contained in:
Felix Lange 2014-11-05 01:28:46 +01:00
parent f38052c499
commit 7149191dd9
4 changed files with 96 additions and 53 deletions

View File

@ -98,7 +98,7 @@ type byteReader interface {
io.ByteReader io.ByteReader
} }
// readMsg reads a message header. // readMsg reads a message header from r.
func readMsg(r byteReader) (msg Msg, err error) { func readMsg(r byteReader) (msg Msg, err error) {
// read magic and payload size // read magic and payload size
start := make([]byte, 8) start := make([]byte, 8)

View File

@ -11,7 +11,7 @@ import (
"time" "time"
) )
type Handlers map[string]func() Protocol type Handlers map[string]Protocol
type proto struct { type proto struct {
in chan Msg in chan Msg
@ -23,6 +23,7 @@ func (rw *proto) WriteMsg(msg Msg) error {
if msg.Code >= rw.maxcode { if msg.Code >= rw.maxcode {
return NewPeerError(InvalidMsgCode, "not handled") return NewPeerError(InvalidMsgCode, "not handled")
} }
msg.Code += rw.offset
return rw.messenger.writeMsg(msg) return rw.messenger.writeMsg(msg)
} }
@ -31,12 +32,13 @@ func (rw *proto) ReadMsg() (Msg, error) {
if !ok { if !ok {
return msg, io.EOF return msg, io.EOF
} }
msg.Code -= rw.offset
return msg, nil return msg, nil
} }
// eofSignal is used to 'lend' the network connection // eofSignal wraps a reader with eof signaling.
// to a protocol. when the protocol's read loop has read the // the eof channel is closed when the wrapped reader
// whole payload, the done channel is closed. // reaches EOF.
type eofSignal struct { type eofSignal struct {
wrapped io.Reader wrapped io.Reader
eof chan struct{} eof chan struct{}
@ -119,7 +121,6 @@ func (m *messenger) readLoop() {
m.err <- err m.err <- err
return return
} }
msg.Code -= proto.offset
if msg.Size <= wholePayloadSize { if msg.Size <= wholePayloadSize {
// optimization: msg is small enough, read all // optimization: msg is small enough, read all
// of it and move on to the next message // of it and move on to the next message
@ -185,11 +186,10 @@ func (m *messenger) setRemoteProtocols(protocols []string) {
defer m.protocolLock.Unlock() defer m.protocolLock.Unlock()
offset := baseProtocolOffset offset := baseProtocolOffset
for _, name := range protocols { for _, name := range protocols {
protocolFunc, ok := m.handlers[name] inst, ok := m.handlers[name]
if !ok { if !ok {
continue // not handled continue // not handled
} }
inst := protocolFunc()
m.protocols[name] = m.startProto(offset, name, inst) m.protocols[name] = m.startProto(offset, name, inst)
offset += inst.Offset() offset += inst.Offset()
} }

View File

@ -11,14 +11,14 @@ import (
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/ethutil" logpkg "github.com/ethereum/go-ethereum/logger"
) )
func init() { func init() {
ethlog.AddLogSystem(ethlog.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlog.DebugLevel)) logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel))
} }
func setupMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) { func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
conn1, conn2 := net.Pipe() conn1, conn2 := net.Pipe()
id := NewSimpleClientIdentity("test", "0", "0", "public key") id := NewSimpleClientIdentity("test", "0", "0", "public key")
server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist()) server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist())
@ -33,7 +33,7 @@ func performTestHandshake(r *bufio.Reader, w io.Writer) error {
return fmt.Errorf("read error: %v", err) return fmt.Errorf("read error: %v", err)
} }
if msg.Code != handshakeMsg { if msg.Code != handshakeMsg {
return fmt.Errorf("first message should be handshake, got %x", msg.Code) return fmt.Errorf("first message should be handshake, got %d", msg.Code)
} }
if err := msg.Discard(); err != nil { if err := msg.Discard(); err != nil {
return err return err
@ -44,56 +44,102 @@ func performTestHandshake(r *bufio.Reader, w io.Writer) error {
return writeMsg(w, msg) return writeMsg(w, msg)
} }
type testMsg struct { type testProtocol struct {
code MsgCode offset MsgCode
data *ethutil.Value f func(MsgReadWriter)
} }
type testProto struct { func (p *testProtocol) Offset() MsgCode {
recv chan testMsg return p.offset
} }
func (*testProto) Offset() MsgCode { return 5 } func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error {
p.f(rw)
func (tp *testProto) Start(peer *Peer, rw MsgReadWriter) error {
return MsgLoop(rw, 1024, func(code MsgCode, data *ethutil.Value) error {
logger.Debugf("testprotocol got msg: %d\n", code)
tp.recv <- testMsg{code, data}
return nil return nil
})
} }
func TestRead(t *testing.T) { func TestRead(t *testing.T) {
testProtocol := &testProto{make(chan testMsg)} done := make(chan struct{})
handlers := Handlers{"a": func() Protocol { return testProtocol }} handlers := Handlers{
net, peer, mess := setupMessenger(handlers) "a": &testProtocol{5, func(rw MsgReadWriter) {
bufr := bufio.NewReader(net) msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 2 {
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
}
data, err := msg.Data()
if err != nil {
t.Errorf("data decoding error: %v", err)
}
expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
if !reflect.DeepEqual(data.Slice(), expdata) {
t.Errorf("incorrect msg data %#v", data.Slice())
}
close(done)
}},
}
net, peer, m := testMessenger(handlers)
defer peer.Stop() defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil { if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err) t.Fatalf("handshake failed: %v", err)
} }
m.setRemoteProtocols([]string{"a"})
mess.setRemoteProtocols([]string{"a"}) writeMsg(net, NewMsg(18, 1, "000"))
writeMsg(net, NewMsg(17, uint32(1), "000"))
select { select {
case msg := <-testProtocol.recv: case <-done:
if msg.code != 1 {
t.Errorf("incorrect msg code %d relayed to protocol", msg.code)
}
expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
if !reflect.DeepEqual(msg.data.Slice(), expdata) {
t.Errorf("incorrect msg data %#v", msg.data.Slice())
}
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):
t.Errorf("receive timeout") t.Errorf("receive timeout")
} }
} }
func TestWriteProtoMsg(t *testing.T) { func TestWriteFromProto(t *testing.T) {
handlers := make(Handlers) handlers := Handlers{
testProtocol := &testProto{recv: make(chan testMsg, 1)} "a": &testProtocol{2, func(rw MsgReadWriter) {
handlers["a"] = func() Protocol { return testProtocol } if err := rw.WriteMsg(NewMsg(2)); err == nil {
net, peer, mess := setupMessenger(handlers) t.Error("expected error for out-of-range msg code, got nil")
}
if err := rw.WriteMsg(NewMsg(1)); err != nil {
t.Errorf("write error: %v", err)
}
}},
}
net, peer, mess := testMessenger(handlers)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
mess.setRemoteProtocols([]string{"a"})
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v")
}
if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
}
}
var discardProto = &testProtocol{1, func(rw MsgReadWriter) {
for {
msg, err := rw.ReadMsg()
if err != nil {
return
}
if err = msg.Discard(); err != nil {
return
}
}
}}
func TestMessengerWriteProtoMsg(t *testing.T) {
handlers := Handlers{"a": discardProto}
net, peer, mess := testMessenger(handlers)
defer peer.Stop() defer peer.Stop()
bufr := bufio.NewReader(net) bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil { if err := performTestHandshake(bufr, net); err != nil {
@ -120,13 +166,13 @@ func TestWriteProtoMsg(t *testing.T) {
read <- msg read <- msg
} }
}() }()
if err := mess.writeProtoMsg("a", NewMsg(3)); err != nil { if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err) t.Errorf("expect no error for known protocol: %v", err)
} }
select { select {
case msg := <-read: case msg := <-read:
if msg.Code != 19 { if msg.Code != 16 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 19) t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
} }
msg.Discard() msg.Discard()
case err := <-readerr: case err := <-readerr:
@ -135,7 +181,7 @@ func TestWriteProtoMsg(t *testing.T) {
} }
func TestPulse(t *testing.T) { func TestPulse(t *testing.T) {
net, peer, _ := setupMessenger(nil) net, peer, _ := testMessenger(nil)
defer peer.Stop() defer peer.Stop()
bufr := bufio.NewReader(net) bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil { if err := performTestHandshake(bufr, net); err != nil {
@ -149,7 +195,7 @@ func TestPulse(t *testing.T) {
} }
after := time.Now() after := time.Now()
if msg.Code != pingMsg { if msg.Code != pingMsg {
t.Errorf("expected ping message, got %x", msg.Code) t.Errorf("expected ping message, got %d", msg.Code)
} }
if d := after.Sub(before); d < pingTimeout { if d := after.Sub(before); d < pingTimeout {
t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout) t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout)

View File

@ -143,9 +143,6 @@ func (d DiscReason) String() string {
return discReasonToString[d] return discReasonToString[d]
} }
func (bp *baseProtocol) Ping() {
}
func (bp *baseProtocol) Offset() MsgCode { func (bp *baseProtocol) Offset() MsgCode {
return baseProtocolOffset return baseProtocolOffset
} }
@ -287,7 +284,7 @@ func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
// self connect detection // self connect detection
if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 { if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 {
return NewPeerError(PubkeyForbidden, "not allowed to connect to bp") return NewPeerError(PubkeyForbidden, "not allowed to connect to self")
} }
// register pubkey on server. this also sets the pubkey on the peer (need lock) // register pubkey on server. this also sets the pubkey on the peer (need lock)