p2p: improve disconnect signaling at handshake time

As of this commit, p2p will disconnect nodes directly after the
encryption handshake if too many peer connections are active.
Errors in the protocol handshake packet are now handled more politely
by sending a disconnect packet before closing the connection.
This commit is contained in:
Felix Lange 2015-04-10 13:25:35 +02:00
parent 99a1db2d40
commit b3c058a9e4
4 changed files with 111 additions and 35 deletions

View File

@ -68,50 +68,61 @@ type protoHandshake struct {
// setupConn starts a protocol session on the given connection. // setupConn starts a protocol session on the given connection.
// It runs the encryption handshake and the protocol handshake. // It runs the encryption handshake and the protocol handshake.
// If dial is non-nil, the connection the local node is the initiator. // If dial is non-nil, the connection the local node is the initiator.
func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { // If atcap is true, the connection will be disconnected with DiscTooManyPeers
// after the key exchange.
func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) {
if dial == nil { if dial == nil {
return setupInboundConn(fd, prv, our) return setupInboundConn(fd, prv, our, atcap)
} else { } else {
return setupOutboundConn(fd, prv, our, dial) return setupOutboundConn(fd, prv, our, dial, atcap)
} }
} }
func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (*conn, error) { func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, atcap bool) (*conn, error) {
secrets, err := receiverEncHandshake(fd, prv, nil) secrets, err := receiverEncHandshake(fd, prv, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("encryption handshake failed: %v", err) return nil, fmt.Errorf("encryption handshake failed: %v", err)
} }
// Run the protocol handshake using authenticated messages.
rw := newRlpxFrameRW(fd, secrets) rw := newRlpxFrameRW(fd, secrets)
rhs, err := readProtocolHandshake(rw, our) if atcap {
SendItems(rw, discMsg, DiscTooManyPeers)
return nil, errors.New("we have too many peers")
}
// Run the protocol handshake using authenticated messages.
rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if rhs.ID != secrets.RemoteID {
return nil, errors.New("node ID in protocol handshake does not match encryption handshake")
}
// TODO: validate that handshake node ID matches
if err := Send(rw, handshakeMsg, our); err != nil { if err := Send(rw, handshakeMsg, our); err != nil {
return nil, fmt.Errorf("protocol write error: %v", err) return nil, fmt.Errorf("protocol handshake write error: %v", err)
} }
return &conn{rw, rhs}, nil return &conn{rw, rhs}, nil
} }
func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) {
secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil) secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("encryption handshake failed: %v", err) return nil, fmt.Errorf("encryption handshake failed: %v", err)
} }
// Run the protocol handshake using authenticated messages.
rw := newRlpxFrameRW(fd, secrets) rw := newRlpxFrameRW(fd, secrets)
if err := Send(rw, handshakeMsg, our); err != nil { if atcap {
return nil, fmt.Errorf("protocol write error: %v", err) SendItems(rw, discMsg, DiscTooManyPeers)
return nil, errors.New("we have too many peers")
} }
rhs, err := readProtocolHandshake(rw, our) // Run the protocol handshake using authenticated messages.
//
// Note that even though writing the handshake is first, we prefer
// returning the handshake read error. If the remote side
// disconnects us early with a valid reason, we should return it
// as the error so it can be tracked elsewhere.
werr := make(chan error)
go func() { werr <- Send(rw, handshakeMsg, our) }()
rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
if err != nil { if err != nil {
return nil, fmt.Errorf("protocol handshake read error: %v", err) return nil, err
}
if err := <-werr; err != nil {
return nil, fmt.Errorf("protocol handshake write error: %v", err)
} }
if rhs.ID != dial.ID { if rhs.ID != dial.ID {
return nil, errors.New("dialed node id mismatch") return nil, errors.New("dialed node id mismatch")
@ -398,18 +409,17 @@ func xor(one, other []byte) (xor []byte) {
return xor return xor
} }
func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) { func readProtocolHandshake(rw MsgReadWriter, wantID discover.NodeID, our *protoHandshake) (*protoHandshake, error) {
// read and handle remote handshake msg, err := rw.ReadMsg()
msg, err := r.ReadMsg()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if msg.Code == discMsg { if msg.Code == discMsg {
// disconnect before protocol handshake is valid according to the // disconnect before protocol handshake is valid according to the
// spec and we send it ourself if Server.addPeer fails. // spec and we send it ourself if Server.addPeer fails.
var reason DiscReason var reason [1]DiscReason
rlp.Decode(msg.Payload, &reason) rlp.Decode(msg.Payload, &reason)
return nil, reason return nil, reason[0]
} }
if msg.Code != handshakeMsg { if msg.Code != handshakeMsg {
return nil, fmt.Errorf("expected handshake, got %x", msg.Code) return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
@ -423,10 +433,16 @@ func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, e
} }
// validate handshake info // validate handshake info
if hs.Version != our.Version { if hs.Version != our.Version {
return nil, newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", baseProtocolVersion, hs.Version) SendItems(rw, discMsg, DiscIncompatibleVersion)
return nil, fmt.Errorf("required version %d, received %d\n", baseProtocolVersion, hs.Version)
} }
if (hs.ID == discover.NodeID{}) { if (hs.ID == discover.NodeID{}) {
return nil, newPeerError(errPubkeyInvalid, "missing") SendItems(rw, discMsg, DiscInvalidIdentity)
return nil, errors.New("invalid public key in handshake")
}
if hs.ID != wantID {
SendItems(rw, discMsg, DiscUnexpectedIdentity)
return nil, errors.New("handshake node ID does not match encryption handshake")
} }
return &hs, nil return &hs, nil
} }

