p2p: use package rlp for baseProtocol

This commit is contained in:
Felix Lange 2014-11-25 12:25:31 +01:00
parent c1fca72552
commit 6049fcd52a
4 changed files with 71 additions and 58 deletions

View File

@ -41,14 +41,22 @@ func encodePayload(params ...interface{}) []byte {
return buf.Bytes() return buf.Bytes()
} }
// Data returns the decoded RLP payload items in a message. // Value returns the decoded RLP payload items in a message.
func (msg Msg) Data() (*ethutil.Value, error) { func (msg Msg) Value() (*ethutil.Value, error) {
s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
var v []interface{} var v []interface{}
err := s.Decode(&v) err := msg.Decode(&v)
return ethutil.NewValue(v), err return ethutil.NewValue(v), err
} }
// Decode parse the RLP content of a message into
// the given value, which must be a pointer.
//
// For the decoding rules, please see package rlp.
func (msg Msg) Decode(val interface{}) error {
s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
return s.Decode(val)
}
// Discard reads any remaining payload data into a black hole. // Discard reads any remaining payload data into a black hole.
func (msg Msg) Discard() error { func (msg Msg) Discard() error {
_, err := io.Copy(ioutil.Discard, msg.Payload) _, err := io.Copy(ioutil.Discard, msg.Payload)
@ -91,7 +99,7 @@ func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Valu
if msg.Size > maxsize { if msg.Size > maxsize {
return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize) return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
} }
value, err := msg.Data() value, err := msg.Value()
if err != nil { if err != nil {
return err return err
} }

View File

@ -42,7 +42,7 @@ func TestEncodeDecodeMsg(t *testing.T) {
if decmsg.Size != 5 { if decmsg.Size != 5 {
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
} }
data, err := decmsg.Data() data, err := decmsg.Value()
if err != nil { if err != nil {
t.Fatalf("first payload item decode error: %v", err) t.Fatalf("first payload item decode error: %v", err)
} }

View File

@ -53,7 +53,7 @@ func TestPeerProtoReadMsg(t *testing.T) {
if msg.Code != 2 { if msg.Code != 2 {
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
} }
data, err := msg.Data() data, err := msg.Value()
if err != nil { if err != nil {
t.Errorf("data decoding error: %v", err) t.Errorf("data decoding error: %v", err)
} }

View File

@ -2,7 +2,6 @@ package p2p
import ( import (
"bytes" "bytes"
"net"
"time" "time"
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
@ -90,30 +89,18 @@ type baseProtocol struct {
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
bp := &baseProtocol{rw, peer} bp := &baseProtocol{rw, peer}
if err := bp.doHandshake(rw); err != nil {
// do handshake
if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
return err return err
} }
msg, err := rw.ReadMsg()
if err != nil {
return err
}
if msg.Code != handshakeMsg {
return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
}
data, err := msg.Data()
if err != nil {
return newPeerError(errInvalidMsg, "%v", err)
}
if err := bp.handleHandshake(data); err != nil {
return err
}
// run main loop // run main loop
quit := make(chan error, 1) quit := make(chan error, 1)
go func() { go func() {
quit <- MsgLoop(rw, baseProtocolMaxMsgSize, bp.handle) for {
if err := bp.handle(rw); err != nil {
quit <- err
break
}
}
}() }()
return bp.loop(quit) return bp.loop(quit)
} }
@ -151,13 +138,27 @@ func (bp *baseProtocol) loop(quit <-chan error) error {
return err return err
} }
func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error { func (bp *baseProtocol) handle(rw MsgReadWriter) error {
switch code { msg, err := rw.ReadMsg()
if err != nil {
return err
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
// make sure that the payload has been fully consumed
defer msg.Discard()
switch msg.Code {
case handshakeMsg: case handshakeMsg:
return newPeerError(errProtocolBreach, "extra handshake received") return newPeerError(errProtocolBreach, "extra handshake received")
case discMsg: case discMsg:
bp.peer.Disconnect(DiscReason(data.Get(0).Uint())) var reason DiscReason
if err := msg.Decode(&reason); err != nil {
return err
}
bp.peer.Disconnect(reason)
return nil return nil
case pingMsg: case pingMsg:
@ -178,35 +179,45 @@ func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error {
} }
case peersMsg: case peersMsg:
bp.handlePeers(data) var peers []*peerAddr
if err := msg.Decode(&peers); err != nil {
return err
}
for _, addr := range peers {
bp.peer.Debugf("received peer suggestion: %v", addr)
bp.peer.newPeerAddr <- addr
}
default: default:
return newPeerError(errInvalidMsgCode, "unknown message code %v", code) return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
} }
return nil return nil
} }
func (bp *baseProtocol) handlePeers(data *ethutil.Value) { func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error {
it := data.NewIterator() // send our handshake
for it.Next() { if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
addr := &peerAddr{ return err
IP: net.IP(it.Value().Get(0).Bytes()),
Port: it.Value().Get(1).Uint(),
Pubkey: it.Value().Get(2).Bytes(),
}
bp.peer.Debugf("received peer suggestion: %v", addr)
bp.peer.newPeerAddr <- addr
} }
}
func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error { // read and handle remote handshake
hs := handshake{ msg, err := rw.ReadMsg()
Version: c.Get(0).Uint(), if err != nil {
ID: c.Get(1).Str(), return err
Caps: nil, // decoded below
ListenPort: c.Get(3).Uint(),
NodeID: c.Get(4).Bytes(),
} }
if msg.Code != handshakeMsg {
return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
var hs handshake
if err := msg.Decode(&hs); err != nil {
return err
}
// validate handshake info
if hs.Version != baseProtocolVersion { if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
baseProtocolVersion, hs.Version) baseProtocolVersion, hs.Version)
@ -228,14 +239,8 @@ func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
if err := bp.peer.pubkeyHook(pa); err != nil { if err := bp.peer.pubkeyHook(pa); err != nil {
return newPeerError(errPubkeyForbidden, "%v", err) return newPeerError(errPubkeyForbidden, "%v", err)
} }
capsIt := c.Get(2).NewIterator()
for capsIt.Next() { // TODO: remove Caps with empty name
cap := capsIt.Value()
name := cap.Get(0).Str()
if name != "" {
hs.Caps = append(hs.Caps, Cap{Name: name, Version: uint(cap.Get(1).Uint())})
}
}
var addr *peerAddr var addr *peerAddr
if hs.ListenPort != 0 { if hs.ListenPort != 0 {