diff --git a/p2p/dial.go b/p2p/dial.go
index 936887a1a..8dee5063f 100644
--- a/p2p/dial.go
+++ b/p2p/dial.go
@@ -17,7 +17,6 @@
package p2p
import (
- "container/heap"
"errors"
"fmt"
"net"
@@ -29,9 +28,10 @@ import (
)
const (
- // This is the amount of time spent waiting in between
- // redialing a certain node.
- dialHistoryExpiration = 30 * time.Second
+ // This is the amount of time spent waiting in between redialing a certain node. The
+ // limit is a bit higher than inboundThrottleTime to prevent failing dials in small
+ // private networks.
+ dialHistoryExpiration = inboundThrottleTime + 5*time.Second
// Discovery lookups are throttled and can only run
// once every few seconds.
@@ -72,16 +72,16 @@ type dialstate struct {
ntab discoverTable
netrestrict *netutil.Netlist
self enode.ID
+ bootnodes []*enode.Node // default dials when there are no peers
+ log log.Logger
+ start time.Time // time when the dialer was first used
lookupRunning bool
dialing map[enode.ID]connFlag
lookupBuf []*enode.Node // current discovery lookup results
randomNodes []*enode.Node // filled from Table
static map[enode.ID]*dialTask
- hist *dialHistory
-
- start time.Time // time when the dialer was first used
- bootnodes []*enode.Node // default dials when there are no peers
+ hist expHeap
}
type discoverTable interface {
@@ -91,15 +91,6 @@ type discoverTable interface {
ReadRandomNodes([]*enode.Node) int
}
-// the dial history remembers recent dials.
-type dialHistory []pastDial
-
-// pastDial is an entry in the dial history.
-type pastDial struct {
- id enode.ID
- exp time.Time
-}
-
type task interface {
Do(*Server)
}
@@ -126,20 +117,23 @@ type waitExpireTask struct {
time.Duration
}
-func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
+func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate {
s := &dialstate{
maxDynDials: maxdyn,
ntab: ntab,
self: self,
- netrestrict: netrestrict,
+ netrestrict: cfg.NetRestrict,
+ log: cfg.Logger,
static: make(map[enode.ID]*dialTask),
dialing: make(map[enode.ID]connFlag),
- bootnodes: make([]*enode.Node, len(bootnodes)),
+ bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)),
randomNodes: make([]*enode.Node, maxdyn/2),
- hist: new(dialHistory),
}
- copy(s.bootnodes, bootnodes)
- for _, n := range static {
+ copy(s.bootnodes, cfg.BootstrapNodes)
+ if s.log == nil {
+ s.log = log.Root()
+ }
+ for _, n := range cfg.StaticNodes {
s.addStatic(n)
}
return s
@@ -154,9 +148,6 @@ func (s *dialstate) addStatic(n *enode.Node) {
func (s *dialstate) removeStatic(n *enode.Node) {
// This removes a task so future attempts to connect will not be made.
delete(s.static, n.ID())
- // This removes a previous dial timestamp so that application
- // can force a server to reconnect with chosen peer immediately.
- s.hist.remove(n.ID())
}
func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task {
@@ -167,7 +158,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
var newtasks []task
addDial := func(flag connFlag, n *enode.Node) bool {
if err := s.checkDial(n, peers); err != nil {
- log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err)
+ s.log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err)
return false
}
s.dialing[n.ID()] = flag
@@ -196,7 +187,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
err := s.checkDial(t.dest, peers)
switch err {
case errNotWhitelisted, errSelf:
- log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err)
+ s.log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err)
delete(s.static, t.dest.ID())
case nil:
s.dialing[id] = t.flags
@@ -246,7 +237,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
// This should prevent cases where the dialer logic is not ticked
// because there are no pending events.
if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
- t := &waitExpireTask{s.hist.min().exp.Sub(now)}
+ t := &waitExpireTask{s.hist.nextExpiry().Sub(now)}
newtasks = append(newtasks, t)
}
return newtasks
@@ -271,7 +262,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
return errSelf
case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()):
return errNotWhitelisted
- case s.hist.contains(n.ID()):
+ case s.hist.contains(string(n.ID().Bytes())):
return errRecentlyDialed
}
return nil
@@ -280,7 +271,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
func (s *dialstate) taskDone(t task, now time.Time) {
switch t := t.(type) {
case *dialTask:
- s.hist.add(t.dest.ID(), now.Add(dialHistoryExpiration))
+ s.hist.add(string(t.dest.ID().Bytes()), now.Add(dialHistoryExpiration))
delete(s.dialing, t.dest.ID())
case *discoverTask:
s.lookupRunning = false
@@ -296,7 +287,7 @@ func (t *dialTask) Do(srv *Server) {
}
err := t.dial(srv, t.dest)
if err != nil {
- log.Trace("Dial error", "task", t, "err", err)
+ srv.log.Trace("Dial error", "task", t, "err", err)
// Try resolving the ID of static nodes if dialing failed.
if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
if t.resolve(srv) {
@@ -314,7 +305,7 @@ func (t *dialTask) Do(srv *Server) {
// The backoff delay resets when the node is found.
func (t *dialTask) resolve(srv *Server) bool {
if srv.ntab == nil {
- log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
+ srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
return false
}
if t.resolveDelay == 0 {
@@ -330,13 +321,13 @@ func (t *dialTask) resolve(srv *Server) bool {
if t.resolveDelay > maxResolveDelay {
t.resolveDelay = maxResolveDelay
}
- log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay)
+ srv.log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay)
return false
}
// The node was found.
t.resolveDelay = initialResolveDelay
t.dest = resolved
- log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
+ srv.log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
return true
}
@@ -385,49 +376,3 @@ func (t waitExpireTask) Do(*Server) {
func (t waitExpireTask) String() string {
return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
}
-
-// Use only these methods to access or modify dialHistory.
-func (h dialHistory) min() pastDial {
- return h[0]
-}
-func (h *dialHistory) add(id enode.ID, exp time.Time) {
- heap.Push(h, pastDial{id, exp})
-
-}
-func (h *dialHistory) remove(id enode.ID) bool {
- for i, v := range *h {
- if v.id == id {
- heap.Remove(h, i)
- return true
- }
- }
- return false
-}
-func (h dialHistory) contains(id enode.ID) bool {
- for _, v := range h {
- if v.id == id {
- return true
- }
- }
- return false
-}
-func (h *dialHistory) expire(now time.Time) {
- for h.Len() > 0 && h.min().exp.Before(now) {
- heap.Pop(h)
- }
-}
-
-// heap.Interface boilerplate
-func (h dialHistory) Len() int { return len(h) }
-func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
-func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
-func (h *dialHistory) Push(x interface{}) {
- *h = append(*h, x.(pastDial))
-}
-func (h *dialHistory) Pop() interface{} {
- old := *h
- n := len(old)
- x := old[n-1]
- *h = old[0 : n-1]
- return x
-}
diff --git a/p2p/dial_test.go b/p2p/dial_test.go
index f41ab7752..de8fc4a6e 100644
--- a/p2p/dial_test.go
+++ b/p2p/dial_test.go
@@ -20,10 +20,13 @@ import (
"encoding/binary"
"net"
"reflect"
+ "strings"
"testing"
"time"
"github.com/davecgh/go-spew/spew"
+ "github.com/ethereum/go-ethereum/internal/testlog"
+ "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/netutil"
@@ -67,10 +70,10 @@ func runDialTest(t *testing.T, test dialtest) {
new := test.init.newTasks(running, pm(round.peers), vtime)
if !sametasks(new, round.new) {
- t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n",
+ t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v",
i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
}
- t.Log("tasks:", spew.Sdump(new))
+ t.Logf("round %d new tasks: %s", i, strings.TrimSpace(spew.Sdump(new)))
// Time advances by 16 seconds on every round.
vtime = vtime.Add(16 * time.Second)
@@ -88,8 +91,9 @@ func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t)
// This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) {
+ config := &Config{Logger: testlog.Logger(t, log.LvlTrace)}
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, nil, nil, fakeTable{}, 5, nil),
+ init: newDialState(enode.ID{}, fakeTable{}, 5, config),
rounds: []round{
// A discovery query is launched.
{
@@ -153,7 +157,7 @@ func TestDialStateDynDial(t *testing.T) {
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{
- &waitExpireTask{Duration: 14 * time.Second},
+ &waitExpireTask{Duration: 19 * time.Second},
},
},
// In this round, the peer with id 2 drops off. The query
@@ -223,10 +227,13 @@ func TestDialStateDynDial(t *testing.T) {
// Tests that bootnodes are dialed if no peers are connectd, but not otherwise.
func TestDialStateDynDialBootnode(t *testing.T) {
- bootnodes := []*enode.Node{
- newNode(uintID(1), nil),
- newNode(uintID(2), nil),
- newNode(uintID(3), nil),
+ config := &Config{
+ BootstrapNodes: []*enode.Node{
+ newNode(uintID(1), nil),
+ newNode(uintID(2), nil),
+ newNode(uintID(3), nil),
+ },
+ Logger: testlog.Logger(t, log.LvlTrace),
}
table := fakeTable{
newNode(uintID(4), nil),
@@ -236,7 +243,7 @@ func TestDialStateDynDialBootnode(t *testing.T) {
newNode(uintID(8), nil),
}
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, nil, bootnodes, table, 5, nil),
+ init: newDialState(enode.ID{}, table, 5, config),
rounds: []round{
// 2 dynamic dials attempted, bootnodes pending fallback interval
{
@@ -259,25 +266,24 @@ func TestDialStateDynDialBootnode(t *testing.T) {
{
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
},
// No dials succeed, 2nd bootnode is attempted
{
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
},
// No dials succeed, 3rd bootnode is attempted
{
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
+ &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
@@ -288,21 +294,19 @@ func TestDialStateDynDialBootnode(t *testing.T) {
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
},
- new: []task{
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
- },
+ new: []task{},
},
// Random dial succeeds, no more bootnodes are attempted
{
+ new: []task{
+ &waitExpireTask{3 * time.Second},
+ },
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
},
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
- &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
},
},
@@ -324,7 +328,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
}
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, nil, nil, table, 10, nil),
+ init: newDialState(enode.ID{}, table, 10, &Config{Logger: testlog.Logger(t, log.LvlTrace)}),
rounds: []round{
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
{
@@ -430,7 +434,7 @@ func TestDialStateNetRestrict(t *testing.T) {
restrict.Add("127.0.2.0/24")
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, nil, nil, table, 10, restrict),
+ init: newDialState(enode.ID{}, table, 10, &Config{NetRestrict: restrict}),
rounds: []round{
{
new: []task{
@@ -444,16 +448,18 @@ func TestDialStateNetRestrict(t *testing.T) {
// This test checks that static dials are launched.
func TestDialStateStaticDial(t *testing.T) {
- wantStatic := []*enode.Node{
- newNode(uintID(1), nil),
- newNode(uintID(2), nil),
- newNode(uintID(3), nil),
- newNode(uintID(4), nil),
- newNode(uintID(5), nil),
+ config := &Config{
+ StaticNodes: []*enode.Node{
+ newNode(uintID(1), nil),
+ newNode(uintID(2), nil),
+ newNode(uintID(3), nil),
+ newNode(uintID(4), nil),
+ newNode(uintID(5), nil),
+ },
+ Logger: testlog.Logger(t, log.LvlTrace),
}
-
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
+ init: newDialState(enode.ID{}, fakeTable{}, 0, config),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
@@ -495,7 +501,7 @@ func TestDialStateStaticDial(t *testing.T) {
&dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{
- &waitExpireTask{Duration: 14 * time.Second},
+ &waitExpireTask{Duration: 19 * time.Second},
},
},
// Wait a round for dial history to expire, no new tasks should spawn.
@@ -511,6 +517,9 @@ func TestDialStateStaticDial(t *testing.T) {
// If a static node is dropped, it should be immediately redialed,
// irrespective whether it was originally static or dynamic.
{
+ done: []task{
+ &waitExpireTask{Duration: 19 * time.Second},
+ },
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
@@ -518,67 +527,24 @@ func TestDialStateStaticDial(t *testing.T) {
},
new: []task{
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
- &dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)},
},
},
},
})
}
-// This test checks that static peers will be redialed immediately if they were re-added to a static list.
-func TestDialStaticAfterReset(t *testing.T) {
- wantStatic := []*enode.Node{
- newNode(uintID(1), nil),
- newNode(uintID(2), nil),
- }
-
- rounds := []round{
- // Static dials are launched for the nodes that aren't yet connected.
- {
- peers: nil,
- new: []task{
- &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
- &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
- },
- },
- // No new dial tasks, all peers are connected.
- {
- peers: []*Peer{
- {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
- {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
- },
- done: []task{
- &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
- &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
- },
- new: []task{
- &waitExpireTask{Duration: 30 * time.Second},
- },
- },
- }
- dTest := dialtest{
- init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
- rounds: rounds,
- }
- runDialTest(t, dTest)
- for _, n := range wantStatic {
- dTest.init.removeStatic(n)
- dTest.init.addStatic(n)
- }
- // without removing peers they will be considered recently dialed
- runDialTest(t, dTest)
-}
-
// This test checks that past dials are not retried for some time.
func TestDialStateCache(t *testing.T) {
- wantStatic := []*enode.Node{
- newNode(uintID(1), nil),
- newNode(uintID(2), nil),
- newNode(uintID(3), nil),
+ config := &Config{
+ StaticNodes: []*enode.Node{
+ newNode(uintID(1), nil),
+ newNode(uintID(2), nil),
+ newNode(uintID(3), nil),
+ },
+ Logger: testlog.Logger(t, log.LvlTrace),
}
-
runDialTest(t, dialtest{
- init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
+ init: newDialState(enode.ID{}, fakeTable{}, 0, config),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
@@ -606,28 +572,37 @@ func TestDialStateCache(t *testing.T) {
// entry to expire.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
},
done: []task{
&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
},
new: []task{
- &waitExpireTask{Duration: 14 * time.Second},
+ &waitExpireTask{Duration: 19 * time.Second},
},
},
// Still waiting for node 3's entry to expire in the cache.
{
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
+ },
+ },
+ {
+ peers: []*Peer{
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
},
},
// The cache entry for node 3 has expired and is retried.
{
+ done: []task{
+ &waitExpireTask{Duration: 19 * time.Second},
+ },
peers: []*Peer{
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
- {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
+ {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
},
new: []task{
&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
@@ -638,9 +613,13 @@ func TestDialStateCache(t *testing.T) {
}
func TestDialResolve(t *testing.T) {
+ config := &Config{
+ Logger: testlog.Logger(t, log.LvlTrace),
+ Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}},
+ }
resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
table := &resolveMock{answer: resolved}
- state := newDialState(enode.ID{}, nil, nil, table, 0, nil)
+ state := newDialState(enode.ID{}, table, 0, config)
// Check that the task is generated with an incomplete ID.
dest := newNode(uintID(1), nil)
@@ -651,8 +630,7 @@ func TestDialResolve(t *testing.T) {
}
// Now run the task, it should resolve the ID once.
- config := Config{Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}}
- srv := &Server{ntab: table, Config: config}
+ srv := &Server{ntab: table, log: config.Logger, Config: *config}
tasks[0].Do(srv)
if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) {
t.Fatalf("wrong resolve calls, got %v", table.resolveCalls)
diff --git a/p2p/netutil/addrutil.go b/p2p/netutil/addrutil.go
new file mode 100644
index 000000000..b261a5295
--- /dev/null
+++ b/p2p/netutil/addrutil.go
@@ -0,0 +1,33 @@
+// Copyright 2016 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package netutil
+
+import "net"
+
+// AddrIP gets the IP address contained in addr. It returns nil if no address is present.
+func AddrIP(addr net.Addr) net.IP {
+ switch a := addr.(type) {
+ case *net.IPAddr:
+ return a.IP
+ case *net.TCPAddr:
+ return a.IP
+ case *net.UDPAddr:
+ return a.IP
+ default:
+ return nil
+ }
+}
diff --git a/p2p/peer.go b/p2p/peer.go
index af019d07a..98ea6835d 100644
--- a/p2p/peer.go
+++ b/p2p/peer.go
@@ -120,7 +120,7 @@ func NewPeer(id enode.ID, name string, caps []Cap) *Peer {
pipe, _ := net.Pipe()
node := enode.SignNull(new(enr.Record), id)
conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name}
- peer := newPeer(conn, nil)
+ peer := newPeer(log.Root(), conn, nil)
close(peer.closed) // ensures Disconnect doesn't block
return peer
}
@@ -176,7 +176,7 @@ func (p *Peer) Inbound() bool {
return p.rw.is(inboundConn)
}
-func newPeer(conn *conn, protocols []Protocol) *Peer {
+func newPeer(log log.Logger, conn *conn, protocols []Protocol) *Peer {
protomap := matchProtocols(protocols, conn.caps, conn)
p := &Peer{
rw: conn,
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
index 5aa64a32e..984cc411a 100644
--- a/p2p/peer_test.go
+++ b/p2p/peer_test.go
@@ -24,6 +24,8 @@ import (
"reflect"
"testing"
"time"
+
+ "github.com/ethereum/go-ethereum/log"
)
var discard = Protocol{
@@ -52,7 +54,7 @@ func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) {
c2.caps = append(c2.caps, p.cap())
}
- peer := newPeer(c1, protos)
+ peer := newPeer(log.Root(), c1, protos)
errc := make(chan error, 1)
go func() {
_, err := peer.run()
diff --git a/p2p/server.go b/p2p/server.go
index b3494cc88..a373904fc 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -22,6 +22,7 @@ import (
"crypto/ecdsa"
"encoding/hex"
"errors"
+ "fmt"
"net"
"sort"
"sync"
@@ -49,6 +50,9 @@ const (
defaultMaxPendingPeers = 50
defaultDialRatio = 3
+ // This time limits inbound connection attempts per source IP.
+ inboundThrottleTime = 30 * time.Second
+
// Maximum time allowed for reading a complete message.
// This is effectively the amount of time a connection can be idle.
frameReadTimeout = 30 * time.Second
@@ -158,6 +162,7 @@ type Server struct {
// the whole protocol stack.
newTransport func(net.Conn) transport
newPeerHook func(*Peer)
+ listenFunc func(network, addr string) (net.Listener, error)
lock sync.Mutex // protects running
running bool
@@ -167,24 +172,26 @@ type Server struct {
ntab discoverTable
listener net.Listener
ourHandshake *protoHandshake
- lastLookup time.Time
DiscV5 *discv5.Network
+ loopWG sync.WaitGroup // loop, listenLoop
+ peerFeed event.Feed
+ log log.Logger
- // These are for Peers, PeerCount (and nothing else).
- peerOp chan peerOpFunc
- peerOpDone chan struct{}
+ // Channels into the run loop.
+ quit chan struct{}
+ addstatic chan *enode.Node
+ removestatic chan *enode.Node
+ addtrusted chan *enode.Node
+ removetrusted chan *enode.Node
+ peerOp chan peerOpFunc
+ peerOpDone chan struct{}
+ delpeer chan peerDrop
+ checkpointPostHandshake chan *conn
+ checkpointAddPeer chan *conn
- quit chan struct{}
- addstatic chan *enode.Node
- removestatic chan *enode.Node
- addtrusted chan *enode.Node
- removetrusted chan *enode.Node
- posthandshake chan *conn
- addpeer chan *conn
- delpeer chan peerDrop
- loopWG sync.WaitGroup // loop, listenLoop
- peerFeed event.Feed
- log log.Logger
+ // State of run loop and listenLoop.
+ lastLookup time.Time
+ inboundHistory expHeap
}
type peerOpFunc func(map[enode.ID]*Peer)
@@ -415,7 +422,7 @@ func (srv *Server) Start() (err error) {
srv.running = true
srv.log = srv.Config.Logger
if srv.log == nil {
- srv.log = log.New()
+ srv.log = log.Root()
}
if srv.NoDial && srv.ListenAddr == "" {
srv.log.Warn("P2P server will be useless, neither dialing nor listening")
@@ -428,13 +435,16 @@ func (srv *Server) Start() (err error) {
if srv.newTransport == nil {
srv.newTransport = newRLPX
}
+ if srv.listenFunc == nil {
+ srv.listenFunc = net.Listen
+ }
if srv.Dialer == nil {
srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}}
}
srv.quit = make(chan struct{})
- srv.addpeer = make(chan *conn)
srv.delpeer = make(chan peerDrop)
- srv.posthandshake = make(chan *conn)
+ srv.checkpointPostHandshake = make(chan *conn)
+ srv.checkpointAddPeer = make(chan *conn)
srv.addstatic = make(chan *enode.Node)
srv.removestatic = make(chan *enode.Node)
srv.addtrusted = make(chan *enode.Node)
@@ -455,7 +465,7 @@ func (srv *Server) Start() (err error) {
}
dynPeers := srv.maxDialedConns()
- dialer := newDialState(srv.localnode.ID(), srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
+ dialer := newDialState(srv.localnode.ID(), srv.ntab, dynPeers, &srv.Config)
srv.loopWG.Add(1)
go srv.run(dialer)
return nil
@@ -541,6 +551,7 @@ func (srv *Server) setupDiscovery() error {
NetRestrict: srv.NetRestrict,
Bootnodes: srv.BootstrapNodes,
Unhandled: unhandled,
+ Log: srv.log,
}
ntab, err := discover.ListenUDP(conn, srv.localnode, cfg)
if err != nil {
@@ -569,27 +580,28 @@ func (srv *Server) setupDiscovery() error {
}
func (srv *Server) setupListening() error {
- // Launch the TCP listener.
- listener, err := net.Listen("tcp", srv.ListenAddr)
+ // Launch the listener.
+ listener, err := srv.listenFunc("tcp", srv.ListenAddr)
if err != nil {
return err
}
- laddr := listener.Addr().(*net.TCPAddr)
- srv.ListenAddr = laddr.String()
srv.listener = listener
- srv.localnode.Set(enr.TCP(laddr.Port))
+ srv.ListenAddr = listener.Addr().String()
+
+ // Update the local node record and map the TCP listening port if NAT is configured.
+ if tcp, ok := listener.Addr().(*net.TCPAddr); ok {
+ srv.localnode.Set(enr.TCP(tcp.Port))
+ if !tcp.IP.IsLoopback() && srv.NAT != nil {
+ srv.loopWG.Add(1)
+ go func() {
+ nat.Map(srv.NAT, srv.quit, "tcp", tcp.Port, tcp.Port, "ethereum p2p")
+ srv.loopWG.Done()
+ }()
+ }
+ }
srv.loopWG.Add(1)
go srv.listenLoop()
-
- // Map the TCP listening port if NAT is configured.
- if !laddr.IP.IsLoopback() && srv.NAT != nil {
- srv.loopWG.Add(1)
- go func() {
- nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p")
- srv.loopWG.Done()
- }()
- }
return nil
}
@@ -657,12 +669,14 @@ running:
case <-srv.quit:
// The server was stopped. Run the cleanup logic.
break running
+
case n := <-srv.addstatic:
// This channel is used by AddPeer to add to the
// ephemeral static peer list. Add it to the dialer,
// it will keep the node connected.
srv.log.Trace("Adding static node", "node", n)
dialstate.addStatic(n)
+
case n := <-srv.removestatic:
// This channel is used by RemovePeer to send a
// disconnect request to a peer and begin the
@@ -672,6 +686,7 @@ running:
if p, ok := peers[n.ID()]; ok {
p.Disconnect(DiscRequested)
}
+
case n := <-srv.addtrusted:
// This channel is used by AddTrustedPeer to add an enode
// to the trusted node set.
@@ -681,6 +696,7 @@ running:
if p, ok := peers[n.ID()]; ok {
p.rw.set(trustedConn, true)
}
+
case n := <-srv.removetrusted:
// This channel is used by RemoveTrustedPeer to remove an enode
// from the trusted node set.
@@ -691,10 +707,12 @@ running:
if p, ok := peers[n.ID()]; ok {
p.rw.set(trustedConn, false)
}
+
case op := <-srv.peerOp:
// This channel is used by Peers and PeerCount.
op(peers)
srv.peerOpDone <- struct{}{}
+
case t := <-taskdone:
// A task got done. Tell dialstate about it so it
// can update its state and remove it from the active
@@ -702,7 +720,8 @@ running:
srv.log.Trace("Dial task done", "task", t)
dialstate.taskDone(t, time.Now())
delTask(t)
- case c := <-srv.posthandshake:
+
+ case c := <-srv.checkpointPostHandshake:
// A connection has passed the encryption handshake so
// the remote identity is known (but hasn't been verified yet).
if trusted[c.node.ID()] {
@@ -710,18 +729,15 @@ running:
c.flags |= trustedConn
}
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
- select {
- case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c):
- case <-srv.quit:
- break running
- }
- case c := <-srv.addpeer:
+ c.cont <- srv.postHandshakeChecks(peers, inboundCount, c)
+
+ case c := <-srv.checkpointAddPeer:
// At this point the connection is past the protocol handshake.
// Its capabilities are known and the remote identity is verified.
- err := srv.protoHandshakeChecks(peers, inboundCount, c)
+ err := srv.addPeerChecks(peers, inboundCount, c)
if err == nil {
// The handshakes are done and it passed all checks.
- p := newPeer(c, srv.Protocols)
+ p := newPeer(srv.log, c, srv.Protocols)
// If message events are enabled, pass the peerFeed
// to the peer
if srv.EnableMsgEvents {
@@ -738,11 +754,8 @@ running:
// The dialer logic relies on the assumption that
// dial tasks complete after the peer has been added or
// discarded. Unblock the task last.
- select {
- case c.cont <- err:
- case <-srv.quit:
- break running
- }
+ c.cont <- err
+
case pd := <-srv.delpeer:
// A peer disconnected.
d := common.PrettyDuration(mclock.Now() - pd.created)
@@ -777,17 +790,7 @@ running:
}
}
-func (srv *Server) protoHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
- // Drop connections with no matching protocols.
- if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
- return DiscUselessPeer
- }
- // Repeat the encryption handshake checks because the
- // peer set might have changed between the handshakes.
- return srv.encHandshakeChecks(peers, inboundCount, c)
-}
-
-func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
+func (srv *Server) postHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
switch {
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
return DiscTooManyPeers
@@ -802,9 +805,20 @@ func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int
}
}
+func (srv *Server) addPeerChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
+ // Drop connections with no matching protocols.
+ if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
+ return DiscUselessPeer
+ }
+ // Repeat the post-handshake checks because the
+ // peer set might have changed since those checks were performed.
+ return srv.postHandshakeChecks(peers, inboundCount, c)
+}
+
func (srv *Server) maxInboundConns() int {
return srv.MaxPeers - srv.maxDialedConns()
}
+
func (srv *Server) maxDialedConns() int {
if srv.NoDiscovery || srv.NoDial {
return 0
@@ -832,7 +846,7 @@ func (srv *Server) listenLoop() {
}
for {
- // Wait for a handshake slot before accepting.
+ // Wait for a free slot before accepting.
<-slots
var (
@@ -851,21 +865,16 @@ func (srv *Server) listenLoop() {
break
}
- // Reject connections that do not match NetRestrict.
- if srv.NetRestrict != nil {
- if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) {
- srv.log.Debug("Rejected conn (not whitelisted in NetRestrict)", "addr", fd.RemoteAddr())
- fd.Close()
- slots <- struct{}{}
- continue
- }
+ remoteIP := netutil.AddrIP(fd.RemoteAddr())
+ if err := srv.checkInboundConn(fd, remoteIP); err != nil {
+ srv.log.Debug("Rejected inbound connnection", "addr", fd.RemoteAddr(), "err", err)
+ fd.Close()
+ slots <- struct{}{}
+ continue
}
-
- var ip net.IP
- if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok {
- ip = tcp.IP
+ if remoteIP != nil {
+ fd = newMeteredConn(fd, true, remoteIP)
}
- fd = newMeteredConn(fd, true, ip)
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
go func() {
srv.SetupConn(fd, inboundConn, nil)
@@ -874,6 +883,22 @@ func (srv *Server) listenLoop() {
}
}
+func (srv *Server) checkInboundConn(fd net.Conn, remoteIP net.IP) error {
+ if remoteIP != nil {
+ // Reject connections that do not match NetRestrict.
+ if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) {
+ return fmt.Errorf("not whitelisted in NetRestrict")
+ }
+ // Reject Internet peers that try too often.
+ srv.inboundHistory.expire(time.Now())
+ if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) {
+ return fmt.Errorf("too many attempts")
+ }
+ srv.inboundHistory.add(remoteIP.String(), time.Now().Add(inboundThrottleTime))
+ }
+ return nil
+}
+
// SetupConn runs the handshakes and attempts to add the connection
// as a peer. It returns when the connection has been added as a peer
// or the handshakes have failed.
@@ -895,6 +920,7 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
if !running {
return errServerStopped
}
+
// If dialing, figure out the remote public key.
var dialPubkey *ecdsa.PublicKey
if dialDest != nil {
@@ -903,7 +929,8 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
return errors.New("dial destination doesn't have a secp256k1 public key")
}
}
- // Run the encryption handshake.
+
+ // Run the RLPx handshake.
remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey)
if err != nil {
srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err)
@@ -922,12 +949,13 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
conn.handshakeDone(c.node.ID())
}
clog := srv.log.New("id", c.node.ID(), "addr", c.fd.RemoteAddr(), "conn", c.flags)
- err = srv.checkpoint(c, srv.posthandshake)
+ err = srv.checkpoint(c, srv.checkpointPostHandshake)
if err != nil {
- clog.Trace("Rejected peer before protocol handshake", "err", err)
+ clog.Trace("Rejected peer", "err", err)
return err
}
- // Run the protocol handshake
+
+ // Run the capability negotiation handshake.
phs, err := c.doProtoHandshake(srv.ourHandshake)
if err != nil {
clog.Trace("Failed proto handshake", "err", err)
@@ -938,14 +966,15 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
return DiscUnexpectedIdentity
}
c.caps, c.name = phs.Caps, phs.Name
- err = srv.checkpoint(c, srv.addpeer)
+ err = srv.checkpoint(c, srv.checkpointAddPeer)
if err != nil {
clog.Trace("Rejected peer", "err", err)
return err
}
- // If the checks completed successfully, runPeer has now been
- // launched by run.
- clog.Trace("connection set up", "inbound", dialDest == nil)
+
+ // If the checks completed successfully, the connection has been added as a peer and
+ // runPeer has been launched.
+ clog.Trace("Connection set up", "inbound", dialDest == nil)
return nil
}
@@ -974,12 +1003,7 @@ func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error {
case <-srv.quit:
return errServerStopped
}
- select {
- case err := <-c.cont:
- return err
- case <-srv.quit:
- return errServerStopped
- }
+ return <-c.cont
}
// runPeer runs in its own goroutine for each peer.
diff --git a/p2p/server_test.go b/p2p/server_test.go
index f665c1424..e8bc627e1 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -19,6 +19,7 @@ package p2p
import (
"crypto/ecdsa"
"errors"
+ "io"
"math/rand"
"net"
"reflect"
@@ -26,6 +27,7 @@ import (
"time"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/internal/testlog"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
@@ -74,6 +76,7 @@ func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *
MaxPeers: 10,
ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(),
+ Logger: testlog.Logger(t, log.LvlTrace),
}
server := &Server{
Config: config,
@@ -359,6 +362,7 @@ func TestServerAtCap(t *testing.T) {
PrivateKey: newkey(),
MaxPeers: 10,
NoDial: true,
+ NoDiscovery: true,
TrustedNodes: []*enode.Node{newNode(trustedID, nil)},
},
}
@@ -377,19 +381,19 @@ func TestServerAtCap(t *testing.T) {
// Inject a few connections to fill up the peer set.
for i := 0; i < 10; i++ {
c := newconn(randomID())
- if err := srv.checkpoint(c, srv.addpeer); err != nil {
+ if err := srv.checkpoint(c, srv.checkpointAddPeer); err != nil {
t.Fatalf("could not add conn %d: %v", i, err)
}
}
// Try inserting a non-trusted connection.
anotherID := randomID()
c := newconn(anotherID)
- if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
+ if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
t.Error("wrong error for insert:", err)
}
// Try inserting a trusted connection.
c = newconn(trustedID)
- if err := srv.checkpoint(c, srv.posthandshake); err != nil {
+ if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
t.Error("unexpected error for trusted conn @posthandshake:", err)
}
if !c.is(trustedConn) {
@@ -399,14 +403,14 @@ func TestServerAtCap(t *testing.T) {
// Remove from trusted set and try again
srv.RemoveTrustedPeer(newNode(trustedID, nil))
c = newconn(trustedID)
- if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
+ if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
t.Error("wrong error for insert:", err)
}
// Add anotherID to trusted set and try again
srv.AddTrustedPeer(newNode(anotherID, nil))
c = newconn(anotherID)
- if err := srv.checkpoint(c, srv.posthandshake); err != nil {
+ if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
t.Error("unexpected error for trusted conn @posthandshake:", err)
}
if !c.is(trustedConn) {
@@ -430,10 +434,11 @@ func TestServerPeerLimits(t *testing.T) {
srv := &Server{
Config: Config{
- PrivateKey: srvkey,
- MaxPeers: 0,
- NoDial: true,
- Protocols: []Protocol{discard},
+ PrivateKey: srvkey,
+ MaxPeers: 0,
+ NoDial: true,
+ NoDiscovery: true,
+ Protocols: []Protocol{discard},
},
newTransport: func(fd net.Conn) transport { return tp },
log: log.New(),
@@ -541,29 +546,35 @@ func TestServerSetupConn(t *testing.T) {
}
for i, test := range tests {
- srv := &Server{
- Config: Config{
- PrivateKey: srvkey,
- MaxPeers: 10,
- NoDial: true,
- Protocols: []Protocol{discard},
- },
- newTransport: func(fd net.Conn) transport { return test.tt },
- log: log.New(),
- }
- if !test.dontstart {
- if err := srv.Start(); err != nil {
- t.Fatalf("couldn't start server: %v", err)
+ t.Run(test.wantCalls, func(t *testing.T) {
+ cfg := Config{
+ PrivateKey: srvkey,
+ MaxPeers: 10,
+ NoDial: true,
+ NoDiscovery: true,
+ Protocols: []Protocol{discard},
+ Logger: testlog.Logger(t, log.LvlTrace),
}
- }
- p1, _ := net.Pipe()
- srv.SetupConn(p1, test.flags, test.dialDest)
- if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
- t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
- }
- if test.tt.calls != test.wantCalls {
- t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
- }
+ srv := &Server{
+ Config: cfg,
+ newTransport: func(fd net.Conn) transport { return test.tt },
+ log: cfg.Logger,
+ }
+ if !test.dontstart {
+ if err := srv.Start(); err != nil {
+ t.Fatalf("couldn't start server: %v", err)
+ }
+ defer srv.Stop()
+ }
+ p1, _ := net.Pipe()
+ srv.SetupConn(p1, test.flags, test.dialDest)
+ if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
+ t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
+ }
+ if test.tt.calls != test.wantCalls {
+ t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
+ }
+ })
}
}
@@ -616,3 +627,100 @@ func randomID() (id enode.ID) {
}
return id
}
+
+// This test checks that inbound connections are throttled by IP.
+func TestServerInboundThrottle(t *testing.T) {
+ const timeout = 5 * time.Second
+ newTransportCalled := make(chan struct{})
+ srv := &Server{
+ Config: Config{
+ PrivateKey: newkey(),
+ ListenAddr: "127.0.0.1:0",
+ MaxPeers: 10,
+ NoDial: true,
+ NoDiscovery: true,
+ Protocols: []Protocol{discard},
+ Logger: testlog.Logger(t, log.LvlTrace),
+ },
+ newTransport: func(fd net.Conn) transport {
+ newTransportCalled <- struct{}{}
+ return newRLPX(fd)
+ },
+ listenFunc: func(network, laddr string) (net.Listener, error) {
+ fakeAddr := &net.TCPAddr{IP: net.IP{95, 33, 21, 2}, Port: 4444}
+ return listenFakeAddr(network, laddr, fakeAddr)
+ },
+ }
+ if err := srv.Start(); err != nil {
+ t.Fatal("can't start: ", err)
+ }
+ defer srv.Stop()
+
+ // Dial the test server.
+ conn, err := net.DialTimeout("tcp", srv.ListenAddr, timeout)
+ if err != nil {
+ t.Fatalf("could not dial: %v", err)
+ }
+ select {
+ case <-newTransportCalled:
+ // OK
+ case <-time.After(timeout):
+ t.Error("newTransport not called")
+ }
+ conn.Close()
+
+ // Dial again. This time the server should close the connection immediately.
+ connClosed := make(chan struct{})
+ conn, err = net.DialTimeout("tcp", srv.ListenAddr, timeout)
+ if err != nil {
+ t.Fatalf("could not dial: %v", err)
+ }
+ defer conn.Close()
+ go func() {
+ conn.SetDeadline(time.Now().Add(timeout))
+ buf := make([]byte, 10)
+ if n, err := conn.Read(buf); err != io.EOF || n != 0 {
+ t.Errorf("expected io.EOF and n == 0, got error %q and n == %d", err, n)
+ }
+ connClosed <- struct{}{}
+ }()
+ select {
+ case <-connClosed:
+ // OK
+ case <-newTransportCalled:
+ t.Error("newTransport called for second attempt")
+ case <-time.After(timeout):
+ t.Error("connection not closed within timeout")
+ }
+}
+
+func listenFakeAddr(network, laddr string, remoteAddr net.Addr) (net.Listener, error) {
+ l, err := net.Listen(network, laddr)
+ if err == nil {
+ l = &fakeAddrListener{l, remoteAddr}
+ }
+ return l, err
+}
+
+// fakeAddrListener is a listener that creates connections with a mocked remote address.
+type fakeAddrListener struct {
+ net.Listener
+ remoteAddr net.Addr
+}
+
+type fakeAddrConn struct {
+ net.Conn
+ remoteAddr net.Addr
+}
+
+func (l *fakeAddrListener) Accept() (net.Conn, error) {
+ c, err := l.Listener.Accept()
+ if err != nil {
+ return nil, err
+ }
+ return &fakeAddrConn{c, l.remoteAddr}, nil
+}
+
+func (c *fakeAddrConn) RemoteAddr() net.Addr {
+ return c.remoteAddr
+}
diff --git a/p2p/util.go b/p2p/util.go
new file mode 100644
index 000000000..2a6edf5ce
--- /dev/null
+++ b/p2p/util.go
@@ -0,0 +1,82 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package p2p
+
+import (
+ "container/heap"
+ "time"
+)
+
+// expHeap tracks strings and their expiry time.
+type expHeap []expItem
+
+// expItem is an entry in addrHistory.
+type expItem struct {
+ item string
+ exp time.Time
+}
+
+// nextExpiry returns the next expiry time.
+func (h *expHeap) nextExpiry() time.Time {
+ return (*h)[0].exp
+}
+
+// add adds an item and sets its expiry time.
+func (h *expHeap) add(item string, exp time.Time) {
+ heap.Push(h, expItem{item, exp})
+}
+
+// remove removes an item.
+func (h *expHeap) remove(item string) bool {
+ for i, v := range *h {
+ if v.item == item {
+ heap.Remove(h, i)
+ return true
+ }
+ }
+ return false
+}
+
+// contains checks whether an item is present.
+func (h expHeap) contains(item string) bool {
+ for _, v := range h {
+ if v.item == item {
+ return true
+ }
+ }
+ return false
+}
+
+// expire removes items with expiry time before 'now'.
+func (h *expHeap) expire(now time.Time) {
+ for h.Len() > 0 && h.nextExpiry().Before(now) {
+ heap.Pop(h)
+ }
+}
+
+// heap.Interface boilerplate
+func (h expHeap) Len() int { return len(h) }
+func (h expHeap) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
+func (h expHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
+func (h *expHeap) Push(x interface{}) { *h = append(*h, x.(expItem)) }
+func (h *expHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[0 : n-1]
+ return x
+}
diff --git a/p2p/util_test.go b/p2p/util_test.go
new file mode 100644
index 000000000..c9f2648dc
--- /dev/null
+++ b/p2p/util_test.go
@@ -0,0 +1,54 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package p2p
+
+import (
+ "testing"
+ "time"
+)
+
+func TestExpHeap(t *testing.T) {
+ var h expHeap
+
+ var (
+ basetime = time.Unix(4000, 0)
+ exptimeA = basetime.Add(2 * time.Second)
+ exptimeB = basetime.Add(3 * time.Second)
+ exptimeC = basetime.Add(4 * time.Second)
+ )
+ h.add("a", exptimeA)
+ h.add("b", exptimeB)
+ h.add("c", exptimeC)
+
+ if !h.nextExpiry().Equal(exptimeA) {
+ t.Fatal("wrong nextExpiry")
+ }
+ if !h.contains("a") || !h.contains("b") || !h.contains("c") {
+ t.Fatal("heap doesn't contain all live items")
+ }
+
+ h.expire(exptimeA.Add(1))
+ if !h.nextExpiry().Equal(exptimeB) {
+ t.Fatal("wrong nextExpiry")
+ }
+ if h.contains("a") {
+ t.Fatal("heap contains a even though it has already expired")
+ }
+ if !h.contains("b") || !h.contains("c") {
+ t.Fatal("heap doesn't contain all live items")
+ }
+}