diff --git a/p2p/dial.go b/p2p/dial.go index 936887a1a..8dee5063f 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -17,7 +17,6 @@ package p2p import ( - "container/heap" "errors" "fmt" "net" @@ -29,9 +28,10 @@ import ( ) const ( - // This is the amount of time spent waiting in between - // redialing a certain node. - dialHistoryExpiration = 30 * time.Second + // This is the amount of time spent waiting in between redialing a certain node. The + // limit is a bit higher than inboundThrottleTime to prevent failing dials in small + // private networks. + dialHistoryExpiration = inboundThrottleTime + 5*time.Second // Discovery lookups are throttled and can only run // once every few seconds. @@ -72,16 +72,16 @@ type dialstate struct { ntab discoverTable netrestrict *netutil.Netlist self enode.ID + bootnodes []*enode.Node // default dials when there are no peers + log log.Logger + start time.Time // time when the dialer was first used lookupRunning bool dialing map[enode.ID]connFlag lookupBuf []*enode.Node // current discovery lookup results randomNodes []*enode.Node // filled from Table static map[enode.ID]*dialTask - hist *dialHistory - - start time.Time // time when the dialer was first used - bootnodes []*enode.Node // default dials when there are no peers + hist expHeap } type discoverTable interface { @@ -91,15 +91,6 @@ type discoverTable interface { ReadRandomNodes([]*enode.Node) int } -// the dial history remembers recent dials. -type dialHistory []pastDial - -// pastDial is an entry in the dial history. -type pastDial struct { - id enode.ID - exp time.Time -} - type task interface { Do(*Server) } @@ -126,20 +117,23 @@ type waitExpireTask struct { time.Duration } -func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { +func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate { s := &dialstate{ maxDynDials: maxdyn, ntab: ntab, self: self, - netrestrict: netrestrict, + netrestrict: cfg.NetRestrict, + log: cfg.Logger, static: make(map[enode.ID]*dialTask), dialing: make(map[enode.ID]connFlag), - bootnodes: make([]*enode.Node, len(bootnodes)), + bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)), randomNodes: make([]*enode.Node, maxdyn/2), - hist: new(dialHistory), } - copy(s.bootnodes, bootnodes) - for _, n := range static { + copy(s.bootnodes, cfg.BootstrapNodes) + if s.log == nil { + s.log = log.Root() + } + for _, n := range cfg.StaticNodes { s.addStatic(n) } return s @@ -154,9 +148,6 @@ func (s *dialstate) addStatic(n *enode.Node) { func (s *dialstate) removeStatic(n *enode.Node) { // This removes a task so future attempts to connect will not be made. delete(s.static, n.ID()) - // This removes a previous dial timestamp so that application - // can force a server to reconnect with chosen peer immediately. - s.hist.remove(n.ID()) } func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task { @@ -167,7 +158,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti var newtasks []task addDial := func(flag connFlag, n *enode.Node) bool { if err := s.checkDial(n, peers); err != nil { - log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err) + s.log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err) return false } s.dialing[n.ID()] = flag @@ -196,7 +187,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti err := s.checkDial(t.dest, peers) switch err { case errNotWhitelisted, errSelf: - log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err) + s.log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err) delete(s.static, t.dest.ID()) case nil: s.dialing[id] = t.flags @@ -246,7 +237,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti // This should prevent cases where the dialer logic is not ticked // because there are no pending events. if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 { - t := &waitExpireTask{s.hist.min().exp.Sub(now)} + t := &waitExpireTask{s.hist.nextExpiry().Sub(now)} newtasks = append(newtasks, t) } return newtasks @@ -271,7 +262,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error { return errSelf case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()): return errNotWhitelisted - case s.hist.contains(n.ID()): + case s.hist.contains(string(n.ID().Bytes())): return errRecentlyDialed } return nil @@ -280,7 +271,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error { func (s *dialstate) taskDone(t task, now time.Time) { switch t := t.(type) { case *dialTask: - s.hist.add(t.dest.ID(), now.Add(dialHistoryExpiration)) + s.hist.add(string(t.dest.ID().Bytes()), now.Add(dialHistoryExpiration)) delete(s.dialing, t.dest.ID()) case *discoverTask: s.lookupRunning = false @@ -296,7 +287,7 @@ func (t *dialTask) Do(srv *Server) { } err := t.dial(srv, t.dest) if err != nil { - log.Trace("Dial error", "task", t, "err", err) + srv.log.Trace("Dial error", "task", t, "err", err) // Try resolving the ID of static nodes if dialing failed. if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { if t.resolve(srv) { @@ -314,7 +305,7 @@ func (t *dialTask) Do(srv *Server) { // The backoff delay resets when the node is found. func (t *dialTask) resolve(srv *Server) bool { if srv.ntab == nil { - log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled") + srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled") return false } if t.resolveDelay == 0 { @@ -330,13 +321,13 @@ func (t *dialTask) resolve(srv *Server) bool { if t.resolveDelay > maxResolveDelay { t.resolveDelay = maxResolveDelay } - log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay) + srv.log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay) return false } // The node was found. t.resolveDelay = initialResolveDelay t.dest = resolved - log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) + srv.log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) return true } @@ -385,49 +376,3 @@ func (t waitExpireTask) Do(*Server) { func (t waitExpireTask) String() string { return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration) } - -// Use only these methods to access or modify dialHistory. -func (h dialHistory) min() pastDial { - return h[0] -} -func (h *dialHistory) add(id enode.ID, exp time.Time) { - heap.Push(h, pastDial{id, exp}) - -} -func (h *dialHistory) remove(id enode.ID) bool { - for i, v := range *h { - if v.id == id { - heap.Remove(h, i) - return true - } - } - return false -} -func (h dialHistory) contains(id enode.ID) bool { - for _, v := range h { - if v.id == id { - return true - } - } - return false -} -func (h *dialHistory) expire(now time.Time) { - for h.Len() > 0 && h.min().exp.Before(now) { - heap.Pop(h) - } -} - -// heap.Interface boilerplate -func (h dialHistory) Len() int { return len(h) } -func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) } -func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] } -func (h *dialHistory) Push(x interface{}) { - *h = append(*h, x.(pastDial)) -} -func (h *dialHistory) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - *h = old[0 : n-1] - return x -} diff --git a/p2p/dial_test.go b/p2p/dial_test.go index f41ab7752..de8fc4a6e 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -20,10 +20,13 @@ import ( "encoding/binary" "net" "reflect" + "strings" "testing" "time" "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/internal/testlog" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/netutil" @@ -67,10 +70,10 @@ func runDialTest(t *testing.T, test dialtest) { new := test.init.newTasks(running, pm(round.peers), vtime) if !sametasks(new, round.new) { - t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n", + t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v", i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running)) } - t.Log("tasks:", spew.Sdump(new)) + t.Logf("round %d new tasks: %s", i, strings.TrimSpace(spew.Sdump(new))) // Time advances by 16 seconds on every round. vtime = vtime.Add(16 * time.Second) @@ -88,8 +91,9 @@ func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t) // This test checks that dynamic dials are launched from discovery results. func TestDialStateDynDial(t *testing.T) { + config := &Config{Logger: testlog.Logger(t, log.LvlTrace)} runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, nil, nil, fakeTable{}, 5, nil), + init: newDialState(enode.ID{}, fakeTable{}, 5, config), rounds: []round{ // A discovery query is launched. { @@ -153,7 +157,7 @@ func TestDialStateDynDial(t *testing.T) { &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ - &waitExpireTask{Duration: 14 * time.Second}, + &waitExpireTask{Duration: 19 * time.Second}, }, }, // In this round, the peer with id 2 drops off. The query @@ -223,10 +227,13 @@ func TestDialStateDynDial(t *testing.T) { // Tests that bootnodes are dialed if no peers are connectd, but not otherwise. func TestDialStateDynDialBootnode(t *testing.T) { - bootnodes := []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), + config := &Config{ + BootstrapNodes: []*enode.Node{ + newNode(uintID(1), nil), + newNode(uintID(2), nil), + newNode(uintID(3), nil), + }, + Logger: testlog.Logger(t, log.LvlTrace), } table := fakeTable{ newNode(uintID(4), nil), @@ -236,7 +243,7 @@ func TestDialStateDynDialBootnode(t *testing.T) { newNode(uintID(8), nil), } runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, nil, bootnodes, table, 5, nil), + init: newDialState(enode.ID{}, table, 5, config), rounds: []round{ // 2 dynamic dials attempted, bootnodes pending fallback interval { @@ -259,25 +266,24 @@ func TestDialStateDynDialBootnode(t *testing.T) { { new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, // No dials succeed, 2nd bootnode is attempted { done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, // No dials succeed, 3rd bootnode is attempted { done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, @@ -288,21 +294,19 @@ func TestDialStateDynDialBootnode(t *testing.T) { done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - }, + new: []task{}, }, // Random dial succeeds, no more bootnodes are attempted { + new: []task{ + &waitExpireTask{3 * time.Second}, + }, peers: []*Peer{ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, }, done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, }, @@ -324,7 +328,7 @@ func TestDialStateDynDialFromTable(t *testing.T) { } runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, nil, nil, table, 10, nil), + init: newDialState(enode.ID{}, table, 10, &Config{Logger: testlog.Logger(t, log.LvlTrace)}), rounds: []round{ // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. { @@ -430,7 +434,7 @@ func TestDialStateNetRestrict(t *testing.T) { restrict.Add("127.0.2.0/24") runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, nil, nil, table, 10, restrict), + init: newDialState(enode.ID{}, table, 10, &Config{NetRestrict: restrict}), rounds: []round{ { new: []task{ @@ -444,16 +448,18 @@ func TestDialStateNetRestrict(t *testing.T) { // This test checks that static dials are launched. func TestDialStateStaticDial(t *testing.T) { - wantStatic := []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), - newNode(uintID(4), nil), - newNode(uintID(5), nil), + config := &Config{ + StaticNodes: []*enode.Node{ + newNode(uintID(1), nil), + newNode(uintID(2), nil), + newNode(uintID(3), nil), + newNode(uintID(4), nil), + newNode(uintID(5), nil), + }, + Logger: testlog.Logger(t, log.LvlTrace), } - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil), + init: newDialState(enode.ID{}, fakeTable{}, 0, config), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -495,7 +501,7 @@ func TestDialStateStaticDial(t *testing.T) { &dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ - &waitExpireTask{Duration: 14 * time.Second}, + &waitExpireTask{Duration: 19 * time.Second}, }, }, // Wait a round for dial history to expire, no new tasks should spawn. @@ -511,6 +517,9 @@ func TestDialStateStaticDial(t *testing.T) { // If a static node is dropped, it should be immediately redialed, // irrespective whether it was originally static or dynamic. { + done: []task{ + &waitExpireTask{Duration: 19 * time.Second}, + }, peers: []*Peer{ {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, @@ -518,67 +527,24 @@ func TestDialStateStaticDial(t *testing.T) { }, new: []task{ &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)}, }, }, }, }) } -// This test checks that static peers will be redialed immediately if they were re-added to a static list. -func TestDialStaticAfterReset(t *testing.T) { - wantStatic := []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - } - - rounds := []round{ - // Static dials are launched for the nodes that aren't yet connected. - { - peers: nil, - new: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, - }, - }, - // No new dial tasks, all peers are connected. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, - }, - done: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, - }, - new: []task{ - &waitExpireTask{Duration: 30 * time.Second}, - }, - }, - } - dTest := dialtest{ - init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil), - rounds: rounds, - } - runDialTest(t, dTest) - for _, n := range wantStatic { - dTest.init.removeStatic(n) - dTest.init.addStatic(n) - } - // without removing peers they will be considered recently dialed - runDialTest(t, dTest) -} - // This test checks that past dials are not retried for some time. func TestDialStateCache(t *testing.T) { - wantStatic := []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), + config := &Config{ + StaticNodes: []*enode.Node{ + newNode(uintID(1), nil), + newNode(uintID(2), nil), + newNode(uintID(3), nil), + }, + Logger: testlog.Logger(t, log.LvlTrace), } - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil), + init: newDialState(enode.ID{}, fakeTable{}, 0, config), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -606,28 +572,37 @@ func TestDialStateCache(t *testing.T) { // entry to expire. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, }, done: []task{ &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, }, new: []task{ - &waitExpireTask{Duration: 14 * time.Second}, + &waitExpireTask{Duration: 19 * time.Second}, }, }, // Still waiting for node 3's entry to expire in the cache. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, + }, + }, + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, }, }, // The cache entry for node 3 has expired and is retried. { + done: []task{ + &waitExpireTask{Duration: 19 * time.Second}, + }, peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, }, new: []task{ &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, @@ -638,9 +613,13 @@ func TestDialStateCache(t *testing.T) { } func TestDialResolve(t *testing.T) { + config := &Config{ + Logger: testlog.Logger(t, log.LvlTrace), + Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}, + } resolved := newNode(uintID(1), net.IP{127, 0, 55, 234}) table := &resolveMock{answer: resolved} - state := newDialState(enode.ID{}, nil, nil, table, 0, nil) + state := newDialState(enode.ID{}, table, 0, config) // Check that the task is generated with an incomplete ID. dest := newNode(uintID(1), nil) @@ -651,8 +630,7 @@ func TestDialResolve(t *testing.T) { } // Now run the task, it should resolve the ID once. - config := Config{Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}} - srv := &Server{ntab: table, Config: config} + srv := &Server{ntab: table, log: config.Logger, Config: *config} tasks[0].Do(srv) if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) { t.Fatalf("wrong resolve calls, got %v", table.resolveCalls) diff --git a/p2p/netutil/addrutil.go b/p2p/netutil/addrutil.go new file mode 100644 index 000000000..b261a5295 --- /dev/null +++ b/p2p/netutil/addrutil.go @@ -0,0 +1,33 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package netutil + +import "net" + +// AddrIP gets the IP address contained in addr. It returns nil if no address is present. +func AddrIP(addr net.Addr) net.IP { + switch a := addr.(type) { + case *net.IPAddr: + return a.IP + case *net.TCPAddr: + return a.IP + case *net.UDPAddr: + return a.IP + default: + return nil + } +} diff --git a/p2p/peer.go b/p2p/peer.go index af019d07a..98ea6835d 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -120,7 +120,7 @@ func NewPeer(id enode.ID, name string, caps []Cap) *Peer { pipe, _ := net.Pipe() node := enode.SignNull(new(enr.Record), id) conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name} - peer := newPeer(conn, nil) + peer := newPeer(log.Root(), conn, nil) close(peer.closed) // ensures Disconnect doesn't block return peer } @@ -176,7 +176,7 @@ func (p *Peer) Inbound() bool { return p.rw.is(inboundConn) } -func newPeer(conn *conn, protocols []Protocol) *Peer { +func newPeer(log log.Logger, conn *conn, protocols []Protocol) *Peer { protomap := matchProtocols(protocols, conn.caps, conn) p := &Peer{ rw: conn, diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 5aa64a32e..984cc411a 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -24,6 +24,8 @@ import ( "reflect" "testing" "time" + + "github.com/ethereum/go-ethereum/log" ) var discard = Protocol{ @@ -52,7 +54,7 @@ func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) { c2.caps = append(c2.caps, p.cap()) } - peer := newPeer(c1, protos) + peer := newPeer(log.Root(), c1, protos) errc := make(chan error, 1) go func() { _, err := peer.run() diff --git a/p2p/server.go b/p2p/server.go index b3494cc88..a373904fc 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -22,6 +22,7 @@ import ( "crypto/ecdsa" "encoding/hex" "errors" + "fmt" "net" "sort" "sync" @@ -49,6 +50,9 @@ const ( defaultMaxPendingPeers = 50 defaultDialRatio = 3 + // This time limits inbound connection attempts per source IP. + inboundThrottleTime = 30 * time.Second + // Maximum time allowed for reading a complete message. // This is effectively the amount of time a connection can be idle. frameReadTimeout = 30 * time.Second @@ -158,6 +162,7 @@ type Server struct { // the whole protocol stack. newTransport func(net.Conn) transport newPeerHook func(*Peer) + listenFunc func(network, addr string) (net.Listener, error) lock sync.Mutex // protects running running bool @@ -167,24 +172,26 @@ type Server struct { ntab discoverTable listener net.Listener ourHandshake *protoHandshake - lastLookup time.Time DiscV5 *discv5.Network + loopWG sync.WaitGroup // loop, listenLoop + peerFeed event.Feed + log log.Logger - // These are for Peers, PeerCount (and nothing else). - peerOp chan peerOpFunc - peerOpDone chan struct{} + // Channels into the run loop. + quit chan struct{} + addstatic chan *enode.Node + removestatic chan *enode.Node + addtrusted chan *enode.Node + removetrusted chan *enode.Node + peerOp chan peerOpFunc + peerOpDone chan struct{} + delpeer chan peerDrop + checkpointPostHandshake chan *conn + checkpointAddPeer chan *conn - quit chan struct{} - addstatic chan *enode.Node - removestatic chan *enode.Node - addtrusted chan *enode.Node - removetrusted chan *enode.Node - posthandshake chan *conn - addpeer chan *conn - delpeer chan peerDrop - loopWG sync.WaitGroup // loop, listenLoop - peerFeed event.Feed - log log.Logger + // State of run loop and listenLoop. + lastLookup time.Time + inboundHistory expHeap } type peerOpFunc func(map[enode.ID]*Peer) @@ -415,7 +422,7 @@ func (srv *Server) Start() (err error) { srv.running = true srv.log = srv.Config.Logger if srv.log == nil { - srv.log = log.New() + srv.log = log.Root() } if srv.NoDial && srv.ListenAddr == "" { srv.log.Warn("P2P server will be useless, neither dialing nor listening") @@ -428,13 +435,16 @@ func (srv *Server) Start() (err error) { if srv.newTransport == nil { srv.newTransport = newRLPX } + if srv.listenFunc == nil { + srv.listenFunc = net.Listen + } if srv.Dialer == nil { srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}} } srv.quit = make(chan struct{}) - srv.addpeer = make(chan *conn) srv.delpeer = make(chan peerDrop) - srv.posthandshake = make(chan *conn) + srv.checkpointPostHandshake = make(chan *conn) + srv.checkpointAddPeer = make(chan *conn) srv.addstatic = make(chan *enode.Node) srv.removestatic = make(chan *enode.Node) srv.addtrusted = make(chan *enode.Node) @@ -455,7 +465,7 @@ func (srv *Server) Start() (err error) { } dynPeers := srv.maxDialedConns() - dialer := newDialState(srv.localnode.ID(), srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict) + dialer := newDialState(srv.localnode.ID(), srv.ntab, dynPeers, &srv.Config) srv.loopWG.Add(1) go srv.run(dialer) return nil @@ -541,6 +551,7 @@ func (srv *Server) setupDiscovery() error { NetRestrict: srv.NetRestrict, Bootnodes: srv.BootstrapNodes, Unhandled: unhandled, + Log: srv.log, } ntab, err := discover.ListenUDP(conn, srv.localnode, cfg) if err != nil { @@ -569,27 +580,28 @@ func (srv *Server) setupDiscovery() error { } func (srv *Server) setupListening() error { - // Launch the TCP listener. - listener, err := net.Listen("tcp", srv.ListenAddr) + // Launch the listener. + listener, err := srv.listenFunc("tcp", srv.ListenAddr) if err != nil { return err } - laddr := listener.Addr().(*net.TCPAddr) - srv.ListenAddr = laddr.String() srv.listener = listener - srv.localnode.Set(enr.TCP(laddr.Port)) + srv.ListenAddr = listener.Addr().String() + + // Update the local node record and map the TCP listening port if NAT is configured. + if tcp, ok := listener.Addr().(*net.TCPAddr); ok { + srv.localnode.Set(enr.TCP(tcp.Port)) + if !tcp.IP.IsLoopback() && srv.NAT != nil { + srv.loopWG.Add(1) + go func() { + nat.Map(srv.NAT, srv.quit, "tcp", tcp.Port, tcp.Port, "ethereum p2p") + srv.loopWG.Done() + }() + } + } srv.loopWG.Add(1) go srv.listenLoop() - - // Map the TCP listening port if NAT is configured. - if !laddr.IP.IsLoopback() && srv.NAT != nil { - srv.loopWG.Add(1) - go func() { - nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p") - srv.loopWG.Done() - }() - } return nil } @@ -657,12 +669,14 @@ running: case <-srv.quit: // The server was stopped. Run the cleanup logic. break running + case n := <-srv.addstatic: // This channel is used by AddPeer to add to the // ephemeral static peer list. Add it to the dialer, // it will keep the node connected. srv.log.Trace("Adding static node", "node", n) dialstate.addStatic(n) + case n := <-srv.removestatic: // This channel is used by RemovePeer to send a // disconnect request to a peer and begin the @@ -672,6 +686,7 @@ running: if p, ok := peers[n.ID()]; ok { p.Disconnect(DiscRequested) } + case n := <-srv.addtrusted: // This channel is used by AddTrustedPeer to add an enode // to the trusted node set. @@ -681,6 +696,7 @@ running: if p, ok := peers[n.ID()]; ok { p.rw.set(trustedConn, true) } + case n := <-srv.removetrusted: // This channel is used by RemoveTrustedPeer to remove an enode // from the trusted node set. @@ -691,10 +707,12 @@ running: if p, ok := peers[n.ID()]; ok { p.rw.set(trustedConn, false) } + case op := <-srv.peerOp: // This channel is used by Peers and PeerCount. op(peers) srv.peerOpDone <- struct{}{} + case t := <-taskdone: // A task got done. Tell dialstate about it so it // can update its state and remove it from the active @@ -702,7 +720,8 @@ running: srv.log.Trace("Dial task done", "task", t) dialstate.taskDone(t, time.Now()) delTask(t) - case c := <-srv.posthandshake: + + case c := <-srv.checkpointPostHandshake: // A connection has passed the encryption handshake so // the remote identity is known (but hasn't been verified yet). if trusted[c.node.ID()] { @@ -710,18 +729,15 @@ running: c.flags |= trustedConn } // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. - select { - case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c): - case <-srv.quit: - break running - } - case c := <-srv.addpeer: + c.cont <- srv.postHandshakeChecks(peers, inboundCount, c) + + case c := <-srv.checkpointAddPeer: // At this point the connection is past the protocol handshake. // Its capabilities are known and the remote identity is verified. - err := srv.protoHandshakeChecks(peers, inboundCount, c) + err := srv.addPeerChecks(peers, inboundCount, c) if err == nil { // The handshakes are done and it passed all checks. - p := newPeer(c, srv.Protocols) + p := newPeer(srv.log, c, srv.Protocols) // If message events are enabled, pass the peerFeed // to the peer if srv.EnableMsgEvents { @@ -738,11 +754,8 @@ running: // The dialer logic relies on the assumption that // dial tasks complete after the peer has been added or // discarded. Unblock the task last. - select { - case c.cont <- err: - case <-srv.quit: - break running - } + c.cont <- err + case pd := <-srv.delpeer: // A peer disconnected. d := common.PrettyDuration(mclock.Now() - pd.created) @@ -777,17 +790,7 @@ running: } } -func (srv *Server) protoHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { - // Drop connections with no matching protocols. - if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 { - return DiscUselessPeer - } - // Repeat the encryption handshake checks because the - // peer set might have changed between the handshakes. - return srv.encHandshakeChecks(peers, inboundCount, c) -} - -func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { +func (srv *Server) postHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { switch { case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers: return DiscTooManyPeers @@ -802,9 +805,20 @@ func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int } } +func (srv *Server) addPeerChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { + // Drop connections with no matching protocols. + if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 { + return DiscUselessPeer + } + // Repeat the post-handshake checks because the + // peer set might have changed since those checks were performed. + return srv.postHandshakeChecks(peers, inboundCount, c) +} + func (srv *Server) maxInboundConns() int { return srv.MaxPeers - srv.maxDialedConns() } + func (srv *Server) maxDialedConns() int { if srv.NoDiscovery || srv.NoDial { return 0 @@ -832,7 +846,7 @@ func (srv *Server) listenLoop() { } for { - // Wait for a handshake slot before accepting. + // Wait for a free slot before accepting. <-slots var ( @@ -851,21 +865,16 @@ func (srv *Server) listenLoop() { break } - // Reject connections that do not match NetRestrict. - if srv.NetRestrict != nil { - if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) { - srv.log.Debug("Rejected conn (not whitelisted in NetRestrict)", "addr", fd.RemoteAddr()) - fd.Close() - slots <- struct{}{} - continue - } + remoteIP := netutil.AddrIP(fd.RemoteAddr()) + if err := srv.checkInboundConn(fd, remoteIP); err != nil { + srv.log.Debug("Rejected inbound connnection", "addr", fd.RemoteAddr(), "err", err) + fd.Close() + slots <- struct{}{} + continue } - - var ip net.IP - if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok { - ip = tcp.IP + if remoteIP != nil { + fd = newMeteredConn(fd, true, remoteIP) } - fd = newMeteredConn(fd, true, ip) srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr()) go func() { srv.SetupConn(fd, inboundConn, nil) @@ -874,6 +883,22 @@ func (srv *Server) listenLoop() { } } +func (srv *Server) checkInboundConn(fd net.Conn, remoteIP net.IP) error { + if remoteIP != nil { + // Reject connections that do not match NetRestrict. + if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) { + return fmt.Errorf("not whitelisted in NetRestrict") + } + // Reject Internet peers that try too often. + srv.inboundHistory.expire(time.Now()) + if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) { + return fmt.Errorf("too many attempts") + } + srv.inboundHistory.add(remoteIP.String(), time.Now().Add(inboundThrottleTime)) + } + return nil +} + // SetupConn runs the handshakes and attempts to add the connection // as a peer. It returns when the connection has been added as a peer // or the handshakes have failed. @@ -895,6 +920,7 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro if !running { return errServerStopped } + // If dialing, figure out the remote public key. var dialPubkey *ecdsa.PublicKey if dialDest != nil { @@ -903,7 +929,8 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro return errors.New("dial destination doesn't have a secp256k1 public key") } } - // Run the encryption handshake. + + // Run the RLPx handshake. remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey) if err != nil { srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err) @@ -922,12 +949,13 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro conn.handshakeDone(c.node.ID()) } clog := srv.log.New("id", c.node.ID(), "addr", c.fd.RemoteAddr(), "conn", c.flags) - err = srv.checkpoint(c, srv.posthandshake) + err = srv.checkpoint(c, srv.checkpointPostHandshake) if err != nil { - clog.Trace("Rejected peer before protocol handshake", "err", err) + clog.Trace("Rejected peer", "err", err) return err } - // Run the protocol handshake + + // Run the capability negotiation handshake. phs, err := c.doProtoHandshake(srv.ourHandshake) if err != nil { clog.Trace("Failed proto handshake", "err", err) @@ -938,14 +966,15 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro return DiscUnexpectedIdentity } c.caps, c.name = phs.Caps, phs.Name - err = srv.checkpoint(c, srv.addpeer) + err = srv.checkpoint(c, srv.checkpointAddPeer) if err != nil { clog.Trace("Rejected peer", "err", err) return err } - // If the checks completed successfully, runPeer has now been - // launched by run. - clog.Trace("connection set up", "inbound", dialDest == nil) + + // If the checks completed successfully, the connection has been added as a peer and + // runPeer has been launched. + clog.Trace("Connection set up", "inbound", dialDest == nil) return nil } @@ -974,12 +1003,7 @@ func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error { case <-srv.quit: return errServerStopped } - select { - case err := <-c.cont: - return err - case <-srv.quit: - return errServerStopped - } + return <-c.cont } // runPeer runs in its own goroutine for each peer. diff --git a/p2p/server_test.go b/p2p/server_test.go index f665c1424..e8bc627e1 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -19,6 +19,7 @@ package p2p import ( "crypto/ecdsa" "errors" + "io" "math/rand" "net" "reflect" @@ -26,6 +27,7 @@ import ( "time" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/internal/testlog" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" @@ -74,6 +76,7 @@ func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) * MaxPeers: 10, ListenAddr: "127.0.0.1:0", PrivateKey: newkey(), + Logger: testlog.Logger(t, log.LvlTrace), } server := &Server{ Config: config, @@ -359,6 +362,7 @@ func TestServerAtCap(t *testing.T) { PrivateKey: newkey(), MaxPeers: 10, NoDial: true, + NoDiscovery: true, TrustedNodes: []*enode.Node{newNode(trustedID, nil)}, }, } @@ -377,19 +381,19 @@ func TestServerAtCap(t *testing.T) { // Inject a few connections to fill up the peer set. for i := 0; i < 10; i++ { c := newconn(randomID()) - if err := srv.checkpoint(c, srv.addpeer); err != nil { + if err := srv.checkpoint(c, srv.checkpointAddPeer); err != nil { t.Fatalf("could not add conn %d: %v", i, err) } } // Try inserting a non-trusted connection. anotherID := randomID() c := newconn(anotherID) - if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { + if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers { t.Error("wrong error for insert:", err) } // Try inserting a trusted connection. c = newconn(trustedID) - if err := srv.checkpoint(c, srv.posthandshake); err != nil { + if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil { t.Error("unexpected error for trusted conn @posthandshake:", err) } if !c.is(trustedConn) { @@ -399,14 +403,14 @@ func TestServerAtCap(t *testing.T) { // Remove from trusted set and try again srv.RemoveTrustedPeer(newNode(trustedID, nil)) c = newconn(trustedID) - if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { + if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers { t.Error("wrong error for insert:", err) } // Add anotherID to trusted set and try again srv.AddTrustedPeer(newNode(anotherID, nil)) c = newconn(anotherID) - if err := srv.checkpoint(c, srv.posthandshake); err != nil { + if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil { t.Error("unexpected error for trusted conn @posthandshake:", err) } if !c.is(trustedConn) { @@ -430,10 +434,11 @@ func TestServerPeerLimits(t *testing.T) { srv := &Server{ Config: Config{ - PrivateKey: srvkey, - MaxPeers: 0, - NoDial: true, - Protocols: []Protocol{discard}, + PrivateKey: srvkey, + MaxPeers: 0, + NoDial: true, + NoDiscovery: true, + Protocols: []Protocol{discard}, }, newTransport: func(fd net.Conn) transport { return tp }, log: log.New(), @@ -541,29 +546,35 @@ func TestServerSetupConn(t *testing.T) { } for i, test := range tests { - srv := &Server{ - Config: Config{ - PrivateKey: srvkey, - MaxPeers: 10, - NoDial: true, - Protocols: []Protocol{discard}, - }, - newTransport: func(fd net.Conn) transport { return test.tt }, - log: log.New(), - } - if !test.dontstart { - if err := srv.Start(); err != nil { - t.Fatalf("couldn't start server: %v", err) + t.Run(test.wantCalls, func(t *testing.T) { + cfg := Config{ + PrivateKey: srvkey, + MaxPeers: 10, + NoDial: true, + NoDiscovery: true, + Protocols: []Protocol{discard}, + Logger: testlog.Logger(t, log.LvlTrace), } - } - p1, _ := net.Pipe() - srv.SetupConn(p1, test.flags, test.dialDest) - if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { - t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) - } - if test.tt.calls != test.wantCalls { - t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) - } + srv := &Server{ + Config: cfg, + newTransport: func(fd net.Conn) transport { return test.tt }, + log: cfg.Logger, + } + if !test.dontstart { + if err := srv.Start(); err != nil { + t.Fatalf("couldn't start server: %v", err) + } + defer srv.Stop() + } + p1, _ := net.Pipe() + srv.SetupConn(p1, test.flags, test.dialDest) + if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { + t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) + } + if test.tt.calls != test.wantCalls { + t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) + } + }) } } @@ -616,3 +627,100 @@ func randomID() (id enode.ID) { } return id } + +// This test checks that inbound connections are throttled by IP. +func TestServerInboundThrottle(t *testing.T) { + const timeout = 5 * time.Second + newTransportCalled := make(chan struct{}) + srv := &Server{ + Config: Config{ + PrivateKey: newkey(), + ListenAddr: "127.0.0.1:0", + MaxPeers: 10, + NoDial: true, + NoDiscovery: true, + Protocols: []Protocol{discard}, + Logger: testlog.Logger(t, log.LvlTrace), + }, + newTransport: func(fd net.Conn) transport { + newTransportCalled <- struct{}{} + return newRLPX(fd) + }, + listenFunc: func(network, laddr string) (net.Listener, error) { + fakeAddr := &net.TCPAddr{IP: net.IP{95, 33, 21, 2}, Port: 4444} + return listenFakeAddr(network, laddr, fakeAddr) + }, + } + if err := srv.Start(); err != nil { + t.Fatal("can't start: ", err) + } + defer srv.Stop() + + // Dial the test server. + conn, err := net.DialTimeout("tcp", srv.ListenAddr, timeout) + if err != nil { + t.Fatalf("could not dial: %v", err) + } + select { + case <-newTransportCalled: + // OK + case <-time.After(timeout): + t.Error("newTransport not called") + } + conn.Close() + + // Dial again. This time the server should close the connection immediately. + connClosed := make(chan struct{}) + conn, err = net.DialTimeout("tcp", srv.ListenAddr, timeout) + if err != nil { + t.Fatalf("could not dial: %v", err) + } + defer conn.Close() + go func() { + conn.SetDeadline(time.Now().Add(timeout)) + buf := make([]byte, 10) + if n, err := conn.Read(buf); err != io.EOF || n != 0 { + t.Errorf("expected io.EOF and n == 0, got error %q and n == %d", err, n) + } + connClosed <- struct{}{} + }() + select { + case <-connClosed: + // OK + case <-newTransportCalled: + t.Error("newTransport called for second attempt") + case <-time.After(timeout): + t.Error("connection not closed within timeout") + } +} + +func listenFakeAddr(network, laddr string, remoteAddr net.Addr) (net.Listener, error) { + l, err := net.Listen(network, laddr) + if err == nil { + l = &fakeAddrListener{l, remoteAddr} + } + return l, err +} + +// fakeAddrListener is a listener that creates connections with a mocked remote address. +type fakeAddrListener struct { + net.Listener + remoteAddr net.Addr +} + +type fakeAddrConn struct { + net.Conn + remoteAddr net.Addr +} + +func (l *fakeAddrListener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return &fakeAddrConn{c, l.remoteAddr}, nil +} + +func (c *fakeAddrConn) RemoteAddr() net.Addr { + return c.remoteAddr +} diff --git a/p2p/util.go b/p2p/util.go new file mode 100644 index 000000000..2a6edf5ce --- /dev/null +++ b/p2p/util.go @@ -0,0 +1,82 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package p2p + +import ( + "container/heap" + "time" +) + +// expHeap tracks strings and their expiry time. +type expHeap []expItem + +// expItem is an entry in addrHistory. +type expItem struct { + item string + exp time.Time +} + +// nextExpiry returns the next expiry time. +func (h *expHeap) nextExpiry() time.Time { + return (*h)[0].exp +} + +// add adds an item and sets its expiry time. +func (h *expHeap) add(item string, exp time.Time) { + heap.Push(h, expItem{item, exp}) +} + +// remove removes an item. +func (h *expHeap) remove(item string) bool { + for i, v := range *h { + if v.item == item { + heap.Remove(h, i) + return true + } + } + return false +} + +// contains checks whether an item is present. +func (h expHeap) contains(item string) bool { + for _, v := range h { + if v.item == item { + return true + } + } + return false +} + +// expire removes items with expiry time before 'now'. +func (h *expHeap) expire(now time.Time) { + for h.Len() > 0 && h.nextExpiry().Before(now) { + heap.Pop(h) + } +} + +// heap.Interface boilerplate +func (h expHeap) Len() int { return len(h) } +func (h expHeap) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) } +func (h expHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *expHeap) Push(x interface{}) { *h = append(*h, x.(expItem)) } +func (h *expHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} diff --git a/p2p/util_test.go b/p2p/util_test.go new file mode 100644 index 000000000..c9f2648dc --- /dev/null +++ b/p2p/util_test.go @@ -0,0 +1,54 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package p2p + +import ( + "testing" + "time" +) + +func TestExpHeap(t *testing.T) { + var h expHeap + + var ( + basetime = time.Unix(4000, 0) + exptimeA = basetime.Add(2 * time.Second) + exptimeB = basetime.Add(3 * time.Second) + exptimeC = basetime.Add(4 * time.Second) + ) + h.add("a", exptimeA) + h.add("b", exptimeB) + h.add("c", exptimeC) + + if !h.nextExpiry().Equal(exptimeA) { + t.Fatal("wrong nextExpiry") + } + if !h.contains("a") || !h.contains("b") || !h.contains("c") { + t.Fatal("heap doesn't contain all live items") + } + + h.expire(exptimeA.Add(1)) + if !h.nextExpiry().Equal(exptimeB) { + t.Fatal("wrong nextExpiry") + } + if h.contains("a") { + t.Fatal("heap contains a even though it has already expired") + } + if !h.contains("b") || !h.contains("c") { + t.Fatal("heap doesn't contain all live items") + } +}