p2p: use package rlp for baseProtocol
This commit is contained in:
parent
c1fca72552
commit
6049fcd52a
@ -41,14 +41,22 @@ func encodePayload(params ...interface{}) []byte {
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// Data returns the decoded RLP payload items in a message.
|
||||
func (msg Msg) Data() (*ethutil.Value, error) {
|
||||
s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
|
||||
// Value returns the decoded RLP payload items in a message.
|
||||
func (msg Msg) Value() (*ethutil.Value, error) {
|
||||
var v []interface{}
|
||||
err := s.Decode(&v)
|
||||
err := msg.Decode(&v)
|
||||
return ethutil.NewValue(v), err
|
||||
}
|
||||
|
||||
// Decode parse the RLP content of a message into
|
||||
// the given value, which must be a pointer.
|
||||
//
|
||||
// For the decoding rules, please see package rlp.
|
||||
func (msg Msg) Decode(val interface{}) error {
|
||||
s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
|
||||
return s.Decode(val)
|
||||
}
|
||||
|
||||
// Discard reads any remaining payload data into a black hole.
|
||||
func (msg Msg) Discard() error {
|
||||
_, err := io.Copy(ioutil.Discard, msg.Payload)
|
||||
@ -91,7 +99,7 @@ func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Valu
|
||||
if msg.Size > maxsize {
|
||||
return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
|
||||
}
|
||||
value, err := msg.Data()
|
||||
value, err := msg.Value()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ func TestEncodeDecodeMsg(t *testing.T) {
|
||||
if decmsg.Size != 5 {
|
||||
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
|
||||
}
|
||||
data, err := decmsg.Data()
|
||||
data, err := decmsg.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("first payload item decode error: %v", err)
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ func TestPeerProtoReadMsg(t *testing.T) {
|
||||
if msg.Code != 2 {
|
||||
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
|
||||
}
|
||||
data, err := msg.Data()
|
||||
data, err := msg.Value()
|
||||
if err != nil {
|
||||
t.Errorf("data decoding error: %v", err)
|
||||
}
|
||||
|
107
p2p/protocol.go
107
p2p/protocol.go
@ -2,7 +2,6 @@ package p2p
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/ethutil"
|
||||
@ -90,30 +89,18 @@ type baseProtocol struct {
|
||||
|
||||
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
|
||||
bp := &baseProtocol{rw, peer}
|
||||
|
||||
// do handshake
|
||||
if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
|
||||
if err := bp.doHandshake(rw); err != nil {
|
||||
return err
|
||||
}
|
||||
msg, err := rw.ReadMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Code != handshakeMsg {
|
||||
return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
|
||||
}
|
||||
data, err := msg.Data()
|
||||
if err != nil {
|
||||
return newPeerError(errInvalidMsg, "%v", err)
|
||||
}
|
||||
if err := bp.handleHandshake(data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// run main loop
|
||||
quit := make(chan error, 1)
|
||||
go func() {
|
||||
quit <- MsgLoop(rw, baseProtocolMaxMsgSize, bp.handle)
|
||||
for {
|
||||
if err := bp.handle(rw); err != nil {
|
||||
quit <- err
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
return bp.loop(quit)
|
||||
}
|
||||
@ -151,13 +138,27 @@ func (bp *baseProtocol) loop(quit <-chan error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error {
|
||||
switch code {
|
||||
func (bp *baseProtocol) handle(rw MsgReadWriter) error {
|
||||
msg, err := rw.ReadMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Size > baseProtocolMaxMsgSize {
|
||||
return newPeerError(errMisc, "message too big")
|
||||
}
|
||||
// make sure that the payload has been fully consumed
|
||||
defer msg.Discard()
|
||||
|
||||
switch msg.Code {
|
||||
case handshakeMsg:
|
||||
return newPeerError(errProtocolBreach, "extra handshake received")
|
||||
|
||||
case discMsg:
|
||||
bp.peer.Disconnect(DiscReason(data.Get(0).Uint()))
|
||||
var reason DiscReason
|
||||
if err := msg.Decode(&reason); err != nil {
|
||||
return err
|
||||
}
|
||||
bp.peer.Disconnect(reason)
|
||||
return nil
|
||||
|
||||
case pingMsg:
|
||||
@ -178,35 +179,45 @@ func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error {
|
||||
}
|
||||
|
||||
case peersMsg:
|
||||
bp.handlePeers(data)
|
||||
var peers []*peerAddr
|
||||
if err := msg.Decode(&peers); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, addr := range peers {
|
||||
bp.peer.Debugf("received peer suggestion: %v", addr)
|
||||
bp.peer.newPeerAddr <- addr
|
||||
}
|
||||
|
||||
default:
|
||||
return newPeerError(errInvalidMsgCode, "unknown message code %v", code)
|
||||
return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bp *baseProtocol) handlePeers(data *ethutil.Value) {
|
||||
it := data.NewIterator()
|
||||
for it.Next() {
|
||||
addr := &peerAddr{
|
||||
IP: net.IP(it.Value().Get(0).Bytes()),
|
||||
Port: it.Value().Get(1).Uint(),
|
||||
Pubkey: it.Value().Get(2).Bytes(),
|
||||
}
|
||||
bp.peer.Debugf("received peer suggestion: %v", addr)
|
||||
bp.peer.newPeerAddr <- addr
|
||||
func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error {
|
||||
// send our handshake
|
||||
if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
|
||||
hs := handshake{
|
||||
Version: c.Get(0).Uint(),
|
||||
ID: c.Get(1).Str(),
|
||||
Caps: nil, // decoded below
|
||||
ListenPort: c.Get(3).Uint(),
|
||||
NodeID: c.Get(4).Bytes(),
|
||||
// read and handle remote handshake
|
||||
msg, err := rw.ReadMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Code != handshakeMsg {
|
||||
return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
|
||||
}
|
||||
if msg.Size > baseProtocolMaxMsgSize {
|
||||
return newPeerError(errMisc, "message too big")
|
||||
}
|
||||
|
||||
var hs handshake
|
||||
if err := msg.Decode(&hs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// validate handshake info
|
||||
if hs.Version != baseProtocolVersion {
|
||||
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
|
||||
baseProtocolVersion, hs.Version)
|
||||
@ -228,14 +239,8 @@ func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
|
||||
if err := bp.peer.pubkeyHook(pa); err != nil {
|
||||
return newPeerError(errPubkeyForbidden, "%v", err)
|
||||
}
|
||||
capsIt := c.Get(2).NewIterator()
|
||||
for capsIt.Next() {
|
||||
cap := capsIt.Value()
|
||||
name := cap.Get(0).Str()
|
||||
if name != "" {
|
||||
hs.Caps = append(hs.Caps, Cap{Name: name, Version: uint(cap.Get(1).Uint())})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: remove Caps with empty name
|
||||
|
||||
var addr *peerAddr
|
||||
if hs.ListenPort != 0 {
|
||||
|
Loading…
Reference in New Issue
Block a user