p2p: use package rlp to encode messages

Message encoding functions have been renamed to catch any uses.
The switch to the new encoder can cause subtle incompatibilities.
If there are any users outside of our tree, they will at least be
alerted that there was a change.

NewMsg no longer exists. The replacements for EncodeMsg are called
Send and SendItems.
This commit is contained in:
Felix Lange 2015-03-19 15:11:02 +01:00
parent 4811f460e7
commit 5ba51594c7
8 changed files with 64 additions and 68 deletions

View File

@ -92,7 +92,7 @@ func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (
return nil, errors.New("node ID in protocol handshake does not match encryption handshake") return nil, errors.New("node ID in protocol handshake does not match encryption handshake")
} }
// TODO: validate that handshake node ID matches // TODO: validate that handshake node ID matches
if err := writeProtocolHandshake(rw, our); err != nil { if err := Send(rw, handshakeMsg, our); err != nil {
return nil, fmt.Errorf("protocol write error: %v", err) return nil, fmt.Errorf("protocol write error: %v", err)
} }
return &conn{rw, rhs}, nil return &conn{rw, rhs}, nil
@ -106,7 +106,7 @@ func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake,
// Run the protocol handshake using authenticated messages. // Run the protocol handshake using authenticated messages.
rw := newRlpxFrameRW(fd, secrets) rw := newRlpxFrameRW(fd, secrets)
if err := writeProtocolHandshake(rw, our); err != nil { if err := Send(rw, handshakeMsg, our); err != nil {
return nil, fmt.Errorf("protocol write error: %v", err) return nil, fmt.Errorf("protocol write error: %v", err)
} }
rhs, err := readProtocolHandshake(rw, our) rhs, err := readProtocolHandshake(rw, our)
@ -398,10 +398,6 @@ func xor(one, other []byte) (xor []byte) {
return xor return xor
} }
func writeProtocolHandshake(w MsgWriter, our *protoHandshake) error {
return EncodeMsg(w, handshakeMsg, our.Version, our.Name, our.Caps, our.ListenPort, our.ID[:])
}
func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) { func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
// read and handle remote handshake // read and handle remote handshake
msg, err := r.ReadMsg() msg, err := r.ReadMsg()

View File

@ -11,7 +11,6 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -28,13 +27,7 @@ type Msg struct {
Payload io.Reader Payload io.Reader
} }
// NewMsg creates an RLP-encoded message with the given code. // Decode parses the RLP content of a message into
func NewMsg(code uint64, params ...interface{}) Msg {
p := bytes.NewReader(common.Encode(params))
return Msg{Code: code, Size: uint32(p.Len()), Payload: p}
}
// Decode parse the RLP content of a message into
// the given value, which must be a pointer. // the given value, which must be a pointer.
// //
// For the decoding rules, please see package rlp. // For the decoding rules, please see package rlp.
@ -76,10 +69,27 @@ type MsgReadWriter interface {
MsgWriter MsgWriter
} }
// EncodeMsg writes an RLP-encoded message with the given code and // Send writes an RLP-encoded message with the given code.
// data elements. // data should encode as an RLP list.
func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error { func Send(w MsgWriter, msgcode uint64, data interface{}) error {
return w.WriteMsg(NewMsg(code, data...)) size, r, err := rlp.EncodeToReader(data)
if err != nil {
return err
}
return w.WriteMsg(Msg{Code: msgcode, Size: uint32(size), Payload: r})
}
// SendItems writes an RLP with the given code and data elements.
// For a call such as:
//
// SendItems(w, code, e1, e2, e3)
//
// the message payload will be an RLP list containing the items:
//
// [e1, e2, e3]
//
func SendItems(w MsgWriter, msgcode uint64, elems ...interface{}) error {
return Send(w, msgcode, elems)
} }
// netWrapper wraps a MsgReadWriter with locks around // netWrapper wraps a MsgReadWriter with locks around

View File

