forked from cerc-io/plugeth
p2p: msg.Payload contains list data
With RLPx frames, the message code is contained in the frame and is no longer part of the encoded data. EncodeMsg, Msg.Decode have been updated to match. Code that decodes RLP directly from Msg.Payload will need to change.
This commit is contained in:
parent
21649100b1
commit
7964f30dcb
@ -51,19 +51,8 @@ type Msg struct {
|
|||||||
|
|
||||||
// NewMsg creates an RLP-encoded message with the given code.
|
// NewMsg creates an RLP-encoded message with the given code.
|
||||||
func NewMsg(code uint64, params ...interface{}) Msg {
|
func NewMsg(code uint64, params ...interface{}) Msg {
|
||||||
buf := new(bytes.Buffer)
|
p := bytes.NewReader(ethutil.Encode(params))
|
||||||
for _, p := range params {
|
return Msg{Code: code, Size: uint32(p.Len()), Payload: p}
|
||||||
buf.Write(ethutil.Encode(p))
|
|
||||||
}
|
|
||||||
return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf}
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodePayload(params ...interface{}) []byte {
|
|
||||||
buf := new(bytes.Buffer)
|
|
||||||
for _, p := range params {
|
|
||||||
buf.Write(ethutil.Encode(p))
|
|
||||||
}
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode parse the RLP content of a message into
|
// Decode parse the RLP content of a message into
|
||||||
@ -71,8 +60,7 @@ func encodePayload(params ...interface{}) []byte {
|
|||||||
//
|
//
|
||||||
// For the decoding rules, please see package rlp.
|
// For the decoding rules, please see package rlp.
|
||||||
func (msg Msg) Decode(val interface{}) error {
|
func (msg Msg) Decode(val interface{}) error {
|
||||||
s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
|
if err := rlp.Decode(msg.Payload, val); err != nil {
|
||||||
if err := s.Decode(val); err != nil {
|
|
||||||
return newPeerError(errInvalidMsg, "(code %#x) (size %d) %v", msg.Code, msg.Size, err)
|
return newPeerError(errInvalidMsg, "(code %#x) (size %d) %v", msg.Code, msg.Size, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -2,10 +2,12 @@ package p2p
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -15,11 +17,11 @@ func TestNewMsg(t *testing.T) {
|
|||||||
if msg.Code != 3 {
|
if msg.Code != 3 {
|
||||||
t.Errorf("incorrect code %d, want %d", msg.Code)
|
t.Errorf("incorrect code %d, want %d", msg.Code)
|
||||||
}
|
}
|
||||||
if msg.Size != 5 {
|
expect := unhex("c50183303030")
|
||||||
t.Errorf("incorrect size %d, want %d", msg.Size, 5)
|
if msg.Size != uint32(len(expect)) {
|
||||||
|
t.Errorf("incorrect size %d, want %d", msg.Size, len(expect))
|
||||||
}
|
}
|
||||||
pl, _ := ioutil.ReadAll(msg.Payload)
|
pl, _ := ioutil.ReadAll(msg.Payload)
|
||||||
expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30}
|
|
||||||
if !bytes.Equal(pl, expect) {
|
if !bytes.Equal(pl, expect) {
|
||||||
t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
|
t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
|
||||||
}
|
}
|
||||||
@ -139,3 +141,11 @@ func TestEOFSignal(t *testing.T) {
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func unhex(str string) []byte {
|
||||||
|
b, err := hex.DecodeString(strings.Replace(str, "\n", "", -1))
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid hex string: %q", str))
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
@ -193,12 +193,12 @@ func (p *Peer) handle(msg Msg) error {
|
|||||||
msg.Discard()
|
msg.Discard()
|
||||||
go EncodeMsg(p.rw, pongMsg)
|
go EncodeMsg(p.rw, pongMsg)
|
||||||
case msg.Code == discMsg:
|
case msg.Code == discMsg:
|
||||||
var reason DiscReason
|
var reason [1]DiscReason
|
||||||
// no need to discard or for error checking, we'll close the
|
// no need to discard or for error checking, we'll close the
|
||||||
// connection after this.
|
// connection after this.
|
||||||
rlp.Decode(msg.Payload, &reason)
|
rlp.Decode(msg.Payload, &reason)
|
||||||
p.Disconnect(DiscRequested)
|
p.Disconnect(DiscRequested)
|
||||||
return discRequestedError(reason)
|
return discRequestedError(reason[0])
|
||||||
case msg.Code < baseProtocolLength:
|
case msg.Code < baseProtocolLength:
|
||||||
// ignore other base protocol messages
|
// ignore other base protocol messages
|
||||||
return msg.Discard()
|
return msg.Discard()
|
||||||
|
@ -85,41 +85,6 @@ func TestPeerProtoReadMsg(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPeerProtoReadLargeMsg(t *testing.T) {
|
|
||||||
defer testlog(t).detach()
|
|
||||||
|
|
||||||
msgsize := uint32(10 * 1024 * 1024)
|
|
||||||
done := make(chan struct{})
|
|
||||||
proto := Protocol{
|
|
||||||
Name: "a",
|
|
||||||
Length: 5,
|
|
||||||
Run: func(peer *Peer, rw MsgReadWriter) error {
|
|
||||||
msg, err := rw.ReadMsg()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("read error: %v", err)
|
|
||||||
}
|
|
||||||
if msg.Size != msgsize+4 {
|
|
||||||
t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
|
|
||||||
}
|
|
||||||
msg.Discard()
|
|
||||||
close(done)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
closer, rw, _, errc := testPeer([]Protocol{proto})
|
|
||||||
defer closer.Close()
|
|
||||||
|
|
||||||
EncodeMsg(rw, 18, make([]byte, msgsize))
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case err := <-errc:
|
|
||||||
t.Errorf("peer returned: %v", err)
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Errorf("receive timeout")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPeerProtoEncodeMsg(t *testing.T) {
|
func TestPeerProtoEncodeMsg(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
@ -246,13 +211,9 @@ func expectMsg(r MsgReader, code uint64, content interface{}) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic("content encode error: " + err.Error())
|
panic("content encode error: " + err.Error())
|
||||||
}
|
}
|
||||||
// skip over list header in encoded value. this is temporary.
|
if int(msg.Size) != len(contentEnc) {
|
||||||
contentEncR := bytes.NewReader(contentEnc)
|
return fmt.Errorf("message size mismatch: got %d, want %d", msg.Size, len(contentEnc))
|
||||||
if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil {
|
|
||||||
panic("content must encode as RLP list")
|
|
||||||
}
|
}
|
||||||
contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():]
|
|
||||||
|
|
||||||
actualContent, err := ioutil.ReadAll(msg.Payload)
|
actualContent, err := ioutil.ReadAll(msg.Payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -3,8 +3,6 @@ package p2p
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@ -32,7 +30,7 @@ ba628a4ba590cb43f7848f41c4382885
|
|||||||
`)
|
`)
|
||||||
|
|
||||||
// Check WriteMsg. This puts a message into the buffer.
|
// Check WriteMsg. This puts a message into the buffer.
|
||||||
if err := EncodeMsg(rw, 8, []interface{}{1, 2, 3, 4}); err != nil {
|
if err := EncodeMsg(rw, 8, 1, 2, 3, 4); err != nil {
|
||||||
t.Fatalf("WriteMsg error: %v", err)
|
t.Fatalf("WriteMsg error: %v", err)
|
||||||
}
|
}
|
||||||
written := buf.Bytes()
|
written := buf.Bytes()
|
||||||
@ -68,14 +66,6 @@ func (fakeHash) BlockSize() int { return 0 }
|
|||||||
func (h fakeHash) Size() int { return len(h) }
|
func (h fakeHash) Size() int { return len(h) }
|
||||||
func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
|
func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
|
||||||
|
|
||||||
func unhex(str string) []byte {
|
|
||||||
b, err := hex.DecodeString(strings.Replace(str, "\n", "", -1))
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("invalid hex string: %q", str))
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRlpxFrameRW(t *testing.T) {
|
func TestRlpxFrameRW(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
aesSecret = make([]byte, 16)
|
aesSecret = make([]byte, 16)
|
||||||
@ -112,7 +102,7 @@ func TestRlpxFrameRW(t *testing.T) {
|
|||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
// write message into conn buffer
|
// write message into conn buffer
|
||||||
wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)}
|
wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)}
|
||||||
err := EncodeMsg(rw1, uint64(i), wmsg)
|
err := EncodeMsg(rw1, uint64(i), wmsg...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("WriteMsg error (i=%d): %v", i, err)
|
t.Fatalf("WriteMsg error (i=%d): %v", i, err)
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/ethutil"
|
||||||
"github.com/ethereum/go-ethereum/logger"
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
"github.com/ethereum/go-ethereum/p2p/nat"
|
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||||
@ -135,7 +136,7 @@ func (srv *Server) SuggestPeer(n *discover.Node) {
|
|||||||
func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
|
func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) {
|
||||||
var payload []byte
|
var payload []byte
|
||||||
if data != nil {
|
if data != nil {
|
||||||
payload = encodePayload(data...)
|
payload = ethutil.Encode(data)
|
||||||
}
|
}
|
||||||
srv.lock.RLock()
|
srv.lock.RLock()
|
||||||
defer srv.lock.RUnlock()
|
defer srv.lock.RUnlock()
|
||||||
|
@ -150,7 +150,7 @@ func TestServerBroadcast(t *testing.T) {
|
|||||||
|
|
||||||
// broadcast one message
|
// broadcast one message
|
||||||
srv.Broadcast("discard", 0, "foo")
|
srv.Broadcast("discard", 0, "foo")
|
||||||
golden := unhex("66e94e166f0a2c3b884cfa59ca34")
|
golden := unhex("66e94d166f0a2c3b884cfa59ca34")
|
||||||
|
|
||||||
// check that the message has been written everywhere
|
// check that the message has been written everywhere
|
||||||
for i, conn := range conns {
|
for i, conn := range conns {
|
||||||
|
Loading…
Reference in New Issue
Block a user