forked from cerc-io/plugeth
p2p: use package rlp to encode messages
Message encoding functions have been renamed to catch any uses. The switch to the new encoder can cause subtle incompatibilities. If there are any users outside of our tree, they will at least be alerted that there was a change. NewMsg no longer exists. The replacements for EncodeMsg are called Send and SendItems.
This commit is contained in:
parent
4811f460e7
commit
5ba51594c7
@ -92,7 +92,7 @@ func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (
|
||||
return nil, errors.New("node ID in protocol handshake does not match encryption handshake")
|
||||
}
|
||||
// TODO: validate that handshake node ID matches
|
||||
if err := writeProtocolHandshake(rw, our); err != nil {
|
||||
if err := Send(rw, handshakeMsg, our); err != nil {
|
||||
return nil, fmt.Errorf("protocol write error: %v", err)
|
||||
}
|
||||
return &conn{rw, rhs}, nil
|
||||
@ -106,7 +106,7 @@ func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake,
|
||||
|
||||
// Run the protocol handshake using authenticated messages.
|
||||
rw := newRlpxFrameRW(fd, secrets)
|
||||
if err := writeProtocolHandshake(rw, our); err != nil {
|
||||
if err := Send(rw, handshakeMsg, our); err != nil {
|
||||
return nil, fmt.Errorf("protocol write error: %v", err)
|
||||
}
|
||||
rhs, err := readProtocolHandshake(rw, our)
|
||||
@ -398,10 +398,6 @@ func xor(one, other []byte) (xor []byte) {
|
||||
return xor
|
||||
}
|
||||
|
||||
func writeProtocolHandshake(w MsgWriter, our *protoHandshake) error {
|
||||
return EncodeMsg(w, handshakeMsg, our.Version, our.Name, our.Caps, our.ListenPort, our.ID[:])
|
||||
}
|
||||
|
||||
func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
|
||||
// read and handle remote handshake
|
||||
msg, err := r.ReadMsg()
|
||||
|
@ -11,7 +11,6 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
@ -28,13 +27,7 @@ type Msg struct {
|
||||
Payload io.Reader
|
||||
}
|
||||
|
||||
// NewMsg creates an RLP-encoded message with the given code.
|
||||
func NewMsg(code uint64, params ...interface{}) Msg {
|
||||
p := bytes.NewReader(common.Encode(params))
|
||||
return Msg{Code: code, Size: uint32(p.Len()), Payload: p}
|
||||
}
|
||||
|
||||
// Decode parse the RLP content of a message into
|
||||
// Decode parses the RLP content of a message into
|
||||
// the given value, which must be a pointer.
|
||||
//
|
||||
// For the decoding rules, please see package rlp.
|
||||
@ -76,13 +69,30 @@ type MsgReadWriter interface {
|
||||
MsgWriter
|
||||
}
|
||||
|
||||
// EncodeMsg writes an RLP-encoded message with the given code and
|
||||
// data elements.
|
||||
func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
|
||||
return w.WriteMsg(NewMsg(code, data...))
|
||||
// Send writes an RLP-encoded message with the given code.
|
||||
// data should encode as an RLP list.
|
||||
func Send(w MsgWriter, msgcode uint64, data interface{}) error {
|
||||
size, r, err := rlp.EncodeToReader(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return w.WriteMsg(Msg{Code: msgcode, Size: uint32(size), Payload: r})
|
||||
}
|
||||
|
||||
// netWrapper wrapsa MsgReadWriter with locks around
|
||||
// SendItems writes an RLP with the given code and data elements.
|
||||
// For a call such as:
|
||||
//
|
||||
// SendItems(w, code, e1, e2, e3)
|
||||
//
|
||||
// the message payload will be an RLP list containing the items:
|
||||
//
|
||||
// [e1, e2, e3]
|
||||
//
|
||||
func SendItems(w MsgWriter, msgcode uint64, elems ...interface{}) error {
|
||||
return Send(w, msgcode, elems)
|
||||
}
|
||||
|
||||
// netWrapper wraps a MsgReadWriter with locks around
|
||||
// ReadMsg/WriteMsg and applies read/write deadlines.
|
||||
type netWrapper struct {
|
||||
rmu, wmu sync.Mutex
|
||||
|
@ -5,33 +5,17 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewMsg(t *testing.T) {
|
||||
msg := NewMsg(3, 1, "000")
|
||||
if msg.Code != 3 {
|
||||
t.Errorf("incorrect code %d, want %d", msg.Code)
|
||||
}
|
||||
expect := unhex("c50183303030")
|
||||
if msg.Size != uint32(len(expect)) {
|
||||
t.Errorf("incorrect size %d, want %d", msg.Size, len(expect))
|
||||
}
|
||||
pl, _ := ioutil.ReadAll(msg.Payload)
|
||||
if !bytes.Equal(pl, expect) {
|
||||
t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleMsgPipe() {
|
||||
rw1, rw2 := MsgPipe()
|
||||
go func() {
|
||||
EncodeMsg(rw1, 8, []byte{0, 0})
|
||||
EncodeMsg(rw1, 5, []byte{1, 1})
|
||||
Send(rw1, 8, [][]byte{{0, 0}})
|
||||
Send(rw1, 5, [][]byte{{1, 1}})
|
||||
rw1.Close()
|
||||
}()
|
||||
|
||||
@ -40,7 +24,7 @@ func ExampleMsgPipe() {
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
var data [1][]byte
|
||||
var data [][]byte
|
||||
msg.Decode(&data)
|
||||
fmt.Printf("msg: %d, %x\n", msg.Code, data[0])
|
||||
}
|
||||
@ -55,7 +39,7 @@ loop:
|
||||
rw1, rw2 := MsgPipe()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
if err := EncodeMsg(rw1, 1); err == nil {
|
||||
if err := SendItems(rw1, 1); err == nil {
|
||||
t.Error("EncodeMsg returned nil error")
|
||||
} else if err != ErrPipeClosed {
|
||||
t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed)
|
||||
|
@ -132,7 +132,7 @@ loop:
|
||||
select {
|
||||
case <-ping.C:
|
||||
go func() {
|
||||
if err := EncodeMsg(p.rw, pingMsg, nil); err != nil {
|
||||
if err := SendItems(p.rw, pingMsg); err != nil {
|
||||
p.protoErr <- err
|
||||
return
|
||||
}
|
||||
@ -161,7 +161,7 @@ loop:
|
||||
func (p *Peer) politeDisconnect(reason DiscReason) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
EncodeMsg(p.rw, discMsg, uint(reason))
|
||||
SendItems(p.rw, discMsg, uint(reason))
|
||||
// Wait for the other side to close the connection.
|
||||
// Discard any data that they send until then.
|
||||
io.Copy(ioutil.Discard, p.conn)
|
||||
@ -192,7 +192,7 @@ func (p *Peer) handle(msg Msg) error {
|
||||
switch {
|
||||
case msg.Code == pingMsg:
|
||||
msg.Discard()
|
||||
go EncodeMsg(p.rw, pongMsg)
|
||||
go SendItems(p.rw, pongMsg)
|
||||
case msg.Code == discMsg:
|
||||
var reason [1]DiscReason
|
||||
// no need to discard or for error checking, we'll close the
|
||||
|
@ -4,13 +4,10 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
var discard = Protocol{
|
||||
@ -55,13 +52,13 @@ func TestPeerProtoReadMsg(t *testing.T) {
|
||||
Name: "a",
|
||||
Length: 5,
|
||||
Run: func(peer *Peer, rw MsgReadWriter) error {
|
||||
if err := expectMsg(rw, 2, []uint{1}); err != nil {
|
||||
if err := ExpectMsg(rw, 2, []uint{1}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := expectMsg(rw, 3, []uint{2}); err != nil {
|
||||
if err := ExpectMsg(rw, 3, []uint{2}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := expectMsg(rw, 4, []uint{3}); err != nil {
|
||||
if err := ExpectMsg(rw, 4, []uint{3}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
close(done)
|
||||
@ -72,9 +69,9 @@ func TestPeerProtoReadMsg(t *testing.T) {
|
||||
closer, rw, _, errc := testPeer([]Protocol{proto})
|
||||
defer closer.Close()
|
||||
|
||||
EncodeMsg(rw, baseProtocolLength+2, 1)
|
||||
EncodeMsg(rw, baseProtocolLength+3, 2)
|
||||
EncodeMsg(rw, baseProtocolLength+4, 3)
|
||||
Send(rw, baseProtocolLength+2, []uint{1})
|
||||
Send(rw, baseProtocolLength+3, []uint{2})
|
||||
Send(rw, baseProtocolLength+4, []uint{3})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
@ -92,10 +89,10 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
||||
Name: "a",
|
||||
Length: 2,
|
||||
Run: func(peer *Peer, rw MsgReadWriter) error {
|
||||
if err := EncodeMsg(rw, 2); err == nil {
|
||||
if err := SendItems(rw, 2); err == nil {
|
||||
t.Error("expected error for out-of-range msg code, got nil")
|
||||
}
|
||||
if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil {
|
||||
if err := SendItems(rw, 1, "foo", "bar"); err != nil {
|
||||
t.Errorf("write error: %v", err)
|
||||
}
|
||||
return nil
|
||||
@ -104,7 +101,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
||||
closer, rw, _, _ := testPeer([]Protocol{proto})
|
||||
defer closer.Close()
|
||||
|
||||
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
|
||||
if err := ExpectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
@ -115,11 +112,15 @@ func TestPeerWriteForBroadcast(t *testing.T) {
|
||||
closer, rw, peer, peerErr := testPeer([]Protocol{discard})
|
||||
defer closer.Close()
|
||||
|
||||
emptymsg := func(code uint64) Msg {
|
||||
return Msg{Code: code, Size: 0, Payload: bytes.NewReader(nil)}
|
||||
}
|
||||
|
||||
// test write errors
|
||||
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
|
||||
if err := peer.writeProtoMsg("b", emptymsg(3)); err == nil {
|
||||
t.Errorf("expected error for unknown protocol, got nil")
|
||||
}
|
||||
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
|
||||
if err := peer.writeProtoMsg("discard", emptymsg(8)); err == nil {
|
||||
t.Errorf("expected error for out-of-range msg code, got nil")
|
||||
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
|
||||
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
|
||||
@ -128,14 +129,14 @@ func TestPeerWriteForBroadcast(t *testing.T) {
|
||||
// setup for reading the message on the other end
|
||||
read := make(chan struct{})
|
||||
go func() {
|
||||
if err := expectMsg(rw, 16, nil); err != nil {
|
||||
if err := ExpectMsg(rw, 16, nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
close(read)
|
||||
}()
|
||||
|
||||
// test successful write
|
||||
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
|
||||
if err := peer.writeProtoMsg("discard", emptymsg(0)); err != nil {
|
||||
t.Errorf("expect no error for known protocol: %v", err)
|
||||
}
|
||||
select {
|
||||
@ -150,10 +151,10 @@ func TestPeerPing(t *testing.T) {
|
||||
|
||||
closer, rw, _, _ := testPeer(nil)
|
||||
defer closer.Close()
|
||||
if err := EncodeMsg(rw, pingMsg); err != nil {
|
||||
if err := SendItems(rw, pingMsg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectMsg(rw, pongMsg, nil); err != nil {
|
||||
if err := ExpectMsg(rw, pongMsg, nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
@ -163,10 +164,10 @@ func TestPeerDisconnect(t *testing.T) {
|
||||
|
||||
closer, rw, _, disc := testPeer(nil)
|
||||
defer closer.Close()
|
||||
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
|
||||
if err := SendItems(rw, discMsg, DiscQuitting); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
|
||||
if err := ExpectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
closer.Close() // make test end faster
|
||||
|
@ -30,7 +30,7 @@ ba628a4ba590cb43f7848f41c4382885
|
||||
`)
|
||||
|
||||
// Check WriteMsg. This puts a message into the buffer.
|
||||
if err := EncodeMsg(rw, 8, 1, 2, 3, 4); err != nil {
|
||||
if err := Send(rw, 8, []uint{1, 2, 3, 4}); err != nil {
|
||||
t.Fatalf("WriteMsg error: %v", err)
|
||||
}
|
||||
written := buf.Bytes()
|
||||
@ -102,7 +102,7 @@ func TestRlpxFrameRW(t *testing.T) {
|
||||
for i := 0; i < 10; i++ {
|
||||
// write message into conn buffer
|
||||
wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)}
|
||||
err := EncodeMsg(rw1, uint64(i), wmsg...)
|
||||
err := Send(rw1, uint64(i), wmsg)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteMsg error (i=%d): %v", i, err)
|
||||
}
|
||||
|
@ -9,10 +9,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -129,10 +129,14 @@ func (srv *Server) SuggestPeer(n *discover.Node) {
|
||||
|
||||
// Broadcast sends an RLP-encoded message to all connected peers.
|
||||
// This method is deprecated and will be removed later.
|
||||
func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
|
||||
func (srv *Server) Broadcast(protocol string, code uint64, data interface{}) error {
|
||||
var payload []byte
|
||||
if data != nil {
|
||||
payload = common.Encode(data)
|
||||
var err error
|
||||
payload, err = rlp.EncodeToBytes(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
srv.lock.RLock()
|
||||
defer srv.lock.RUnlock()
|
||||
@ -146,6 +150,7 @@ func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{})
|
||||
peer.writeProtoMsg(protocol, msg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start starts running the server.
|
||||
|
@ -149,7 +149,7 @@ func TestServerBroadcast(t *testing.T) {
|
||||
connected.Wait()
|
||||
|
||||
// broadcast one message
|
||||
srv.Broadcast("discard", 0, "foo")
|
||||
srv.Broadcast("discard", 0, []string{"foo"})
|
||||
golden := unhex("66e94d166f0a2c3b884cfa59ca34")
|
||||
|
||||
// check that the message has been written everywhere
|
||||
|
Loading…
Reference in New Issue
Block a user