p2p: use package rlp to encode messages

Message encoding functions have been renamed to catch any uses.
The switch to the new encoder can cause subtle incompatibilities.
If there are any users outside of our tree, they will at least be
alerted that there was a change.

NewMsg no longer exists. The replacements for EncodeMsg are called
Send and SendItems.
This commit is contained in:
Felix Lange 2015-03-19 15:11:02 +01:00
parent 4811f460e7
commit 5ba51594c7
8 changed files with 64 additions and 68 deletions

View File

@ -92,7 +92,7 @@ func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (
return nil, errors.New("node ID in protocol handshake does not match encryption handshake")
}
// TODO: validate that handshake node ID matches
if err := writeProtocolHandshake(rw, our); err != nil {
if err := Send(rw, handshakeMsg, our); err != nil {
return nil, fmt.Errorf("protocol write error: %v", err)
}
return &conn{rw, rhs}, nil
@ -106,7 +106,7 @@ func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake,
// Run the protocol handshake using authenticated messages.
rw := newRlpxFrameRW(fd, secrets)
if err := writeProtocolHandshake(rw, our); err != nil {
if err := Send(rw, handshakeMsg, our); err != nil {
return nil, fmt.Errorf("protocol write error: %v", err)
}
rhs, err := readProtocolHandshake(rw, our)
@ -398,10 +398,6 @@ func xor(one, other []byte) (xor []byte) {
return xor
}
func writeProtocolHandshake(w MsgWriter, our *protoHandshake) error {
return EncodeMsg(w, handshakeMsg, our.Version, our.Name, our.Caps, our.ListenPort, our.ID[:])
}
func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
// read and handle remote handshake
msg, err := r.ReadMsg()

View File

@ -11,7 +11,6 @@ import (
"sync/atomic"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
)
@ -28,13 +27,7 @@ type Msg struct {
Payload io.Reader
}
// NewMsg creates an RLP-encoded message with the given code.
func NewMsg(code uint64, params ...interface{}) Msg {
p := bytes.NewReader(common.Encode(params))
return Msg{Code: code, Size: uint32(p.Len()), Payload: p}
}
// Decode parse the RLP content of a message into
// Decode parses the RLP content of a message into
// the given value, which must be a pointer.
//
// For the decoding rules, please see package rlp.
@ -76,10 +69,27 @@ type MsgReadWriter interface {
MsgWriter
}
// EncodeMsg writes an RLP-encoded message with the given code and
// data elements.
func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
return w.WriteMsg(NewMsg(code, data...))
// Send writes an RLP-encoded message with the given code.
// data should encode as an RLP list.
func Send(w MsgWriter, msgcode uint64, data interface{}) error {
size, r, err := rlp.EncodeToReader(data)
if err != nil {
return err
}
return w.WriteMsg(Msg{Code: msgcode, Size: uint32(size), Payload: r})
}
// SendItems writes an RLP with the given code and data elements.
// For a call such as:
//
// SendItems(w, code, e1, e2, e3)
//
// the message payload will be an RLP list containing the items:
//
// [e1, e2, e3]
//
func SendItems(w MsgWriter, msgcode uint64, elems ...interface{}) error {
return Send(w, msgcode, elems)
}
// netWrapper wraps a MsgReadWriter with locks around

View File

@ -5,33 +5,17 @@ import (
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"runtime"
"strings"
"testing"
"time"
)
func TestNewMsg(t *testing.T) {
msg := NewMsg(3, 1, "000")
if msg.Code != 3 {
t.Errorf("incorrect code %d, want %d", msg.Code)
}
expect := unhex("c50183303030")
if msg.Size != uint32(len(expect)) {
t.Errorf("incorrect size %d, want %d", msg.Size, len(expect))
}
pl, _ := ioutil.ReadAll(msg.Payload)
if !bytes.Equal(pl, expect) {
t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
}
}
func ExampleMsgPipe() {
rw1, rw2 := MsgPipe()
go func() {
EncodeMsg(rw1, 8, []byte{0, 0})
EncodeMsg(rw1, 5, []byte{1, 1})
Send(rw1, 8, [][]byte{{0, 0}})
Send(rw1, 5, [][]byte{{1, 1}})
rw1.Close()
}()
@ -40,7 +24,7 @@ func ExampleMsgPipe() {
if err != nil {
break
}
var data [1][]byte
var data [][]byte
msg.Decode(&data)
fmt.Printf("msg: %d, %x\n", msg.Code, data[0])
}
@ -55,7 +39,7 @@ loop:
rw1, rw2 := MsgPipe()
done := make(chan struct{})
go func() {
if err := EncodeMsg(rw1, 1); err == nil {
if err := SendItems(rw1, 1); err == nil {
t.Error("EncodeMsg returned nil error")
} else if err != ErrPipeClosed {
t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed)

View File

@ -132,7 +132,7 @@ loop:
select {
case <-ping.C:
go func() {
if err := EncodeMsg(p.rw, pingMsg, nil); err != nil {
if err := SendItems(p.rw, pingMsg); err != nil {
p.protoErr <- err
return
}
@ -161,7 +161,7 @@ loop:
func (p *Peer) politeDisconnect(reason DiscReason) {
done := make(chan struct{})
go func() {
EncodeMsg(p.rw, discMsg, uint(reason))
SendItems(p.rw, discMsg, uint(reason))
// Wait for the other side to close the connection.
// Discard any data that they send until then.
io.Copy(ioutil.Discard, p.conn)
@ -192,7 +192,7 @@ func (p *Peer) handle(msg Msg) error {
switch {
case msg.Code == pingMsg:
msg.Discard()
go EncodeMsg(p.rw, pongMsg)
go SendItems(p.rw, pongMsg)
case msg.Code == discMsg:
var reason [1]DiscReason
// no need to discard or for error checking, we'll close the

View File

@ -4,13 +4,10 @@ import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"reflect"
"testing"
"time"
"github.com/ethereum/go-ethereum/rlp"
)
var discard = Protocol{
@ -55,13 +52,13 @@ func TestPeerProtoReadMsg(t *testing.T) {
Name: "a",
Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error {
if err := expectMsg(rw, 2, []uint{1}); err != nil {
if err := ExpectMsg(rw, 2, []uint{1}); err != nil {
t.Error(err)
}
if err := expectMsg(rw, 3, []uint{2}); err != nil {
if err := ExpectMsg(rw, 3, []uint{2}); err != nil {
t.Error(err)
}
if err := expectMsg(rw, 4, []uint{3}); err != nil {
if err := ExpectMsg(rw, 4, []uint{3}); err != nil {
t.Error(err)
}
close(done)
@ -72,9 +69,9 @@ func TestPeerProtoReadMsg(t *testing.T) {
closer, rw, _, errc := testPeer([]Protocol{proto})
defer closer.Close()
EncodeMsg(rw, baseProtocolLength+2, 1)
EncodeMsg(rw, baseProtocolLength+3, 2)
EncodeMsg(rw, baseProtocolLength+4, 3)
Send(rw, baseProtocolLength+2, []uint{1})
Send(rw, baseProtocolLength+3, []uint{2})
Send(rw, baseProtocolLength+4, []uint{3})
select {
case <-done:
@ -92,10 +89,10 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
Name: "a",
Length: 2,
Run: func(peer *Peer, rw MsgReadWriter) error {
if err := EncodeMsg(rw, 2); err == nil {
if err := SendItems(rw, 2); err == nil {
t.Error("expected error for out-of-range msg code, got nil")
}
if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil {
if err := SendItems(rw, 1, "foo", "bar"); err != nil {
t.Errorf("write error: %v", err)
}
return nil
@ -104,7 +101,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
closer, rw, _, _ := testPeer([]Protocol{proto})
defer closer.Close()
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
if err := ExpectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
t.Error(err)
}
}
@ -115,11 +112,15 @@ func TestPeerWriteForBroadcast(t *testing.T) {
closer, rw, peer, peerErr := testPeer([]Protocol{discard})
defer closer.Close()
emptymsg := func(code uint64) Msg {
return Msg{Code: code, Size: 0, Payload: bytes.NewReader(nil)}
}
// test write errors
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
if err := peer.writeProtoMsg("b", emptymsg(3)); err == nil {
t.Errorf("expected error for unknown protocol, got nil")
}
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
if err := peer.writeProtoMsg("discard", emptymsg(8)); err == nil {
t.Errorf("expected error for out-of-range msg code, got nil")
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
@ -128,14 +129,14 @@ func TestPeerWriteForBroadcast(t *testing.T) {
// setup for reading the message on the other end
read := make(chan struct{})
go func() {
if err := expectMsg(rw, 16, nil); err != nil {
if err := ExpectMsg(rw, 16, nil); err != nil {
t.Error(err)
}
close(read)
}()
// test successful write
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
if err := peer.writeProtoMsg("discard", emptymsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err)
}
select {
@ -150,10 +151,10 @@ func TestPeerPing(t *testing.T) {
closer, rw, _, _ := testPeer(nil)
defer closer.Close()
if err := EncodeMsg(rw, pingMsg); err != nil {
if err := SendItems(rw, pingMsg); err != nil {
t.Fatal(err)
}
if err := expectMsg(rw, pongMsg, nil); err != nil {
if err := ExpectMsg(rw, pongMsg, nil); err != nil {
t.Error(err)
}
}
@ -163,10 +164,10 @@ func TestPeerDisconnect(t *testing.T) {
closer, rw, _, disc := testPeer(nil)
defer closer.Close()
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
if err := SendItems(rw, discMsg, DiscQuitting); err != nil {
t.Fatal(err)
}
if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
if err := ExpectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
t.Error(err)
}
closer.Close() // make test end faster

View File

@ -30,7 +30,7 @@ ba628a4ba590cb43f7848f41c4382885
`)
// Check WriteMsg. This puts a message into the buffer.
if err := EncodeMsg(rw, 8, 1, 2, 3, 4); err != nil {
if err := Send(rw, 8, []uint{1, 2, 3, 4}); err != nil {
t.Fatalf("WriteMsg error: %v", err)
}
written := buf.Bytes()
@ -102,7 +102,7 @@ func TestRlpxFrameRW(t *testing.T) {
for i := 0; i < 10; i++ {
// write message into conn buffer
wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)}
err := EncodeMsg(rw1, uint64(i), wmsg...)
err := Send(rw1, uint64(i), wmsg)
if err != nil {
t.Fatalf("WriteMsg error (i=%d): %v", i, err)
}

View File

@ -9,10 +9,10 @@ import (
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/rlp"
)
const (
@ -129,10 +129,14 @@ func (srv *Server) SuggestPeer(n *discover.Node) {
// Broadcast sends an RLP-encoded message to all connected peers.
// This method is deprecated and will be removed later.
func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
func (srv *Server) Broadcast(protocol string, code uint64, data interface{}) error {
var payload []byte
if data != nil {
payload = common.Encode(data)
var err error
payload, err = rlp.EncodeToBytes(data)
if err != nil {
return err
}
}
srv.lock.RLock()
defer srv.lock.RUnlock()
@ -146,6 +150,7 @@ func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{})
peer.writeProtoMsg(protocol, msg)
}
}
return nil
}
// Start starts running the server.

View File

@ -149,7 +149,7 @@ func TestServerBroadcast(t *testing.T) {
connected.Wait()
// broadcast one message
srv.Broadcast("discard", 0, "foo")
srv.Broadcast("discard", 0, []string{"foo"})
golden := unhex("66e94d166f0a2c3b884cfa59ca34")
// check that the message has been written everywhere