View File

@ -143,7 +143,7 @@ func TestSetupConn(t *testing.T) {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
conn0, err := setupConn(fd0, prv0, hs0, node1) conn0, err := setupConn(fd0, prv0, hs0, node1, false)
if err != nil { if err != nil {
t.Errorf("outbound side error: %v", err) t.Errorf("outbound side error: %v", err)
return return
@ -156,7 +156,7 @@ func TestSetupConn(t *testing.T) {
} }
}() }()
conn1, err := setupConn(fd1, prv1, hs1, nil) conn1, err := setupConn(fd1, prv1, hs1, nil, false)
if err != nil { if err != nil {
t.Fatalf("inbound side error: %v", err) t.Fatalf("inbound side error: %v", err)
} }

View File

@ -99,7 +99,7 @@ type Server struct {
peerConnect chan *discover.Node peerConnect chan *discover.Node
} }
type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node) (*conn, error) type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node, bool) (*conn, error)
type newPeerHook func(*Peer) type newPeerHook func(*Peer)
// Peers returns all connected peers. // Peers returns all connected peers.
@ -261,6 +261,11 @@ func (srv *Server) Stop() {
srv.peerWG.Wait() srv.peerWG.Wait()
} }
// Self returns the local node's endpoint information.
func (srv *Server) Self() *discover.Node {
return srv.ntab.Self()
}
// main loop for adding connections via listening // main loop for adding connections via listening
func (srv *Server) listenLoop() { func (srv *Server) listenLoop() {
defer srv.loopWG.Done() defer srv.loopWG.Done()
@ -354,10 +359,6 @@ func (srv *Server) dialNode(dest *discover.Node) {
srv.startPeer(conn, dest) srv.startPeer(conn, dest)
} }
func (srv *Server) Self() *discover.Node {
return srv.ntab.Self()
}
func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) { func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
// TODO: handle/store session token // TODO: handle/store session token
@ -366,7 +367,10 @@ func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
// returns during that exchange need to call peerWG.Done because // returns during that exchange need to call peerWG.Done because
// the callers of startPeer added the peer to the wait group already. // the callers of startPeer added the peer to the wait group already.
fd.SetDeadline(time.Now().Add(handshakeTimeout)) fd.SetDeadline(time.Now().Add(handshakeTimeout))
conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest) srv.lock.RLock()
atcap := len(srv.peers) == srv.MaxPeers
srv.lock.RUnlock()
conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest, atcap)
if err != nil { if err != nil {
fd.Close() fd.Close()
glog.V(logger.Debug).Infof("Handshake with %v failed: %v", fd.RemoteAddr(), err) glog.V(logger.Debug).Infof("Handshake with %v failed: %v", fd.RemoteAddr(), err)

View File

@ -22,7 +22,7 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
ListenAddr: "127.0.0.1:0", ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(), PrivateKey: newkey(),
newPeerHook: pf, newPeerHook: pf,
setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) {
id := randomID() id := randomID()
rw := newRlpxFrameRW(fd, secrets{ rw := newRlpxFrameRW(fd, secrets{
MAC: zero16, MAC: zero16,
@ -163,6 +163,62 @@ func TestServerBroadcast(t *testing.T) {
} }
} }
// This test checks that connections are disconnected
// just after the encryption handshake when the server is
// at capacity.
//
// It also serves as a light-weight integration test.
func TestServerDisconnectAtCap(t *testing.T) {
defer testlog(t).detach()
started := make(chan *Peer)
srv := &Server{
ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(),
MaxPeers: 10,
NoDial: true,
// This hook signals that the peer was actually started. We
// need to wait for the peer to be started before dialing the
// next connection to get a deterministic peer count.
newPeerHook: func(p *Peer) { started <- p },
}
if err := srv.Start(); err != nil {
t.Fatal(err)
}
defer srv.Stop()
nconns := srv.MaxPeers + 1
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
for i := 0; i < nconns; i++ {
conn, err := dialer.Dial("tcp", srv.ListenAddr)
if err != nil {
t.Fatalf("conn %d: dial error: %v", i, err)
}
// Close the connection when the test ends, before
// shutting down the server.
defer conn.Close()
// Run the handshakes just like a real peer would.
key := newkey()
hs := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
_, err = setupConn(conn, key, hs, srv.Self(), false)
if i == nconns-1 {
// When handling the last connection, the server should
// disconnect immediately instead of running the protocol
// handshake.
if err != DiscTooManyPeers {
t.Errorf("conn %d: got error %q, expected %q", i, err, DiscTooManyPeers)
}
} else {
// For all earlier connections, the handshake should go through.
if err != nil {
t.Fatalf("conn %d: unexpected error: %v", i, err)
}
// Wait for runPeer to be started.
<-started
}
}
}
func newkey() *ecdsa.PrivateKey { func newkey() *ecdsa.PrivateKey {
key, err := crypto.GenerateKey() key, err := crypto.GenerateKey()
if err != nil { if err != nil {