diff --git a/p2p/handshake.go b/p2p/handshake.go index 17f572dea..10ef970dc 100644 --- a/p2p/handshake.go +++ b/p2p/handshake.go @@ -32,14 +32,10 @@ const ( ) type conn struct { - *frameRW + MsgReadWriter *protoHandshake } -func newConn(fd net.Conn, hs *protoHandshake) *conn { - return &conn{newFrameRW(fd, msgWriteTimeout), hs} -} - // encHandshake contains the state of the encryption handshake. type encHandshake struct { remoteID discover.NodeID @@ -115,17 +111,16 @@ func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) ( // Run the protocol handshake using authenticated messages. // TODO: move buffering setup here (out of newFrameRW) - phsrw := newRlpxFrameRW(fd, secrets) - rhs, err := readProtocolHandshake(phsrw, our) + rw := newRlpxFrameRW(fd, secrets) + rhs, err := readProtocolHandshake(rw, our) if err != nil { return nil, err } - if err := writeProtocolHandshake(phsrw, our); err != nil { + // TODO: validate that handshake node ID matches + if err := writeProtocolHandshake(rw, our); err != nil { return nil, fmt.Errorf("protocol write error: %v", err) } - - rw := newFrameRW(fd, msgWriteTimeout) - return &conn{rw, rhs}, nil + return &conn{&lockedRW{wrapped: rw}, rhs}, nil } func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { @@ -136,20 +131,18 @@ func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, // Run the protocol handshake using authenticated messages. // TODO: move buffering setup here (out of newFrameRW) - phsrw := newRlpxFrameRW(fd, secrets) - if err := writeProtocolHandshake(phsrw, our); err != nil { + rw := newRlpxFrameRW(fd, secrets) + if err := writeProtocolHandshake(rw, our); err != nil { return nil, fmt.Errorf("protocol write error: %v", err) } - rhs, err := readProtocolHandshake(phsrw, our) + rhs, err := readProtocolHandshake(rw, our) if err != nil { return nil, fmt.Errorf("protocol handshake read error: %v", err) } if rhs.ID != dial.ID { return nil, errors.New("dialed node id mismatch") } - - rw := newFrameRW(fd, msgWriteTimeout) - return &conn{rw, rhs}, nil + return &conn{&lockedRW{wrapped: rw}, rhs}, nil } // outboundEncHandshake negotiates a session token on conn. diff --git a/p2p/message.go b/p2p/message.go index 7adad4b09..d61faad13 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -119,6 +119,25 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error { return w.WriteMsg(NewMsg(code, data...)) } +// lockedRW wraps a MsgReadWriter with locks around +// ReadMsg and WriteMsg. +type lockedRW struct { + rmu, wmu sync.Mutex + wrapped MsgReadWriter +} + +func (rw *lockedRW) ReadMsg() (Msg, error) { + rw.rmu.Lock() + defer rw.rmu.Unlock() + return rw.wrapped.ReadMsg() +} + +func (rw *lockedRW) WriteMsg(msg Msg) error { + rw.wmu.Lock() + defer rw.wmu.Unlock() + return rw.wrapped.WriteMsg(msg) +} + // frameRW is a MsgReadWriter that reads and writes devp2p message frames. // As required by the interface, ReadMsg and WriteMsg can be called from // multiple goroutines. diff --git a/p2p/peer.go b/p2p/peer.go index fb027c834..4982c4612 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -40,6 +40,7 @@ type Peer struct { // Use them to display messages related to the peer. *logger.Logger + conn net.Conn rw *conn running map[string]*protoRW @@ -52,8 +53,9 @@ type Peer struct { // NewPeer returns a peer for testing purposes. func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer { pipe, _ := net.Pipe() - conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps}) - peer := newPeer(conn, nil) + msgpipe, _ := MsgPipe() + conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}} + peer := newPeer(pipe, conn, nil) close(peer.closed) // ensures Disconnect doesn't block return peer } @@ -76,12 +78,12 @@ func (p *Peer) Caps() []Cap { // RemoteAddr returns the remote address of the network connection. func (p *Peer) RemoteAddr() net.Addr { - return p.rw.RemoteAddr() + return p.conn.RemoteAddr() } // LocalAddr returns the local address of the network connection. func (p *Peer) LocalAddr() net.Addr { - return p.rw.LocalAddr() + return p.conn.LocalAddr() } // Disconnect terminates the peer connection with the given reason. @@ -98,10 +100,11 @@ func (p *Peer) String() string { return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr()) } -func newPeer(conn *conn, protocols []Protocol) *Peer { - logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr()) +func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer { + logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], fd.RemoteAddr()) p := &Peer{ Logger: logger.NewLogger(logtag), + conn: fd, rw: conn, running: matchProtocols(protocols, conn.Caps, conn), disc: make(chan DiscReason), @@ -138,7 +141,7 @@ loop: // We rely on protocols to abort if there is a write error. It // might be more robust to handle them here as well. p.DebugDetailf("Read error: %v\n", err) - p.rw.Close() + p.conn.Close() return DiscNetworkError case err := <-p.protoErr: reason = discReasonForError(err) @@ -161,14 +164,14 @@ func (p *Peer) politeDisconnect(reason DiscReason) { EncodeMsg(p.rw, discMsg, uint(reason)) // Wait for the other side to close the connection. // Discard any data that they send until then. - io.Copy(ioutil.Discard, p.rw) + io.Copy(ioutil.Discard, p.conn) close(done) }() select { case <-done: case <-time.After(disconnectGracePeriod): } - p.rw.Close() + p.conn.Close() } func (p *Peer) readLoop() error { diff --git a/p2p/peer_test.go b/p2p/peer_test.go index a1260adbd..1ba43bed5 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -3,6 +3,7 @@ package p2p import ( "bytes" "fmt" + "io" "io/ioutil" "net" "reflect" @@ -29,8 +30,8 @@ var discard = Protocol{ }, } -func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) { - fd1, fd2 := net.Pipe() +func testPeer(protos []Protocol) (io.Closer, *conn, *Peer, <-chan DiscReason) { + fd1, _ := net.Pipe() hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} for _, p := range protos { @@ -38,11 +39,12 @@ func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) { hs2.Caps = append(hs2.Caps, p.cap()) } - peer := newPeer(newConn(fd1, hs1), protos) + p1, p2 := MsgPipe() + peer := newPeer(fd1, &conn{p1, hs1}, protos) errc := make(chan DiscReason, 1) go func() { errc <- peer.run() }() - return newConn(fd2, hs2), peer, errc + return p1, &conn{p2, hs2}, peer, errc } func TestPeerProtoReadMsg(t *testing.T) { @@ -67,8 +69,8 @@ func TestPeerProtoReadMsg(t *testing.T) { }, } - rw, _, errc := testPeer([]Protocol{proto}) - defer rw.Close() + closer, rw, _, errc := testPeer([]Protocol{proto}) + defer closer.Close() EncodeMsg(rw, baseProtocolLength+2, 1) EncodeMsg(rw, baseProtocolLength+3, 2) @@ -105,8 +107,8 @@ func TestPeerProtoReadLargeMsg(t *testing.T) { }, } - rw, _, errc := testPeer([]Protocol{proto}) - defer rw.Close() + closer, rw, _, errc := testPeer([]Protocol{proto}) + defer closer.Close() EncodeMsg(rw, 18, make([]byte, msgsize)) select { @@ -134,8 +136,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) { return nil }, } - rw, _, _ := testPeer([]Protocol{proto}) - defer rw.Close() + closer, rw, _, _ := testPeer([]Protocol{proto}) + defer closer.Close() if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil { t.Error(err) @@ -145,8 +147,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) { func TestPeerWriteForBroadcast(t *testing.T) { defer testlog(t).detach() - rw, peer, peerErr := testPeer([]Protocol{discard}) - defer rw.Close() + closer, rw, peer, peerErr := testPeer([]Protocol{discard}) + defer closer.Close() // test write errors if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { @@ -181,8 +183,8 @@ func TestPeerWriteForBroadcast(t *testing.T) { func TestPeerPing(t *testing.T) { defer testlog(t).detach() - rw, _, _ := testPeer(nil) - defer rw.Close() + closer, rw, _, _ := testPeer(nil) + defer closer.Close() if err := EncodeMsg(rw, pingMsg); err != nil { t.Fatal(err) } @@ -194,15 +196,15 @@ func TestPeerPing(t *testing.T) { func TestPeerDisconnect(t *testing.T) { defer testlog(t).detach() - rw, _, disc := testPeer(nil) - defer rw.Close() + closer, rw, _, disc := testPeer(nil) + defer closer.Close() if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil { t.Fatal(err) } if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil { t.Error(err) } - rw.Close() // make test end faster + closer.Close() // make test end faster if reason := <-disc; reason != DiscRequested { t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested) } diff --git a/p2p/server.go b/p2p/server.go index 3ea2538d1..e53e832aa 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -358,14 +358,15 @@ func (srv *Server) findPeers() { func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) { // TODO: handle/store session token - fd.SetDeadline(time.Now().Add(handshakeTimeout)) + // TODO: reenable deadlines + // fd.SetDeadline(time.Now().Add(handshakeTimeout)) conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest) if err != nil { fd.Close() srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err) return } - p := newPeer(conn, srv.Protocols) + p := newPeer(fd, conn, srv.Protocols) if ok, reason := srv.addPeer(conn.ID, p); !ok { srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason) p.politeDisconnect(reason) @@ -375,7 +376,7 @@ func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) { srvlog.Debugf("Added %v\n", p) srvjslog.LogJson(&logger.P2PConnected{ RemoteId: fmt.Sprintf("%x", conn.ID[:]), - RemoteAddress: conn.RemoteAddr().String(), + RemoteAddress: fd.RemoteAddr().String(), RemoteVersionString: conn.Name, NumConnections: srv.PeerCount(), }) diff --git a/p2p/server_test.go b/p2p/server_test.go index c109fffb9..c348f5a9a 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/p2p/discover" ) @@ -23,8 +24,14 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server { newPeerHook: pf, setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { id := randomID() + rw := newRlpxFrameRW(fd, secrets{ + MAC: zero16, + AES: zero16, + IngressMAC: sha3.NewKeccak256(), + EgressMAC: sha3.NewKeccak256(), + }) return &conn{ - frameRW: newFrameRW(fd, msgWriteTimeout), + MsgReadWriter: rw, protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion}, }, nil }, @@ -143,9 +150,7 @@ func TestServerBroadcast(t *testing.T) { // broadcast one message srv.Broadcast("discard", 0, "foo") - goldbuf := new(bytes.Buffer) - writeMsg(goldbuf, NewMsg(16, "foo")) - golden := goldbuf.Bytes() + golden := unhex("66e94e166f0a2c3b884cfa59ca34") // check that the message has been written everywhere for i, conn := range conns {