diff --git a/eth/handler.go b/eth/handler.go index 3f10750ab..cd1653804 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -287,7 +287,7 @@ func (h *handler) runEthPeer(peer *eth.Peer, handler eth.Handler) error { peer.Log().Error("Ethereum peer registration failed", "err", err) return err } - defer h.removePeer(peer.ID()) + defer h.unregisterPeer(peer.ID()) p := h.peers.peer(peer.ID()) if p == nil { @@ -354,9 +354,16 @@ func (h *handler) runSnapExtension(peer *snap.Peer, handler snap.Handler) error return handler(peer) } -// removePeer unregisters a peer from the downloader and fetchers, removes it from -// the set of tracked peers and closes the network connection to it. +// removePeer requests disconnection of a peer. func (h *handler) removePeer(id string) { + peer := h.peers.peer(id) + if peer != nil { + peer.Peer.Disconnect(p2p.DiscUselessPeer) + } +} + +// unregisterPeer removes a peer from the downloader, fetchers and main peer set. +func (h *handler) unregisterPeer(id string) { // Create a custom logger to avoid printing the entire id var logger log.Logger if len(id) < 16 { @@ -384,8 +391,6 @@ func (h *handler) removePeer(id string) { if err := h.peers.unregisterPeer(id); err != nil { logger.Error("Ethereum peer removal failed", "err", err) } - // Hard disconnect at the networking layer - peer.Peer.Disconnect(p2p.DiscUselessPeer) } func (h *handler) Start(maxPeers int) { diff --git a/eth/handler_eth_test.go b/eth/handler_eth_test.go index 1d38e3b66..038de4699 100644 --- a/eth/handler_eth_test.go +++ b/eth/handler_eth_test.go @@ -144,8 +144,8 @@ func testForkIDSplit(t *testing.T, protocol uint) { defer p2pNoFork.Close() defer p2pProFork.Close() - peerNoFork := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil) - peerProFork := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil) + peerNoFork := eth.NewPeer(protocol, p2p.NewPeerPipe(enode.ID{1}, "", nil, p2pNoFork), p2pNoFork, nil) + peerProFork := eth.NewPeer(protocol, p2p.NewPeerPipe(enode.ID{2}, "", nil, p2pProFork), p2pProFork, nil) defer peerNoFork.Close() defer peerProFork.Close() @@ -206,8 +206,8 @@ func testForkIDSplit(t *testing.T, protocol uint) { defer p2pNoFork.Close() defer p2pProFork.Close() - peerNoFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil) - peerProFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil) + peerNoFork = eth.NewPeer(protocol, p2p.NewPeerPipe(enode.ID{1}, "", nil, p2pNoFork), p2pNoFork, nil) + peerProFork = eth.NewPeer(protocol, p2p.NewPeerPipe(enode.ID{2}, "", nil, p2pProFork), p2pProFork, nil) defer peerNoFork.Close() defer peerProFork.Close() @@ -257,8 +257,8 @@ func testRecvTransactions(t *testing.T, protocol uint) { defer p2pSrc.Close() defer p2pSink.Close() - src := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pSrc, handler.txpool) - sink := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pSink, handler.txpool) + src := eth.NewPeer(protocol, p2p.NewPeerPipe(enode.ID{1}, "", nil, p2pSrc), p2pSrc, handler.txpool) + sink := eth.NewPeer(protocol, p2p.NewPeerPipe(enode.ID{2}, "", nil, p2pSink), p2pSink, handler.txpool) defer src.Close() defer sink.Close() @@ -319,8 +319,8 @@ func testSendTransactions(t *testing.T, protocol uint) { defer p2pSrc.Close() defer p2pSink.Close() - src := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pSrc, handler.txpool) - sink := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pSink, handler.txpool) + src := eth.NewPeer(protocol, p2p.NewPeerPipe(enode.ID{1}, "", nil, p2pSrc), p2pSrc, handler.txpool) + sink := eth.NewPeer(protocol, p2p.NewPeerPipe(enode.ID{2}, "", nil, p2pSink), p2pSink, handler.txpool) defer src.Close() defer sink.Close() @@ -407,8 +407,8 @@ func testTransactionPropagation(t *testing.T, protocol uint) { defer sourcePipe.Close() defer sinkPipe.Close() - sourcePeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{byte(i)}, "", nil), sourcePipe, source.txpool) - sinkPeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{0}, "", nil), sinkPipe, sink.txpool) + sourcePeer := eth.NewPeer(protocol, p2p.NewPeerPipe(enode.ID{byte(i)}, "", nil, sourcePipe), sourcePipe, source.txpool) + sinkPeer := eth.NewPeer(protocol, p2p.NewPeerPipe(enode.ID{0}, "", nil, sinkPipe), sinkPipe, sink.txpool) defer sourcePeer.Close() defer sinkPeer.Close() @@ -490,6 +490,8 @@ func TestCheckpointChallenge(t *testing.T) { } func testCheckpointChallenge(t *testing.T, syncmode downloader.SyncMode, checkpoint bool, timeout bool, empty bool, match bool, drop bool) { + t.Parallel() + // Reduce the checkpoint handshake challenge timeout defer func(old time.Duration) { syncChallengeTimeout = old }(syncChallengeTimeout) syncChallengeTimeout = 250 * time.Millisecond @@ -513,20 +515,26 @@ func testCheckpointChallenge(t *testing.T, syncmode downloader.SyncMode, checkpo handler.handler.checkpointNumber = number handler.handler.checkpointHash = response.Hash() } - // Create a challenger peer and a challenged one + + // Create a challenger peer and a challenged one. p2pLocal, p2pRemote := p2p.MsgPipe() defer p2pLocal.Close() defer p2pRemote.Close() - local := eth.NewPeer(eth.ETH65, p2p.NewPeer(enode.ID{1}, "", nil), p2pLocal, handler.txpool) - remote := eth.NewPeer(eth.ETH65, p2p.NewPeer(enode.ID{2}, "", nil), p2pRemote, handler.txpool) + local := eth.NewPeer(eth.ETH65, p2p.NewPeerPipe(enode.ID{1}, "", nil, p2pLocal), p2pLocal, handler.txpool) + remote := eth.NewPeer(eth.ETH65, p2p.NewPeerPipe(enode.ID{2}, "", nil, p2pRemote), p2pRemote, handler.txpool) defer local.Close() defer remote.Close() - go handler.handler.runEthPeer(local, func(peer *eth.Peer) error { - return eth.Handle((*ethHandler)(handler.handler), peer) - }) - // Run the handshake locally to avoid spinning up a remote handler + handlerDone := make(chan struct{}) + go func() { + defer close(handlerDone) + handler.handler.runEthPeer(local, func(peer *eth.Peer) error { + return eth.Handle((*ethHandler)(handler.handler), peer) + }) + }() + + // Run the handshake locally to avoid spinning up a remote handler. var ( genesis = handler.chain.Genesis() head = handler.chain.CurrentBlock() @@ -535,12 +543,13 @@ func testCheckpointChallenge(t *testing.T, syncmode downloader.SyncMode, checkpo if err := remote.Handshake(1, td, head.Hash(), genesis.Hash(), forkid.NewIDWithChain(handler.chain), forkid.NewFilter(handler.chain)); err != nil { t.Fatalf("failed to run protocol handshake") } - // Connect a new peer and check that we receive the checkpoint challenge + + // Connect a new peer and check that we receive the checkpoint challenge. if checkpoint { if err := remote.ExpectRequestHeadersByNumber(response.Number.Uint64(), 1, 0, false); err != nil { t.Fatalf("challenge mismatch: %v", err) } - // Create a block to reply to the challenge if no timeout is simulated + // Create a block to reply to the challenge if no timeout is simulated. if !timeout { if empty { if err := remote.SendBlockHeaders([]*types.Header{}); err != nil { @@ -557,11 +566,13 @@ func testCheckpointChallenge(t *testing.T, syncmode downloader.SyncMode, checkpo } } } + // Wait until the test timeout passes to ensure proper cleanup time.Sleep(syncChallengeTimeout + 300*time.Millisecond) - // Verify that the remote peer is maintained or dropped + // Verify that the remote peer is maintained or dropped. if drop { + <-handlerDone if peers := handler.handler.peers.len(); peers != 0 { t.Fatalf("peer count mismatch: have %d, want %d", peers, 0) } @@ -608,8 +619,8 @@ func testBroadcastBlock(t *testing.T, peers, bcasts int) { defer sourcePipe.Close() defer sinkPipe.Close() - sourcePeer := eth.NewPeer(eth.ETH65, p2p.NewPeer(enode.ID{byte(i)}, "", nil), sourcePipe, nil) - sinkPeer := eth.NewPeer(eth.ETH65, p2p.NewPeer(enode.ID{0}, "", nil), sinkPipe, nil) + sourcePeer := eth.NewPeer(eth.ETH65, p2p.NewPeerPipe(enode.ID{byte(i)}, "", nil, sourcePipe), sourcePipe, nil) + sinkPeer := eth.NewPeer(eth.ETH65, p2p.NewPeerPipe(enode.ID{0}, "", nil, sinkPipe), sinkPipe, nil) defer sourcePeer.Close() defer sinkPeer.Close() @@ -676,8 +687,8 @@ func testBroadcastMalformedBlock(t *testing.T, protocol uint) { defer p2pSrc.Close() defer p2pSink.Close() - src := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pSrc, source.txpool) - sink := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pSink, source.txpool) + src := eth.NewPeer(protocol, p2p.NewPeerPipe(enode.ID{1}, "", nil, p2pSrc), p2pSrc, source.txpool) + sink := eth.NewPeer(protocol, p2p.NewPeerPipe(enode.ID{2}, "", nil, p2pSink), p2pSink, source.txpool) defer src.Close() defer sink.Close() diff --git a/p2p/peer.go b/p2p/peer.go index 8ebc85839..b6d0dbd1a 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -115,7 +115,8 @@ type Peer struct { disc chan DiscReason // events receives message send / receive events if set - events *event.Feed + events *event.Feed + testPipe *MsgPipeRW // for testing } // NewPeer returns a peer for testing purposes. @@ -128,6 +129,15 @@ func NewPeer(id enode.ID, name string, caps []Cap) *Peer { return peer } +// NewPeerPipe creates a peer for testing purposes. +// The message pipe given as the last parameter is closed when +// Disconnect is called on the peer. +func NewPeerPipe(id enode.ID, name string, caps []Cap, pipe *MsgPipeRW) *Peer { + p := NewPeer(id, name, caps) + p.testPipe = pipe + return p +} + // ID returns the node's public key. func (p *Peer) ID() enode.ID { return p.rw.node.ID() @@ -185,6 +195,10 @@ func (p *Peer) LocalAddr() net.Addr { // Disconnect terminates the peer connection with the given reason. // It returns immediately and does not wait until the connection is closed. func (p *Peer) Disconnect(reason DiscReason) { + if p.testPipe != nil { + p.testPipe.Close() + } + select { case p.disc <- reason: case <-p.closed: