diff --git a/p2p/message.go b/p2p/message.go index 2ef84f99d..04b9e71f3 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -51,19 +51,8 @@ type Msg struct { // NewMsg creates an RLP-encoded message with the given code. func NewMsg(code uint64, params ...interface{}) Msg { - buf := new(bytes.Buffer) - for _, p := range params { - buf.Write(ethutil.Encode(p)) - } - return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf} -} - -func encodePayload(params ...interface{}) []byte { - buf := new(bytes.Buffer) - for _, p := range params { - buf.Write(ethutil.Encode(p)) - } - return buf.Bytes() + p := bytes.NewReader(ethutil.Encode(params)) + return Msg{Code: code, Size: uint32(p.Len()), Payload: p} } // Decode parse the RLP content of a message into @@ -71,8 +60,7 @@ func encodePayload(params ...interface{}) []byte { // // For the decoding rules, please see package rlp. func (msg Msg) Decode(val interface{}) error { - s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) - if err := s.Decode(val); err != nil { + if err := rlp.Decode(msg.Payload, val); err != nil { return newPeerError(errInvalidMsg, "(code %#x) (size %d) %v", msg.Code, msg.Size, err) } return nil diff --git a/p2p/message_test.go b/p2p/message_test.go index 1757cbe7a..31ed61d87 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -2,10 +2,12 @@ package p2p import ( "bytes" + "encoding/hex" "fmt" "io" "io/ioutil" "runtime" + "strings" "testing" "time" ) @@ -15,11 +17,11 @@ func TestNewMsg(t *testing.T) { if msg.Code != 3 { t.Errorf("incorrect code %d, want %d", msg.Code) } - if msg.Size != 5 { - t.Errorf("incorrect size %d, want %d", msg.Size, 5) + expect := unhex("c50183303030") + if msg.Size != uint32(len(expect)) { + t.Errorf("incorrect size %d, want %d", msg.Size, len(expect)) } pl, _ := ioutil.ReadAll(msg.Payload) - expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30} if !bytes.Equal(pl, expect) { t.Errorf("incorrect payload content, got %x, want %x", pl, expect) } @@ -139,3 +141,11 @@ func TestEOFSignal(t *testing.T) { default: } } + +func unhex(str string) []byte { + b, err := hex.DecodeString(strings.Replace(str, "\n", "", -1)) + if err != nil { + panic(fmt.Sprintf("invalid hex string: %q", str)) + } + return b +} diff --git a/p2p/peer.go b/p2p/peer.go index 4982c4612..025be4ba9 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -193,12 +193,12 @@ func (p *Peer) handle(msg Msg) error { msg.Discard() go EncodeMsg(p.rw, pongMsg) case msg.Code == discMsg: - var reason DiscReason + var reason [1]DiscReason // no need to discard or for error checking, we'll close the // connection after this. rlp.Decode(msg.Payload, &reason) p.Disconnect(DiscRequested) - return discRequestedError(reason) + return discRequestedError(reason[0]) case msg.Code < baseProtocolLength: // ignore other base protocol messages return msg.Discard() diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 1ba43bed5..cc9f1f0cd 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -85,41 +85,6 @@ func TestPeerProtoReadMsg(t *testing.T) { } } -func TestPeerProtoReadLargeMsg(t *testing.T) { - defer testlog(t).detach() - - msgsize := uint32(10 * 1024 * 1024) - done := make(chan struct{}) - proto := Protocol{ - Name: "a", - Length: 5, - Run: func(peer *Peer, rw MsgReadWriter) error { - msg, err := rw.ReadMsg() - if err != nil { - t.Errorf("read error: %v", err) - } - if msg.Size != msgsize+4 { - t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize) - } - msg.Discard() - close(done) - return nil - }, - } - - closer, rw, _, errc := testPeer([]Protocol{proto}) - defer closer.Close() - - EncodeMsg(rw, 18, make([]byte, msgsize)) - select { - case <-done: - case err := <-errc: - t.Errorf("peer returned: %v", err) - case <-time.After(2 * time.Second): - t.Errorf("receive timeout") - } -} - func TestPeerProtoEncodeMsg(t *testing.T) { defer testlog(t).detach() @@ -246,13 +211,9 @@ func expectMsg(r MsgReader, code uint64, content interface{}) error { if err != nil { panic("content encode error: " + err.Error()) } - // skip over list header in encoded value. this is temporary. - contentEncR := bytes.NewReader(contentEnc) - if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil { - panic("content must encode as RLP list") + if int(msg.Size) != len(contentEnc) { + return fmt.Errorf("message size mismatch: got %d, want %d", msg.Size, len(contentEnc)) } - contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():] - actualContent, err := ioutil.ReadAll(msg.Payload) if err != nil { return err diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go index 077dd1309..49354c7ed 100644 --- a/p2p/rlpx_test.go +++ b/p2p/rlpx_test.go @@ -3,8 +3,6 @@ package p2p import ( "bytes" "crypto/rand" - "encoding/hex" - "fmt" "io/ioutil" "strings" "testing" @@ -32,7 +30,7 @@ ba628a4ba590cb43f7848f41c4382885 `) // Check WriteMsg. This puts a message into the buffer. - if err := EncodeMsg(rw, 8, []interface{}{1, 2, 3, 4}); err != nil { + if err := EncodeMsg(rw, 8, 1, 2, 3, 4); err != nil { t.Fatalf("WriteMsg error: %v", err) } written := buf.Bytes() @@ -68,14 +66,6 @@ func (fakeHash) BlockSize() int { return 0 } func (h fakeHash) Size() int { return len(h) } func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) } -func unhex(str string) []byte { - b, err := hex.DecodeString(strings.Replace(str, "\n", "", -1)) - if err != nil { - panic(fmt.Sprintf("invalid hex string: %q", str)) - } - return b -} - func TestRlpxFrameRW(t *testing.T) { var ( aesSecret = make([]byte, 16) @@ -112,7 +102,7 @@ func TestRlpxFrameRW(t *testing.T) { for i := 0; i < 10; i++ { // write message into conn buffer wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)} - err := EncodeMsg(rw1, uint64(i), wmsg) + err := EncodeMsg(rw1, uint64(i), wmsg...) if err != nil { t.Fatalf("WriteMsg error (i=%d): %v", i, err) } diff --git a/p2p/server.go b/p2p/server.go index e53e832aa..67d5514b4 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/nat" @@ -135,7 +136,7 @@ func (srv *Server) SuggestPeer(n *discover.Node) { func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) { var payload []byte if data != nil { - payload = encodePayload(data...) + payload = ethutil.Encode(data) } srv.lock.RLock() defer srv.lock.RUnlock() diff --git a/p2p/server_test.go b/p2p/server_test.go index c348f5a9a..30447050c 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -150,7 +150,7 @@ func TestServerBroadcast(t *testing.T) { // broadcast one message srv.Broadcast("discard", 0, "foo") - golden := unhex("66e94e166f0a2c3b884cfa59ca34") + golden := unhex("66e94d166f0a2c3b884cfa59ca34") // check that the message has been written everywhere for i, conn := range conns {