forked from cerc-io/plugeth
p2p: use package rlp to encode messages
Message encoding functions have been renamed to catch any uses. The switch to the new encoder can cause subtle incompatibilities. If there are any users outside of our tree, they will at least be alerted that there was a change. NewMsg no longer exists. The replacements for EncodeMsg are called Send and SendItems.
This commit is contained in:
parent
4811f460e7
commit
5ba51594c7
@ -92,7 +92,7 @@ func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (
|
|||||||
return nil, errors.New("node ID in protocol handshake does not match encryption handshake")
|
return nil, errors.New("node ID in protocol handshake does not match encryption handshake")
|
||||||
}
|
}
|
||||||
// TODO: validate that handshake node ID matches
|
// TODO: validate that handshake node ID matches
|
||||||
if err := writeProtocolHandshake(rw, our); err != nil {
|
if err := Send(rw, handshakeMsg, our); err != nil {
|
||||||
return nil, fmt.Errorf("protocol write error: %v", err)
|
return nil, fmt.Errorf("protocol write error: %v", err)
|
||||||
}
|
}
|
||||||
return &conn{rw, rhs}, nil
|
return &conn{rw, rhs}, nil
|
||||||
@ -106,7 +106,7 @@ func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake,
|
|||||||
|
|
||||||
// Run the protocol handshake using authenticated messages.
|
// Run the protocol handshake using authenticated messages.
|
||||||
rw := newRlpxFrameRW(fd, secrets)
|
rw := newRlpxFrameRW(fd, secrets)
|
||||||
if err := writeProtocolHandshake(rw, our); err != nil {
|
if err := Send(rw, handshakeMsg, our); err != nil {
|
||||||
return nil, fmt.Errorf("protocol write error: %v", err)
|
return nil, fmt.Errorf("protocol write error: %v", err)
|
||||||
}
|
}
|
||||||
rhs, err := readProtocolHandshake(rw, our)
|
rhs, err := readProtocolHandshake(rw, our)
|
||||||
@ -398,10 +398,6 @@ func xor(one, other []byte) (xor []byte) {
|
|||||||
return xor
|
return xor
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeProtocolHandshake(w MsgWriter, our *protoHandshake) error {
|
|
||||||
return EncodeMsg(w, handshakeMsg, our.Version, our.Name, our.Caps, our.ListenPort, our.ID[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
|
func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
|
||||||
// read and handle remote handshake
|
// read and handle remote handshake
|
||||||
msg, err := r.ReadMsg()
|
msg, err := r.ReadMsg()
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/common"
|
|
||||||
"github.com/ethereum/go-ethereum/rlp"
|
"github.com/ethereum/go-ethereum/rlp"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,13 +27,7 @@ type Msg struct {
|
|||||||
Payload io.Reader
|
Payload io.Reader
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMsg creates an RLP-encoded message with the given code.
|
// Decode parses the RLP content of a message into
|
||||||
func NewMsg(code uint64, params ...interface{}) Msg {
|
|
||||||
p := bytes.NewReader(common.Encode(params))
|
|
||||||
return Msg{Code: code, Size: uint32(p.Len()), Payload: p}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode parse the RLP content of a message into
|
|
||||||
// the given value, which must be a pointer.
|
// the given value, which must be a pointer.
|
||||||
//
|
//
|
||||||
// For the decoding rules, please see package rlp.
|
// For the decoding rules, please see package rlp.
|
||||||
@ -76,13 +69,30 @@ type MsgReadWriter interface {
|
|||||||
MsgWriter
|
MsgWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
// EncodeMsg writes an RLP-encoded message with the given code and
|
// Send writes an RLP-encoded message with the given code.
|
||||||
// data elements.
|
// data should encode as an RLP list.
|
||||||
func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
|
func Send(w MsgWriter, msgcode uint64, data interface{}) error {
|
||||||
return w.WriteMsg(NewMsg(code, data...))
|
size, r, err := rlp.EncodeToReader(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return w.WriteMsg(Msg{Code: msgcode, Size: uint32(size), Payload: r})
|
||||||
}
|
}
|
||||||
|
|
||||||
// netWrapper wrapsa MsgReadWriter with locks around
|
// SendItems writes an RLP with the given code and data elements.
|
||||||
|
// For a call such as:
|
||||||
|
//
|
||||||
|
// SendItems(w, code, e1, e2, e3)
|
||||||
|
//
|
||||||
|
// the message payload will be an RLP list containing the items:
|
||||||
|
//
|
||||||
|
// [e1, e2, e3]
|
||||||
|
//
|
||||||
|
func SendItems(w MsgWriter, msgcode uint64, elems ...interface{}) error {
|
||||||
|
return Send(w, msgcode, elems)
|
||||||
|
}
|
||||||
|
|
||||||
|
// netWrapper wraps a MsgReadWriter with locks around
|
||||||
// ReadMsg/WriteMsg and applies read/write deadlines.
|
// ReadMsg/WriteMsg and applies read/write deadlines.
|
||||||
type netWrapper struct {
|
type netWrapper struct {
|
||||||
rmu, wmu sync.Mutex
|
rmu, wmu sync.Mutex
|
||||||
|
@ -5,33 +5,17 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewMsg(t *testing.T) {
|
|
||||||
msg := NewMsg(3, 1, "000")
|
|
||||||
if msg.Code != 3 {
|
|
||||||
t.Errorf("incorrect code %d, want %d", msg.Code)
|
|
||||||
}
|
|
||||||
expect := unhex("c50183303030")
|
|
||||||
if msg.Size != uint32(len(expect)) {
|
|
||||||
t.Errorf("incorrect size %d, want %d", msg.Size, len(expect))
|
|
||||||
}
|
|
||||||
pl, _ := ioutil.ReadAll(msg.Payload)
|
|
||||||
if !bytes.Equal(pl, expect) {
|
|
||||||
t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExampleMsgPipe() {
|
func ExampleMsgPipe() {
|
||||||
rw1, rw2 := MsgPipe()
|
rw1, rw2 := MsgPipe()
|
||||||
go func() {
|
go func() {
|
||||||
EncodeMsg(rw1, 8, []byte{0, 0})
|
Send(rw1, 8, [][]byte{{0, 0}})
|
||||||
EncodeMsg(rw1, 5, []byte{1, 1})
|
Send(rw1, 5, [][]byte{{1, 1}})
|
||||||
rw1.Close()
|
rw1.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -40,7 +24,7 @@ func ExampleMsgPipe() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
var data [1][]byte
|
var data [][]byte
|
||||||
msg.Decode(&data)
|
msg.Decode(&data)
|
||||||
fmt.Printf("msg: %d, %x\n", msg.Code, data[0])
|
fmt.Printf("msg: %d, %x\n", msg.Code, data[0])
|
||||||
}
|
}
|
||||||
@ -55,7 +39,7 @@ loop:
|
|||||||
rw1, rw2 := MsgPipe()
|
rw1, rw2 := MsgPipe()
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
if err := EncodeMsg(rw1, 1); err == nil {
|
if err := SendItems(rw1, 1); err == nil {
|
||||||
t.Error("EncodeMsg returned nil error")
|
t.Error("EncodeMsg returned nil error")
|
||||||
} else if err != ErrPipeClosed {
|
} else if err != ErrPipeClosed {
|
||||||
t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed)
|
t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed)
|
||||||
|
@ -132,7 +132,7 @@ loop:
|
|||||||
select {
|
select {
|
||||||
case <-ping.C:
|
case <-ping.C:
|
||||||
go func() {
|
go func() {
|
||||||
if err := EncodeMsg(p.rw, pingMsg, nil); err != nil {
|
if err := SendItems(p.rw, pingMsg); err != nil {
|
||||||
p.protoErr <- err
|
p.protoErr <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -161,7 +161,7 @@ loop:
|
|||||||
func (p *Peer) politeDisconnect(reason DiscReason) {
|
func (p *Peer) politeDisconnect(reason DiscReason) {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
EncodeMsg(p.rw, discMsg, uint(reason))
|
SendItems(p.rw, discMsg, uint(reason))
|
||||||
// Wait for the other side to close the connection.
|
// Wait for the other side to close the connection.
|
||||||
// Discard any data that they send until then.
|
// Discard any data that they send until then.
|
||||||
io.Copy(ioutil.Discard, p.conn)
|
io.Copy(ioutil.Discard, p.conn)
|
||||||
@ -192,7 +192,7 @@ func (p *Peer) handle(msg Msg) error {
|
|||||||
switch {
|
switch {
|
||||||
case msg.Code == pingMsg:
|
case msg.Code == pingMsg:
|
||||||
msg.Discard()
|
msg.Discard()
|
||||||
go EncodeMsg(p.rw, pongMsg)
|
go SendItems(p.rw, pongMsg)
|
||||||
case msg.Code == discMsg:
|
case msg.Code == discMsg:
|
||||||
var reason [1]DiscReason
|
var reason [1]DiscReason
|
||||||
// no need to discard or for error checking, we'll close the
|
// no need to discard or for error checking, we'll close the
|
||||||
|
@ -4,13 +4,10 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/rlp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var discard = Protocol{
|
var discard = Protocol{
|
||||||
@ -55,13 +52,13 @@ func TestPeerProtoReadMsg(t *testing.T) {
|
|||||||
Name: "a",
|
Name: "a",
|
||||||
Length: 5,
|
Length: 5,
|
||||||
Run: func(peer *Peer, rw MsgReadWriter) error {
|
Run: func(peer *Peer, rw MsgReadWriter) error {
|
||||||
if err := expectMsg(rw, 2, []uint{1}); err != nil {
|
if err := ExpectMsg(rw, 2, []uint{1}); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if err := expectMsg(rw, 3, []uint{2}); err != nil {
|
if err := ExpectMsg(rw, 3, []uint{2}); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if err := expectMsg(rw, 4, []uint{3}); err != nil {
|
if err := ExpectMsg(rw, 4, []uint{3}); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
close(done)
|
close(done)
|
||||||
@ -72,9 +69,9 @@ func TestPeerProtoReadMsg(t *testing.T) {
|
|||||||
closer, rw, _, errc := testPeer([]Protocol{proto})
|
closer, rw, _, errc := testPeer([]Protocol{proto})
|
||||||
defer closer.Close()
|
defer closer.Close()
|
||||||
|
|
||||||
EncodeMsg(rw, baseProtocolLength+2, 1)
|
Send(rw, baseProtocolLength+2, []uint{1})
|
||||||
EncodeMsg(rw, baseProtocolLength+3, 2)
|
Send(rw, baseProtocolLength+3, []uint{2})
|
||||||
EncodeMsg(rw, baseProtocolLength+4, 3)
|
Send(rw, baseProtocolLength+4, []uint{3})
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
@ -92,10 +89,10 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
|||||||
Name: "a",
|
Name: "a",
|
||||||
Length: 2,
|
Length: 2,
|
||||||
Run: func(peer *Peer, rw MsgReadWriter) error {
|
Run: func(peer *Peer, rw MsgReadWriter) error {
|
||||||
if err := EncodeMsg(rw, 2); err == nil {
|
if err := SendItems(rw, 2); err == nil {
|
||||||
t.Error("expected error for out-of-range msg code, got nil")
|
t.Error("expected error for out-of-range msg code, got nil")
|
||||||
}
|
}
|
||||||
if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil {
|
if err := SendItems(rw, 1, "foo", "bar"); err != nil {
|
||||||
t.Errorf("write error: %v", err)
|
t.Errorf("write error: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -104,7 +101,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
|||||||
closer, rw, _, _ := testPeer([]Protocol{proto})
|
closer, rw, _, _ := testPeer([]Protocol{proto})
|
||||||
defer closer.Close()
|
defer closer.Close()
|
||||||
|
|
||||||
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
|
if err := ExpectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -115,11 +112,15 @@ func TestPeerWriteForBroadcast(t *testing.T) {
|
|||||||
closer, rw, peer, peerErr := testPeer([]Protocol{discard})
|
closer, rw, peer, peerErr := testPeer([]Protocol{discard})
|
||||||
defer closer.Close()
|
defer closer.Close()
|
||||||
|
|
||||||
|
emptymsg := func(code uint64) Msg {
|
||||||
|
return Msg{Code: code, Size: 0, Payload: bytes.NewReader(nil)}
|
||||||
|
}
|
||||||
|
|
||||||
// test write errors
|
// test write errors
|
||||||
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
|
if err := peer.writeProtoMsg("b", emptymsg(3)); err == nil {
|
||||||
t.Errorf("expected error for unknown protocol, got nil")
|
t.Errorf("expected error for unknown protocol, got nil")
|
||||||
}
|
}
|
||||||
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
|
if err := peer.writeProtoMsg("discard", emptymsg(8)); err == nil {
|
||||||
t.Errorf("expected error for out-of-range msg code, got nil")
|
t.Errorf("expected error for out-of-range msg code, got nil")
|
||||||
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
|
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
|
||||||
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
|
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
|
||||||
@ -128,14 +129,14 @@ func TestPeerWriteForBroadcast(t *testing.T) {
|
|||||||
// setup for reading the message on the other end
|
// setup for reading the message on the other end
|
||||||
read := make(chan struct{})
|
read := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
if err := expectMsg(rw, 16, nil); err != nil {
|
if err := ExpectMsg(rw, 16, nil); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
close(read)
|
close(read)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// test successful write
|
// test successful write
|
||||||
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
|
if err := peer.writeProtoMsg("discard", emptymsg(0)); err != nil {
|
||||||
t.Errorf("expect no error for known protocol: %v", err)
|
t.Errorf("expect no error for known protocol: %v", err)
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
@ -150,10 +151,10 @@ func TestPeerPing(t *testing.T) {
|
|||||||
|
|
||||||
closer, rw, _, _ := testPeer(nil)
|
closer, rw, _, _ := testPeer(nil)
|
||||||
defer closer.Close()
|
defer closer.Close()
|
||||||
if err := EncodeMsg(rw, pingMsg); err != nil {
|
if err := SendItems(rw, pingMsg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := expectMsg(rw, pongMsg, nil); err != nil {
|
if err := ExpectMsg(rw, pongMsg, nil); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -163,10 +164,10 @@ func TestPeerDisconnect(t *testing.T) {
|
|||||||
|
|
||||||
closer, rw, _, disc := testPeer(nil)
|
closer, rw, _, disc := testPeer(nil)
|
||||||
defer closer.Close()
|
defer closer.Close()
|
||||||
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
|
if err := SendItems(rw, discMsg, DiscQuitting); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
|
if err := ExpectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
closer.Close() // make test end faster
|
closer.Close() // make test end faster
|
||||||
|
@ -30,7 +30,7 @@ ba628a4ba590cb43f7848f41c4382885
|
|||||||
`)
|
`)
|
||||||
|
|
||||||
// Check WriteMsg. This puts a message into the buffer.
|
// Check WriteMsg. This puts a message into the buffer.
|
||||||
if err := EncodeMsg(rw, 8, 1, 2, 3, 4); err != nil {
|
if err := Send(rw, 8, []uint{1, 2, 3, 4}); err != nil {
|
||||||
t.Fatalf("WriteMsg error: %v", err)
|
t.Fatalf("WriteMsg error: %v", err)
|
||||||
}
|
}
|
||||||
written := buf.Bytes()
|
written := buf.Bytes()
|
||||||
@ -102,7 +102,7 @@ func TestRlpxFrameRW(t *testing.T) {
|
|||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
// write message into conn buffer
|
// write message into conn buffer
|
||||||
wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)}
|
wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)}
|
||||||
err := EncodeMsg(rw1, uint64(i), wmsg...)
|
err := Send(rw1, uint64(i), wmsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("WriteMsg error (i=%d): %v", i, err)
|
t.Fatalf("WriteMsg error (i=%d): %v", i, err)
|
||||||
}
|
}
|
||||||
|
@ -9,10 +9,10 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/common"
|
|
||||||
"github.com/ethereum/go-ethereum/logger"
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
"github.com/ethereum/go-ethereum/p2p/nat"
|
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||||
|
"github.com/ethereum/go-ethereum/rlp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -129,10 +129,14 @@ func (srv *Server) SuggestPeer(n *discover.Node) {
|
|||||||
|
|
||||||
// Broadcast sends an RLP-encoded message to all connected peers.
|
// Broadcast sends an RLP-encoded message to all connected peers.
|
||||||
// This method is deprecated and will be removed later.
|
// This method is deprecated and will be removed later.
|
||||||
func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
|
func (srv *Server) Broadcast(protocol string, code uint64, data interface{}) error {
|
||||||
var payload []byte
|
var payload []byte
|
||||||
if data != nil {
|
if data != nil {
|
||||||
payload = common.Encode(data)
|
var err error
|
||||||
|
payload, err = rlp.EncodeToBytes(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
srv.lock.RLock()
|
srv.lock.RLock()
|
||||||
defer srv.lock.RUnlock()
|
defer srv.lock.RUnlock()
|
||||||
@ -146,6 +150,7 @@ func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{})
|
|||||||
peer.writeProtoMsg(protocol, msg)
|
peer.writeProtoMsg(protocol, msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts running the server.
|
// Start starts running the server.
|
||||||
|
@ -149,7 +149,7 @@ func TestServerBroadcast(t *testing.T) {
|
|||||||
connected.Wait()
|
connected.Wait()
|
||||||
|
|
||||||
// broadcast one message
|
// broadcast one message
|
||||||
srv.Broadcast("discard", 0, "foo")
|
srv.Broadcast("discard", 0, []string{"foo"})
|
||||||
golden := unhex("66e94d166f0a2c3b884cfa59ca34")
|
golden := unhex("66e94d166f0a2c3b884cfa59ca34")
|
||||||
|
|
||||||
// check that the message has been written everywhere
|
// check that the message has been written everywhere
|
||||||
|
Loading…
Reference in New Issue
Block a user