p2p: use RLPx frames for messaging

This commit is contained in:
Felix Lange 2015-02-27 03:06:55 +00:00
parent 51e01cceca
commit 736e632215
6 changed files with 73 additions and 50 deletions

View File

@ -32,14 +32,10 @@ const (
) )
type conn struct { type conn struct {
*frameRW MsgReadWriter
*protoHandshake *protoHandshake
} }
func newConn(fd net.Conn, hs *protoHandshake) *conn {
return &conn{newFrameRW(fd, msgWriteTimeout), hs}
}
// encHandshake contains the state of the encryption handshake. // encHandshake contains the state of the encryption handshake.
type encHandshake struct { type encHandshake struct {
remoteID discover.NodeID remoteID discover.NodeID
@ -115,17 +111,16 @@ func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (
// Run the protocol handshake using authenticated messages. // Run the protocol handshake using authenticated messages.
// TODO: move buffering setup here (out of newFrameRW) // TODO: move buffering setup here (out of newFrameRW)
phsrw := newRlpxFrameRW(fd, secrets) rw := newRlpxFrameRW(fd, secrets)
rhs, err := readProtocolHandshake(phsrw, our) rhs, err := readProtocolHandshake(rw, our)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := writeProtocolHandshake(phsrw, our); err != nil { // TODO: validate that handshake node ID matches
if err := writeProtocolHandshake(rw, our); err != nil {
return nil, fmt.Errorf("protocol write error: %v", err) return nil, fmt.Errorf("protocol write error: %v", err)
} }
return &conn{&lockedRW{wrapped: rw}, rhs}, nil
rw := newFrameRW(fd, msgWriteTimeout)
return &conn{rw, rhs}, nil
} }
func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
@ -136,20 +131,18 @@ func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake,
// Run the protocol handshake using authenticated messages. // Run the protocol handshake using authenticated messages.
// TODO: move buffering setup here (out of newFrameRW) // TODO: move buffering setup here (out of newFrameRW)
phsrw := newRlpxFrameRW(fd, secrets) rw := newRlpxFrameRW(fd, secrets)
if err := writeProtocolHandshake(phsrw, our); err != nil { if err := writeProtocolHandshake(rw, our); err != nil {
return nil, fmt.Errorf("protocol write error: %v", err) return nil, fmt.Errorf("protocol write error: %v", err)
} }
rhs, err := readProtocolHandshake(phsrw, our) rhs, err := readProtocolHandshake(rw, our)
if err != nil { if err != nil {
return nil, fmt.Errorf("protocol handshake read error: %v", err) return nil, fmt.Errorf("protocol handshake read error: %v", err)
} }
if rhs.ID != dial.ID { if rhs.ID != dial.ID {
return nil, errors.New("dialed node id mismatch") return nil, errors.New("dialed node id mismatch")
} }
return &conn{&lockedRW{wrapped: rw}, rhs}, nil
rw := newFrameRW(fd, msgWriteTimeout)
return &conn{rw, rhs}, nil
} }
// outboundEncHandshake negotiates a session token on conn. // outboundEncHandshake negotiates a session token on conn.

View File

@ -119,6 +119,25 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
return w.WriteMsg(NewMsg(code, data...)) return w.WriteMsg(NewMsg(code, data...))
} }
// lockedRW wraps a MsgReadWriter with locks around
// ReadMsg and WriteMsg.
type lockedRW struct {
rmu, wmu sync.Mutex
wrapped MsgReadWriter
}
func (rw *lockedRW) ReadMsg() (Msg, error) {
rw.rmu.Lock()
defer rw.rmu.Unlock()
return rw.wrapped.ReadMsg()
}
func (rw *lockedRW) WriteMsg(msg Msg) error {
rw.wmu.Lock()
defer rw.wmu.Unlock()
return rw.wrapped.WriteMsg(msg)
}
// frameRW is a MsgReadWriter that reads and writes devp2p message frames. // frameRW is a MsgReadWriter that reads and writes devp2p message frames.
// As required by the interface, ReadMsg and WriteMsg can be called from // As required by the interface, ReadMsg and WriteMsg can be called from
// multiple goroutines. // multiple goroutines.

View File