@ -5,33 +5,17 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
) )
func TestNewMsg(t *testing.T) {
msg := NewMsg(3, 1, "000")
if msg.Code != 3 {
t.Errorf("incorrect code %d, want %d", msg.Code)
}
expect := unhex("c50183303030")
if msg.Size != uint32(len(expect)) {
t.Errorf("incorrect size %d, want %d", msg.Size, len(expect))
}
pl, _ := ioutil.ReadAll(msg.Payload)
if !bytes.Equal(pl, expect) {
t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
}
}
func ExampleMsgPipe() { func ExampleMsgPipe() {
rw1, rw2 := MsgPipe() rw1, rw2 := MsgPipe()
go func() { go func() {
EncodeMsg(rw1, 8, []byte{0, 0}) Send(rw1, 8, [][]byte{{0, 0}})
EncodeMsg(rw1, 5, []byte{1, 1}) Send(rw1, 5, [][]byte{{1, 1}})
rw1.Close() rw1.Close()
}() }()
@ -40,7 +24,7 @@ func ExampleMsgPipe() {
if err != nil { if err != nil {
break break
} }
var data [1][]byte var data [][]byte
msg.Decode(&data) msg.Decode(&data)
fmt.Printf("msg: %d, %x\n", msg.Code, data[0]) fmt.Printf("msg: %d, %x\n", msg.Code, data[0])
} }
@ -55,7 +39,7 @@ loop:
rw1, rw2 := MsgPipe() rw1, rw2 := MsgPipe()
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
if err := EncodeMsg(rw1, 1); err == nil { if err := SendItems(rw1, 1); err == nil {
t.Error("EncodeMsg returned nil error") t.Error("EncodeMsg returned nil error")
} else if err != ErrPipeClosed { } else if err != ErrPipeClosed {
t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed) t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed)

View File

@ -132,7 +132,7 @@ loop:
select { select {
case <-ping.C: case <-ping.C:
go func() { go func() {
if err := EncodeMsg(p.rw, pingMsg, nil); err != nil { if err := SendItems(p.rw, pingMsg); err != nil {
p.protoErr <- err p.protoErr <- err
return return
} }
@ -161,7 +161,7 @@ loop:
func (p *Peer) politeDisconnect(reason DiscReason) { func (p *Peer) politeDisconnect(reason DiscReason) {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
EncodeMsg(p.rw, discMsg, uint(reason)) SendItems(p.rw, discMsg, uint(reason))
// Wait for the other side to close the connection. // Wait for the other side to close the connection.
// Discard any data that they send until then. // Discard any data that they send until then.
io.Copy(ioutil.Discard, p.conn) io.Copy(ioutil.Discard, p.conn)
@ -192,7 +192,7 @@ func (p *Peer) handle(msg Msg) error {
switch { switch {
case msg.Code == pingMsg: case msg.Code == pingMsg:
msg.Discard() msg.Discard()
go EncodeMsg(p.rw, pongMsg) go SendItems(p.rw, pongMsg)
case msg.Code == discMsg: case msg.Code == discMsg:
var reason [1]DiscReason var reason [1]DiscReason
// no need to discard or for error checking, we'll close the // no need to discard or for error checking, we'll close the

View File

@ -4,13 +4,10 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/rlp"
) )
var discard = Protocol{ var discard = Protocol{
@ -55,13 +52,13 @@ func TestPeerProtoReadMsg(t *testing.T) {
Name: "a", Name: "a",
Length: 5, Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error { Run: func(peer *Peer, rw MsgReadWriter) error {
if err := expectMsg(rw, 2, []uint{1}); err != nil { if err := ExpectMsg(rw, 2, []uint{1}); err != nil {
t.Error(err) t.Error(err)
} }
if err := expectMsg(rw, 3, []uint{2}); err != nil { if err := ExpectMsg(rw, 3, []uint{2}); err != nil {
t.Error(err) t.Error(err)
} }
if err := expectMsg(rw, 4, []uint{3}); err != nil { if err := ExpectMsg(rw, 4, []uint{3}); err != nil {
t.Error(err) t.Error(err)
} }
close(done) close(done)
@ -72,9 +69,9 @@ func TestPeerProtoReadMsg(t *testing.T) {
closer, rw, _, errc := testPeer([]Protocol{proto}) closer, rw, _, errc := testPeer([]Protocol{proto})
defer closer.Close() defer closer.Close()
EncodeMsg(rw, baseProtocolLength+2, 1) Send(rw, baseProtocolLength+2, []uint{1})
EncodeMsg(rw, baseProtocolLength+3, 2) Send(rw, baseProtocolLength+3, []uint{2})
EncodeMsg(rw, baseProtocolLength+4, 3) Send(rw, baseProtocolLength+4, []uint{3})
select { select {
case <-done: case <-done:
@ -92,10 +89,10 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
Name: "a", Name: "a",
Length: 2, Length: 2,
Run: func(peer *Peer, rw MsgReadWriter) error { Run: func(peer *Peer, rw MsgReadWriter) error {
if err := EncodeMsg(rw, 2); err == nil { if err := SendItems(rw, 2); err == nil {
t.Error("expected error for out-of-range msg code, got nil") t.Error("expected error for out-of-range msg code, got nil")
} }
if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil { if err := SendItems(rw, 1, "foo", "bar"); err != nil {
t.Errorf("write error: %v", err) t.Errorf("write error: %v", err)
} }
return nil return nil
@ -104,7 +101,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
closer, rw, _, _ := testPeer([]Protocol{proto}) closer, rw, _, _ := testPeer([]Protocol{proto})
defer closer.Close() defer closer.Close()
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil { if err := ExpectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
t.Error(err) t.Error(err)
} }
} }
@ -115,11 +112,15 @@ func TestPeerWriteForBroadcast(t *testing.T) {
closer, rw, peer, peerErr := testPeer([]Protocol{discard}) closer, rw, peer, peerErr := testPeer([]Protocol{discard})
defer closer.Close() defer closer.Close()
emptymsg := func(code uint64) Msg {
return Msg{Code: code, Size: 0, Payload: bytes.NewReader(nil)}
}
// test write errors // test write errors
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { if err := peer.writeProtoMsg("b", emptymsg(3)); err == nil {
t.Errorf("expected error for unknown protocol, got nil") t.Errorf("expected error for unknown protocol, got nil")
} }
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil { if err := peer.writeProtoMsg("discard", emptymsg(8)); err == nil {
t.Errorf("expected error for out-of-range msg code, got nil") t.Errorf("expected error for out-of-range msg code, got nil")
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode { } else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
t.Errorf("wrong error for out-of-range msg code, got %#v", err) t.Errorf("wrong error for out-of-range msg code, got %#v", err)
@ -128,14 +129,14 @@ func TestPeerWriteForBroadcast(t *testing.T) {
// setup for reading the message on the other end // setup for reading the message on the other end
read := make(chan struct{}) read := make(chan struct{})
go func() { go func() {
if err := expectMsg(rw, 16, nil); err != nil { if err := ExpectMsg(rw, 16, nil); err != nil {
t.Error(err) t.Error(err)
} }
close(read) close(read)
}() }()
// test successful write // test successful write
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { if err := peer.writeProtoMsg("discard", emptymsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err) t.Errorf("expect no error for known protocol: %v", err)
} }
select { select {
@ -150,10 +151,10 @@ func TestPeerPing(t *testing.T) {
closer, rw, _, _ := testPeer(nil) closer, rw, _, _ := testPeer(nil)
defer closer.Close() defer closer.Close()
if err := EncodeMsg(rw, pingMsg); err != nil { if err := SendItems(rw, pingMsg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := expectMsg(rw, pongMsg, nil); err != nil { if err := ExpectMsg(rw, pongMsg, nil); err != nil {
t.Error(err) t.Error(err)
} }
} }
@ -163,10 +164,10 @@ func TestPeerDisconnect(t *testing.T) {
closer, rw, _, disc := testPeer(nil) closer, rw, _, disc := testPeer(nil)
defer closer.Close() defer closer.Close()
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil { if err := SendItems(rw, discMsg, DiscQuitting); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil { if err := ExpectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
t.Error(err) t.Error(err)
} }
closer.Close() // make test end faster closer.Close() // make test end faster

View File

@ -30,7 +30,7 @@ ba628a4ba590cb43f7848f41c4382885
`) `)
// Check WriteMsg. This puts a message into the buffer. // Check WriteMsg. This puts a message into the buffer.
if err := EncodeMsg(rw, 8, 1, 2, 3, 4); err != nil { if err := Send(rw, 8, []uint{1, 2, 3, 4}); err != nil {
t.Fatalf("WriteMsg error: %v", err) t.Fatalf("WriteMsg error: %v", err)
} }
written := buf.Bytes() written := buf.Bytes()
@ -102,7 +102,7 @@ func TestRlpxFrameRW(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
// write message into conn buffer // write message into conn buffer
wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)} wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)}
err := EncodeMsg(rw1, uint64(i), wmsg...) err := Send(rw1, uint64(i), wmsg)
if err != nil { if err != nil {
t.Fatalf("WriteMsg error (i=%d): %v", i, err) t.Fatalf("WriteMsg error (i=%d): %v", i, err)
} }

View File

@ -9,10 +9,10 @@ import (
"sync" "sync"
"time" "time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/rlp"
) )
const ( const (
@ -129,10 +129,14 @@ func (srv *Server) SuggestPeer(n *discover.Node) {
// Broadcast sends an RLP-encoded message to all connected peers. // Broadcast sends an RLP-encoded message to all connected peers.
// This method is deprecated and will be removed later. // This method is deprecated and will be removed later.
func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) { func (srv *Server) Broadcast(protocol string, code uint64, data interface{}) error {
var payload []byte var payload []byte
if data != nil { if data != nil {
payload = common.Encode(data) var err error
payload, err = rlp.EncodeToBytes(data)
if err != nil {
return err
}
} }
srv.lock.RLock() srv.lock.RLock()
defer srv.lock.RUnlock() defer srv.lock.RUnlock()
@ -146,6 +150,7 @@ func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{})
peer.writeProtoMsg(protocol, msg) peer.writeProtoMsg(protocol, msg)
} }
} }
return nil
} }
// Start starts running the server. // Start starts running the server.

View File

@ -149,7 +149,7 @@ func TestServerBroadcast(t *testing.T) {
connected.Wait() connected.Wait()
// broadcast one message // broadcast one message
srv.Broadcast("discard", 0, "foo") srv.Broadcast("discard", 0, []string{"foo"})
golden := unhex("66e94d166f0a2c3b884cfa59ca34") golden := unhex("66e94d166f0a2c3b884cfa59ca34")
// check that the message has been written everywhere // check that the message has been written everywhere