@ -40,6 +40,7 @@ type Peer struct {
// Use them to display messages related to the peer. // Use them to display messages related to the peer.
*logger.Logger *logger.Logger
conn net.Conn
rw *conn rw *conn
running map[string]*protoRW running map[string]*protoRW
@ -52,8 +53,9 @@ type Peer struct {
// NewPeer returns a peer for testing purposes. // NewPeer returns a peer for testing purposes.
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer { func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
pipe, _ := net.Pipe() pipe, _ := net.Pipe()
conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps}) msgpipe, _ := MsgPipe()
peer := newPeer(conn, nil) conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}}
peer := newPeer(pipe, conn, nil)
close(peer.closed) // ensures Disconnect doesn't block close(peer.closed) // ensures Disconnect doesn't block
return peer return peer
} }
@ -76,12 +78,12 @@ func (p *Peer) Caps() []Cap {
// RemoteAddr returns the remote address of the network connection. // RemoteAddr returns the remote address of the network connection.
func (p *Peer) RemoteAddr() net.Addr { func (p *Peer) RemoteAddr() net.Addr {
return p.rw.RemoteAddr() return p.conn.RemoteAddr()
} }
// LocalAddr returns the local address of the network connection. // LocalAddr returns the local address of the network connection.
func (p *Peer) LocalAddr() net.Addr { func (p *Peer) LocalAddr() net.Addr {
return p.rw.LocalAddr() return p.conn.LocalAddr()
} }
// Disconnect terminates the peer connection with the given reason. // Disconnect terminates the peer connection with the given reason.
@ -98,10 +100,11 @@ func (p *Peer) String() string {
return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr()) return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
} }
func newPeer(conn *conn, protocols []Protocol) *Peer { func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer {
logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr()) logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], fd.RemoteAddr())
p := &Peer{ p := &Peer{
Logger: logger.NewLogger(logtag), Logger: logger.NewLogger(logtag),
conn: fd,
rw: conn, rw: conn,
running: matchProtocols(protocols, conn.Caps, conn), running: matchProtocols(protocols, conn.Caps, conn),
disc: make(chan DiscReason), disc: make(chan DiscReason),
@ -138,7 +141,7 @@ loop:
// We rely on protocols to abort if there is a write error. It // We rely on protocols to abort if there is a write error. It
// might be more robust to handle them here as well. // might be more robust to handle them here as well.
p.DebugDetailf("Read error: %v\n", err) p.DebugDetailf("Read error: %v\n", err)
p.rw.Close() p.conn.Close()
return DiscNetworkError return DiscNetworkError
case err := <-p.protoErr: case err := <-p.protoErr:
reason = discReasonForError(err) reason = discReasonForError(err)
@ -161,14 +164,14 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
EncodeMsg(p.rw, discMsg, uint(reason)) EncodeMsg(p.rw, discMsg, uint(reason))
// Wait for the other side to close the connection. // Wait for the other side to close the connection.
// Discard any data that they send until then. // Discard any data that they send until then.
io.Copy(ioutil.Discard, p.rw) io.Copy(ioutil.Discard, p.conn)
close(done) close(done)
}() }()
select { select {
case <-done: case <-done:
case <-time.After(disconnectGracePeriod): case <-time.After(disconnectGracePeriod):
} }
p.rw.Close() p.conn.Close()
} }
func (p *Peer) readLoop() error { func (p *Peer) readLoop() error {

View File

@ -3,6 +3,7 @@ package p2p
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"net" "net"
"reflect" "reflect"
@ -29,8 +30,8 @@ var discard = Protocol{
}, },
} }
func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) { func testPeer(protos []Protocol) (io.Closer, *conn, *Peer, <-chan DiscReason) {
fd1, fd2 := net.Pipe() fd1, _ := net.Pipe()
hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
for _, p := range protos { for _, p := range protos {
@ -38,11 +39,12 @@ func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) {
hs2.Caps = append(hs2.Caps, p.cap()) hs2.Caps = append(hs2.Caps, p.cap())
} }
peer := newPeer(newConn(fd1, hs1), protos) p1, p2 := MsgPipe()
peer := newPeer(fd1, &conn{p1, hs1}, protos)
errc := make(chan DiscReason, 1) errc := make(chan DiscReason, 1)
go func() { errc <- peer.run() }() go func() { errc <- peer.run() }()
return newConn(fd2, hs2), peer, errc return p1, &conn{p2, hs2}, peer, errc
} }
func TestPeerProtoReadMsg(t *testing.T) { func TestPeerProtoReadMsg(t *testing.T) {
@ -67,8 +69,8 @@ func TestPeerProtoReadMsg(t *testing.T) {
}, },
} }
rw, _, errc := testPeer([]Protocol{proto}) closer, rw, _, errc := testPeer([]Protocol{proto})
defer rw.Close() defer closer.Close()
EncodeMsg(rw, baseProtocolLength+2, 1) EncodeMsg(rw, baseProtocolLength+2, 1)
EncodeMsg(rw, baseProtocolLength+3, 2) EncodeMsg(rw, baseProtocolLength+3, 2)
@ -105,8 +107,8 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
}, },
} }
rw, _, errc := testPeer([]Protocol{proto}) closer, rw, _, errc := testPeer([]Protocol{proto})
defer rw.Close() defer closer.Close()
EncodeMsg(rw, 18, make([]byte, msgsize)) EncodeMsg(rw, 18, make([]byte, msgsize))
select { select {
@ -134,8 +136,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
return nil return nil
}, },
} }
rw, _, _ := testPeer([]Protocol{proto}) closer, rw, _, _ := testPeer([]Protocol{proto})
defer rw.Close() defer closer.Close()
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil { if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
t.Error(err) t.Error(err)
@ -145,8 +147,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
func TestPeerWriteForBroadcast(t *testing.T) { func TestPeerWriteForBroadcast(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
rw, peer, peerErr := testPeer([]Protocol{discard}) closer, rw, peer, peerErr := testPeer([]Protocol{discard})
defer rw.Close() defer closer.Close()
// test write errors // test write errors
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
@ -181,8 +183,8 @@ func TestPeerWriteForBroadcast(t *testing.T) {
func TestPeerPing(t *testing.T) { func TestPeerPing(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
rw, _, _ := testPeer(nil) closer, rw, _, _ := testPeer(nil)
defer rw.Close() defer closer.Close()
if err := EncodeMsg(rw, pingMsg); err != nil { if err := EncodeMsg(rw, pingMsg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -194,15 +196,15 @@ func TestPeerPing(t *testing.T) {
func TestPeerDisconnect(t *testing.T) { func TestPeerDisconnect(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
rw, _, disc := testPeer(nil) closer, rw, _, disc := testPeer(nil)
defer rw.Close() defer closer.Close()
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil { if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil { if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
t.Error(err) t.Error(err)
} }
rw.Close() // make test end faster closer.Close() // make test end faster
if reason := <-disc; reason != DiscRequested { if reason := <-disc; reason != DiscRequested {
t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested) t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
} }

View File

@ -358,14 +358,15 @@ func (srv *Server) findPeers() {
func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) { func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
// TODO: handle/store session token // TODO: handle/store session token
fd.SetDeadline(time.Now().Add(handshakeTimeout)) // TODO: reenable deadlines
// fd.SetDeadline(time.Now().Add(handshakeTimeout))
conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest) conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest)
if err != nil { if err != nil {
fd.Close() fd.Close()
srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err) srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err)
return return
} }
p := newPeer(conn, srv.Protocols) p := newPeer(fd, conn, srv.Protocols)
if ok, reason := srv.addPeer(conn.ID, p); !ok { if ok, reason := srv.addPeer(conn.ID, p); !ok {
srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason) srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
p.politeDisconnect(reason) p.politeDisconnect(reason)
@ -375,7 +376,7 @@ func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
srvlog.Debugf("Added %v\n", p) srvlog.Debugf("Added %v\n", p)
srvjslog.LogJson(&logger.P2PConnected{ srvjslog.LogJson(&logger.P2PConnected{
RemoteId: fmt.Sprintf("%x", conn.ID[:]), RemoteId: fmt.Sprintf("%x", conn.ID[:]),
RemoteAddress: conn.RemoteAddr().String(), RemoteAddress: fd.RemoteAddr().String(),
RemoteVersionString: conn.Name, RemoteVersionString: conn.Name,
NumConnections: srv.PeerCount(), NumConnections: srv.PeerCount(),
}) })

View File

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/sha3"
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
) )
@ -23,8 +24,14 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
newPeerHook: pf, newPeerHook: pf,
setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
id := randomID() id := randomID()
rw := newRlpxFrameRW(fd, secrets{
MAC: zero16,
AES: zero16,
IngressMAC: sha3.NewKeccak256(),
EgressMAC: sha3.NewKeccak256(),
})
return &conn{ return &conn{
frameRW: newFrameRW(fd, msgWriteTimeout), MsgReadWriter: rw,
protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion}, protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
}, nil }, nil
}, },
@ -143,9 +150,7 @@ func TestServerBroadcast(t *testing.T) {
// broadcast one message // broadcast one message
srv.Broadcast("discard", 0, "foo") srv.Broadcast("discard", 0, "foo")
goldbuf := new(bytes.Buffer) golden := unhex("66e94e166f0a2c3b884cfa59ca34")
writeMsg(goldbuf, NewMsg(16, "foo"))
golden := goldbuf.Bytes()
// check that the message has been written everywhere // check that the message has been written everywhere
for i, conn := range conns { for i, conn := range conns {