p2p: enforce connection retry limit on server side (#19684)

The dialer limits itself to one attempt every 30s. Apply the same limit
in Server and reject peers which try to connect too eagerly. The check
against the limit happens right after accepting the connection.

Further changes in this commit ensure we pass the Server logger
down to Peer instances, discovery and dialState. Unit test logging now
works in all Server tests.
This commit is contained in:
Felix Lange 2019-06-11 12:45:33 +02:00 committed by GitHub
parent c0a034ec89
commit c420dcb39c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 518 additions and 292 deletions

View File

@ -17,7 +17,6 @@
package p2p
import (
"container/heap"
"errors"
"fmt"
"net"
@ -29,9 +28,10 @@ import (
)
const (
// This is the amount of time spent waiting in between
// redialing a certain node.
dialHistoryExpiration = 30 * time.Second
// This is the amount of time spent waiting in between redialing a certain node. The
// limit is a bit higher than inboundThrottleTime to prevent failing dials in small
// private networks.
dialHistoryExpiration = inboundThrottleTime + 5*time.Second
// Discovery lookups are throttled and can only run
// once every few seconds.
@ -72,16 +72,16 @@ type dialstate struct {
ntab discoverTable
netrestrict *netutil.Netlist
self enode.ID
bootnodes []*enode.Node // default dials when there are no peers
log log.Logger
start time.Time // time when the dialer was first used
lookupRunning bool
dialing map[enode.ID]connFlag
lookupBuf []*enode.Node // current discovery lookup results
randomNodes []*enode.Node // filled from Table
static map[enode.ID]*dialTask
hist *dialHistory
start time.Time // time when the dialer was first used
bootnodes []*enode.Node // default dials when there are no peers
hist expHeap
}
type discoverTable interface {
@ -91,15 +91,6 @@ type discoverTable interface {
ReadRandomNodes([]*enode.Node) int
}
// the dial history remembers recent dials.
type dialHistory []pastDial
// pastDial is an entry in the dial history.
type pastDial struct {
id enode.ID
exp time.Time
}
type task interface {
Do(*Server)
}
@ -126,20 +117,23 @@ type waitExpireTask struct {
time.Duration
}
func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate {
s := &dialstate{
maxDynDials: maxdyn,
ntab: ntab,
self: self,
netrestrict: netrestrict,
netrestrict: cfg.NetRestrict,
log: cfg.Logger,
static: make(map[enode.ID]*dialTask),
dialing: make(map[enode.ID]connFlag),
bootnodes: make([]*enode.Node, len(bootnodes)),
bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)),
randomNodes: make([]*enode.Node, maxdyn/2),
hist: new(dialHistory),
}
copy(s.bootnodes, bootnodes)
for _, n := range static {
copy(s.bootnodes, cfg.BootstrapNodes)
if s.log == nil {
s.log = log.Root()
}
for _, n := range cfg.StaticNodes {
s.addStatic(n)
}
return s
@ -154,9 +148,6 @@ func (s *dialstate) addStatic(n *enode.Node) {
func (s *dialstate) removeStatic(n *enode.Node) {
// This removes a task so future attempts to connect will not be made.
delete(s.static, n.ID())
// This removes a previous dial timestamp so that application
// can force a server to reconnect with chosen peer immediately.
s.hist.remove(n.ID())
}
func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task {
@ -167,7 +158,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
var newtasks []task
addDial := func(flag connFlag, n *enode.Node) bool {
if err := s.checkDial(n, peers); err != nil {
log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err)
s.log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err)
return false
}
s.dialing[n.ID()] = flag
@ -196,7 +187,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
err := s.checkDial(t.dest, peers)
switch err {
case errNotWhitelisted, errSelf:
log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err)
s.log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err)
delete(s.static, t.dest.ID())
case nil:
s.dialing[id] = t.flags
@ -246,7 +237,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
// This should prevent cases where the dialer logic is not ticked
// because there are no pending events.
if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
t := &waitExpireTask{s.hist.min().exp.Sub(now)}
t := &waitExpireTask{s.hist.nextExpiry().Sub(now)}
newtasks = append(newtasks, t)
}
return newtasks
@ -271,7 +262,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
return errSelf
case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()):
return errNotWhitelisted
case s.hist.contains(n.ID()):
case s.hist.contains(string(n.ID().Bytes())):
return errRecentlyDialed
}
return nil
@ -280,7 +271,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
func (s *dialstate) taskDone(t task, now time.Time) {
switch t := t.(type) {
case *dialTask:
s.hist.add(t.dest.ID(), now.Add(dialHistoryExpiration))
s.hist.add(string(t.dest.ID().Bytes()), now.Add(dialHistoryExpiration))
delete(s.dialing, t.dest.ID())
case *discoverTask:
s.lookupRunning = false
@ -296,7 +287,7 @@ func (t *dialTask) Do(srv *Server) {
}
err := t.dial(srv, t.dest)
if err != nil {
log.Trace("Dial error", "task", t, "err", err)
srv.log.Trace("Dial error", "task", t, "err", err)
// Try resolving the ID of static nodes if dialing failed.
if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
if t.resolve(srv) {
@ -314,7 +305,7 @@ func (t *dialTask) Do(srv *Server) {
// The backoff delay resets when the node is found.
func (t *dialTask) resolve(srv *Server) bool {
if srv.ntab == nil {
log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
return false
}
if t.resolveDelay == 0 {
@ -330,13 +321,13 @@ func (t *dialTask) resolve(srv *Server) bool {
if t.resolveDelay > maxResolveDelay {
t.resolveDelay = maxResolveDelay
}
log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay)
srv.log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay)
return false
}
// The node was found.
t.resolveDelay = initialResolveDelay
t.dest = resolved
log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
srv.log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
return true
}
@ -385,49 +376,3 @@ func (t waitExpireTask) Do(*Server) {
func (t waitExpireTask) String() string {
return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
}
// Use only these methods to access or modify dialHistory.
func (h dialHistory) min() pastDial {
return h[0]
}
func (h *dialHistory) add(id enode.ID, exp time.Time) {
heap.Push(h, pastDial{id, exp})
}
func (h *dialHistory) remove(id enode.ID) bool {
for i, v := range *h {
if v.id == id {
heap.Remove(h, i)
return true
}
}
return false
}
func (h dialHistory) contains(id enode.ID) bool {
for _, v := range h {
if v.id == id {
return true
}
}
return false
}
func (h *dialHistory) expire(now time.Time) {
for h.Len() > 0 && h.min().exp.Before(now) {
heap.Pop(h)
}
}
// heap.Interface boilerplate
func (h dialHistory) Len() int { return len(h) }
func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *dialHistory) Push(x interface{}) {
*h = append(*h, x.(pastDial))
}
func (h *dialHistory) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}

View File

@ -20,10 +20,13 @@ import (
"encoding/binary"
"net"
"reflect"
"strings"
"testing"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/internal/testlog"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/netutil"
@ -67,10 +70,10 @@ func runDialTest(t *testing.T, test dialtest) {
new := test.init.newTasks(running, pm(round.peers), vtime)
if !sametasks(new, round.new) {
t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n",
t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v",
i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
}
t.Log("tasks:", spew.Sdump(new))
t.Logf("round %d new tasks: %s", i, strings.TrimSpace(spew.Sdump(new)))
// Time advances by 16 seconds on every round.
vtime = vtime.Add(16 * time.Second)
@ -88,8 +91,9 @@ func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t)
// This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) {
config := &Config{Logger: testlog.Logger(t, log.LvlTrace)}
runDialTest(t, dialtest{
init: newDialState(enode.ID{}, nil, nil, fakeTable{}, 5, nil),
init: newDialState(enode.ID{}, fakeTable{}, 5, config),
rounds: []round{
// A discovery query is launched.
{
@ -153,7 +157,7 @@ func TestDialStateDynDial(t *testing.T) {
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{
&waitExpireTask{Duration: 14 * time.Second},
&waitExpireTask{Duration: 19 * time.Second},
},
},
// In this round, the peer with id 2 drops off. The query
@ -223,10 +227,13 @@ func TestDialStateDynDial(t *testing.T) {
// Tests that bootnodes are dialed if no peers are connectd, but not otherwise.
func TestDialStateDynDialBootnode(t *testing.T) {
bootnodes := []*enode.Node{
newNode(uintID(1), nil),
newNode(uintID(2), nil),
newNode(uintID(3), nil),
config := &Config{
BootstrapNodes: []*enode.Node{
newNode(uintID(1), nil),
newNode(uintID(2), nil),
newNode(uintID(3), nil),
},
Logger: testlog.Logger(t, log.LvlTrace),
}
table := fakeTable{
newNode(uintID(4), nil),
@ -236,7 +243,7 @@ func TestDialStateDynDialBootnode(t *testing.T) {
newNode(uintID(8), nil),
}
runDialTest(t, dialtest{
init: newDialState(enode.ID{}, nil, bootnodes, table, 5, nil),
init: newDialState(enode.ID{}, table, 5, config),
rounds: []round{
// 2 dynamic dials attempted, bootnodes pending fallback interval
{
@ -259,25 +266,24 @@ func TestDialStateDynDialBootnode(t *testing.T) {
{
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
},
// No dials succeed, 2nd bootnode is attempted
{
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
},
// No dials succeed, 3rd bootnode is attempted
{
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
@ -288,21 +294,19 @@ func TestDialStateDynDialBootnode(t *testing.T) {
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
},
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{},
},
// Random dial succeeds, no more bootnodes are attempted
{
new: []task{
&waitExpireTask{3 * time.Second},
},
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
},
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
},
},
@ -324,7 +328,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
}
runDialTest(t, dialtest{
init: newDialState(enode.ID{}, nil, nil, table, 10, nil),
init: newDialState(enode.ID{}, table, 10, &Config{Logger: testlog.Logger(t, log.LvlTrace)}),
rounds: []round{
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
{
@ -430,7 +434,7 @@ func TestDialStateNetRestrict(t *testing.T) {
restrict.Add("127.0.2.0/24")
runDialTest(t, dialtest{
init: newDialState(enode.ID{}, nil, nil, table, 10, restrict),
init: newDialState(enode.ID{}, table, 10, &Config{NetRestrict: restrict}),
rounds: []round{
{
new: []task{
@ -444,16 +448,18 @@ func TestDialStateNetRestrict(t *testing.T) {
// This test checks that static dials are launched.
func TestDialStateStaticDial(t *testing.T) {
wantStatic := []*enode.Node{
newNode(uintID(1), nil),
newNode(uintID(2), nil),
newNode(uintID(3), nil),
newNode(uintID(4), nil),
newNode(uintID(5), nil),
config := &Config{
StaticNodes: []*enode.Node{
newNode(uintID(1), nil),
newNode(uintID(2), nil),
newNode(uintID(3), nil),
newNode(uintID(4), nil),
newNode(uintID(5), nil),
},
Logger: testlog.Logger(t, log.LvlTrace),
}
runDialTest(t, dialtest{
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
init: newDialState(enode.ID{}, fakeTable{}, 0, config),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
@ -495,7 +501,7 @@ func TestDialStateStaticDial(t *testing.T) {
&dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)},
},
new: []task{
&waitExpireTask{Duration: 14 * time.Second},
&waitExpireTask{Duration: 19 * time.Second},
},
},
// Wait a round for dial history to expire, no new tasks should spawn.
@ -511,6 +517,9 @@ func TestDialStateStaticDial(t *testing.T) {
// If a static node is dropped, it should be immediately redialed,
// irrespective whether it was originally static or dynamic.
{
done: []task{
&waitExpireTask{Duration: 19 * time.Second},
},
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
@ -518,67 +527,24 @@ func TestDialStateStaticDial(t *testing.T) {
},
new: []task{
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
&dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)},
},
},
},
})
}
// This test checks that static peers will be redialed immediately if they were re-added to a static list.
func TestDialStaticAfterReset(t *testing.T) {
wantStatic := []*enode.Node{
newNode(uintID(1), nil),
newNode(uintID(2), nil),
}
rounds := []round{
// Static dials are launched for the nodes that aren't yet connected.
{
peers: nil,
new: []task{
&dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
},
},
// No new dial tasks, all peers are connected.
{
peers: []*Peer{
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
},
done: []task{
&dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
},
new: []task{
&waitExpireTask{Duration: 30 * time.Second},
},
},
}
dTest := dialtest{
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
rounds: rounds,
}
runDialTest(t, dTest)
for _, n := range wantStatic {
dTest.init.removeStatic(n)
dTest.init.addStatic(n)
}
// without removing peers they will be considered recently dialed
runDialTest(t, dTest)
}
// This test checks that past dials are not retried for some time.
func TestDialStateCache(t *testing.T) {
wantStatic := []*enode.Node{
newNode(uintID(1), nil),
newNode(uintID(2), nil),
newNode(uintID(3), nil),
config := &Config{
StaticNodes: []*enode.Node{
newNode(uintID(1), nil),
newNode(uintID(2), nil),
newNode(uintID(3), nil),
},
Logger: testlog.Logger(t, log.LvlTrace),
}
runDialTest(t, dialtest{
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
init: newDialState(enode.ID{}, fakeTable{}, 0, config),
rounds: []round{
// Static dials are launched for the nodes that
// aren't yet connected.
@ -606,28 +572,37 @@ func TestDialStateCache(t *testing.T) {
// entry to expire.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
},
done: []task{
&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
},
new: []task{
&waitExpireTask{Duration: 14 * time.Second},
&waitExpireTask{Duration: 19 * time.Second},
},
},
// Still waiting for node 3's entry to expire in the cache.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
},
},
{
peers: []*Peer{
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
},
},
// The cache entry for node 3 has expired and is retried.
{
done: []task{
&waitExpireTask{Duration: 19 * time.Second},
},
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
},
new: []task{
&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
@ -638,9 +613,13 @@ func TestDialStateCache(t *testing.T) {
}
func TestDialResolve(t *testing.T) {
config := &Config{
Logger: testlog.Logger(t, log.LvlTrace),
Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}},
}
resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
table := &resolveMock{answer: resolved}
state := newDialState(enode.ID{}, nil, nil, table, 0, nil)
state := newDialState(enode.ID{}, table, 0, config)
// Check that the task is generated with an incomplete ID.
dest := newNode(uintID(1), nil)
@ -651,8 +630,7 @@ func TestDialResolve(t *testing.T) {
}
// Now run the task, it should resolve the ID once.
config := Config{Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}}
srv := &Server{ntab: table, Config: config}
srv := &Server{ntab: table, log: config.Logger, Config: *config}
tasks[0].Do(srv)
if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) {
t.Fatalf("wrong resolve calls, got %v", table.resolveCalls)

33
p2p/netutil/addrutil.go Normal file
View File

@ -0,0 +1,33 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package netutil
import "net"
// AddrIP gets the IP address contained in addr. It returns nil if no address is present.
func AddrIP(addr net.Addr) net.IP {
switch a := addr.(type) {
case *net.IPAddr:
return a.IP
case *net.TCPAddr:
return a.IP
case *net.UDPAddr:
return a.IP
default:
return nil
}
}

View File

@ -120,7 +120,7 @@ func NewPeer(id enode.ID, name string, caps []Cap) *Peer {
pipe, _ := net.Pipe()
node := enode.SignNull(new(enr.Record), id)
conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name}
peer := newPeer(conn, nil)
peer := newPeer(log.Root(), conn, nil)
close(peer.closed) // ensures Disconnect doesn't block
return peer
}
@ -176,7 +176,7 @@ func (p *Peer) Inbound() bool {
return p.rw.is(inboundConn)
}
func newPeer(conn *conn, protocols []Protocol) *Peer {
func newPeer(log log.Logger, conn *conn, protocols []Protocol) *Peer {
protomap := matchProtocols(protocols, conn.caps, conn)
p := &Peer{
rw: conn,

View File

@ -24,6 +24,8 @@ import (
"reflect"
"testing"
"time"
"github.com/ethereum/go-ethereum/log"
)
var discard = Protocol{
@ -52,7 +54,7 @@ func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) {
c2.caps = append(c2.caps, p.cap())
}
peer := newPeer(c1, protos)
peer := newPeer(log.Root(), c1, protos)
errc := make(chan error, 1)
go func() {
_, err := peer.run()

View File

@ -22,6 +22,7 @@ import (
"crypto/ecdsa"
"encoding/hex"
"errors"
"fmt"
"net"
"sort"
"sync"
@ -49,6 +50,9 @@ const (
defaultMaxPendingPeers = 50
defaultDialRatio = 3
// This time limits inbound connection attempts per source IP.
inboundThrottleTime = 30 * time.Second
// Maximum time allowed for reading a complete message.
// This is effectively the amount of time a connection can be idle.
frameReadTimeout = 30 * time.Second
@ -158,6 +162,7 @@ type Server struct {
// the whole protocol stack.
newTransport func(net.Conn) transport
newPeerHook func(*Peer)
listenFunc func(network, addr string) (net.Listener, error)
lock sync.Mutex // protects running
running bool
@ -167,24 +172,26 @@ type Server struct {
ntab discoverTable
listener net.Listener
ourHandshake *protoHandshake
lastLookup time.Time
DiscV5 *discv5.Network
loopWG sync.WaitGroup // loop, listenLoop
peerFeed event.Feed
log log.Logger
// These are for Peers, PeerCount (and nothing else).
peerOp chan peerOpFunc
peerOpDone chan struct{}
// Channels into the run loop.
quit chan struct{}
addstatic chan *enode.Node
removestatic chan *enode.Node
addtrusted chan *enode.Node
removetrusted chan *enode.Node
peerOp chan peerOpFunc
peerOpDone chan struct{}
delpeer chan peerDrop
checkpointPostHandshake chan *conn
checkpointAddPeer chan *conn
quit chan struct{}
addstatic chan *enode.Node
removestatic chan *enode.Node
addtrusted chan *enode.Node
removetrusted chan *enode.Node
posthandshake chan *conn
addpeer chan *conn
delpeer chan peerDrop
loopWG sync.WaitGroup // loop, listenLoop
peerFeed event.Feed
log log.Logger
// State of run loop and listenLoop.
lastLookup time.Time
inboundHistory expHeap
}
type peerOpFunc func(map[enode.ID]*Peer)
@ -415,7 +422,7 @@ func (srv *Server) Start() (err error) {
srv.running = true
srv.log = srv.Config.Logger
if srv.log == nil {
srv.log = log.New()
srv.log = log.Root()
}
if srv.NoDial && srv.ListenAddr == "" {
srv.log.Warn("P2P server will be useless, neither dialing nor listening")
@ -428,13 +435,16 @@ func (srv *Server) Start() (err error) {
if srv.newTransport == nil {
srv.newTransport = newRLPX
}
if srv.listenFunc == nil {
srv.listenFunc = net.Listen
}
if srv.Dialer == nil {
srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}}
}
srv.quit = make(chan struct{})
srv.addpeer = make(chan *conn)
srv.delpeer = make(chan peerDrop)
srv.posthandshake = make(chan *conn)
srv.checkpointPostHandshake = make(chan *conn)
srv.checkpointAddPeer = make(chan *conn)
srv.addstatic = make(chan *enode.Node)
srv.removestatic = make(chan *enode.Node)
srv.addtrusted = make(chan *enode.Node)
@ -455,7 +465,7 @@ func (srv *Server) Start() (err error) {
}
dynPeers := srv.maxDialedConns()
dialer := newDialState(srv.localnode.ID(), srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
dialer := newDialState(srv.localnode.ID(), srv.ntab, dynPeers, &srv.Config)
srv.loopWG.Add(1)
go srv.run(dialer)
return nil
@ -541,6 +551,7 @@ func (srv *Server) setupDiscovery() error {
NetRestrict: srv.NetRestrict,
Bootnodes: srv.BootstrapNodes,
Unhandled: unhandled,
Log: srv.log,
}
ntab, err := discover.ListenUDP(conn, srv.localnode, cfg)
if err != nil {
@ -569,27 +580,28 @@ func (srv *Server) setupDiscovery() error {
}
func (srv *Server) setupListening() error {
// Launch the TCP listener.
listener, err := net.Listen("tcp", srv.ListenAddr)
// Launch the listener.
listener, err := srv.listenFunc("tcp", srv.ListenAddr)
if err != nil {
return err
}
laddr := listener.Addr().(*net.TCPAddr)
srv.ListenAddr = laddr.String()
srv.listener = listener
srv.localnode.Set(enr.TCP(laddr.Port))
srv.ListenAddr = listener.Addr().String()
// Update the local node record and map the TCP listening port if NAT is configured.
if tcp, ok := listener.Addr().(*net.TCPAddr); ok {
srv.localnode.Set(enr.TCP(tcp.Port))
if !tcp.IP.IsLoopback() && srv.NAT != nil {
srv.loopWG.Add(1)
go func() {
nat.Map(srv.NAT, srv.quit, "tcp", tcp.Port, tcp.Port, "ethereum p2p")
srv.loopWG.Done()
}()
}
}
srv.loopWG.Add(1)
go srv.listenLoop()
// Map the TCP listening port if NAT is configured.
if !laddr.IP.IsLoopback() && srv.NAT != nil {
srv.loopWG.Add(1)
go func() {
nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p")
srv.loopWG.Done()
}()
}
return nil
}
@ -657,12 +669,14 @@ running:
case <-srv.quit:
// The server was stopped. Run the cleanup logic.
break running
case n := <-srv.addstatic:
// This channel is used by AddPeer to add to the
// ephemeral static peer list. Add it to the dialer,
// it will keep the node connected.
srv.log.Trace("Adding static node", "node", n)
dialstate.addStatic(n)
case n := <-srv.removestatic:
// This channel is used by RemovePeer to send a
// disconnect request to a peer and begin the
@ -672,6 +686,7 @@ running:
if p, ok := peers[n.ID()]; ok {
p.Disconnect(DiscRequested)
}
case n := <-srv.addtrusted:
// This channel is used by AddTrustedPeer to add an enode
// to the trusted node set.
@ -681,6 +696,7 @@ running:
if p, ok := peers[n.ID()]; ok {
p.rw.set(trustedConn, true)
}
case n := <-srv.removetrusted:
// This channel is used by RemoveTrustedPeer to remove an enode
// from the trusted node set.
@ -691,10 +707,12 @@ running:
if p, ok := peers[n.ID()]; ok {
p.rw.set(trustedConn, false)
}
case op := <-srv.peerOp:
// This channel is used by Peers and PeerCount.
op(peers)
srv.peerOpDone <- struct{}{}
case t := <-taskdone:
// A task got done. Tell dialstate about it so it
// can update its state and remove it from the active
@ -702,7 +720,8 @@ running:
srv.log.Trace("Dial task done", "task", t)
dialstate.taskDone(t, time.Now())
delTask(t)
case c := <-srv.posthandshake:
case c := <-srv.checkpointPostHandshake:
// A connection has passed the encryption handshake so
// the remote identity is known (but hasn't been verified yet).
if trusted[c.node.ID()] {
@ -710,18 +729,15 @@ running:
c.flags |= trustedConn
}
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
select {
case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c):
case <-srv.quit:
break running
}
case c := <-srv.addpeer:
c.cont <- srv.postHandshakeChecks(peers, inboundCount, c)
case c := <-srv.checkpointAddPeer:
// At this point the connection is past the protocol handshake.
// Its capabilities are known and the remote identity is verified.
err := srv.protoHandshakeChecks(peers, inboundCount, c)
err := srv.addPeerChecks(peers, inboundCount, c)
if err == nil {
// The handshakes are done and it passed all checks.
p := newPeer(c, srv.Protocols)
p := newPeer(srv.log, c, srv.Protocols)
// If message events are enabled, pass the peerFeed
// to the peer
if srv.EnableMsgEvents {
@ -738,11 +754,8 @@ running:
// The dialer logic relies on the assumption that
// dial tasks complete after the peer has been added or
// discarded. Unblock the task last.
select {
case c.cont <- err:
case <-srv.quit:
break running
}
c.cont <- err
case pd := <-srv.delpeer:
// A peer disconnected.
d := common.PrettyDuration(mclock.Now() - pd.created)
@ -777,17 +790,7 @@ running:
}
}
func (srv *Server) protoHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
// Drop connections with no matching protocols.
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
return DiscUselessPeer
}
// Repeat the encryption handshake checks because the
// peer set might have changed between the handshakes.
return srv.encHandshakeChecks(peers, inboundCount, c)
}
func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
func (srv *Server) postHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
switch {
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
return DiscTooManyPeers
@ -802,9 +805,20 @@ func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int
}
}
func (srv *Server) addPeerChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
// Drop connections with no matching protocols.
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
return DiscUselessPeer
}
// Repeat the post-handshake checks because the
// peer set might have changed since those checks were performed.
return srv.postHandshakeChecks(peers, inboundCount, c)
}
func (srv *Server) maxInboundConns() int {
return srv.MaxPeers - srv.maxDialedConns()
}
func (srv *Server) maxDialedConns() int {
if srv.NoDiscovery || srv.NoDial {
return 0
@ -832,7 +846,7 @@ func (srv *Server) listenLoop() {
}
for {
// Wait for a handshake slot before accepting.
// Wait for a free slot before accepting.
<-slots
var (
@ -851,21 +865,16 @@ func (srv *Server) listenLoop() {
break
}
// Reject connections that do not match NetRestrict.
if srv.NetRestrict != nil {
if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) {
srv.log.Debug("Rejected conn (not whitelisted in NetRestrict)", "addr", fd.RemoteAddr())
fd.Close()
slots <- struct{}{}
continue
}
remoteIP := netutil.AddrIP(fd.RemoteAddr())
if err := srv.checkInboundConn(fd, remoteIP); err != nil {
srv.log.Debug("Rejected inbound connnection", "addr", fd.RemoteAddr(), "err", err)
fd.Close()
slots <- struct{}{}
continue
}
var ip net.IP
if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok {
ip = tcp.IP
if remoteIP != nil {
fd = newMeteredConn(fd, true, remoteIP)
}
fd = newMeteredConn(fd, true, ip)
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
go func() {
srv.SetupConn(fd, inboundConn, nil)
@ -874,6 +883,22 @@ func (srv *Server) listenLoop() {
}
}
func (srv *Server) checkInboundConn(fd net.Conn, remoteIP net.IP) error {
if remoteIP != nil {
// Reject connections that do not match NetRestrict.
if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) {
return fmt.Errorf("not whitelisted in NetRestrict")
}
// Reject Internet peers that try too often.
srv.inboundHistory.expire(time.Now())
if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) {
return fmt.Errorf("too many attempts")
}
srv.inboundHistory.add(remoteIP.String(), time.Now().Add(inboundThrottleTime))
}
return nil
}
// SetupConn runs the handshakes and attempts to add the connection
// as a peer. It returns when the connection has been added as a peer
// or the handshakes have failed.
@ -895,6 +920,7 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
if !running {
return errServerStopped
}
// If dialing, figure out the remote public key.
var dialPubkey *ecdsa.PublicKey
if dialDest != nil {
@ -903,7 +929,8 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
return errors.New("dial destination doesn't have a secp256k1 public key")
}
}
// Run the encryption handshake.
// Run the RLPx handshake.
remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey)
if err != nil {
srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err)
@ -922,12 +949,13 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
conn.handshakeDone(c.node.ID())
}
clog := srv.log.New("id", c.node.ID(), "addr", c.fd.RemoteAddr(), "conn", c.flags)
err = srv.checkpoint(c, srv.posthandshake)
err = srv.checkpoint(c, srv.checkpointPostHandshake)
if err != nil {
clog.Trace("Rejected peer before protocol handshake", "err", err)
clog.Trace("Rejected peer", "err", err)
return err
}
// Run the protocol handshake
// Run the capability negotiation handshake.
phs, err := c.doProtoHandshake(srv.ourHandshake)
if err != nil {
clog.Trace("Failed proto handshake", "err", err)
@ -938,14 +966,15 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
return DiscUnexpectedIdentity
}
c.caps, c.name = phs.Caps, phs.Name
err = srv.checkpoint(c, srv.addpeer)
err = srv.checkpoint(c, srv.checkpointAddPeer)
if err != nil {
clog.Trace("Rejected peer", "err", err)
return err
}
// If the checks completed successfully, runPeer has now been
// launched by run.
clog.Trace("connection set up", "inbound", dialDest == nil)
// If the checks completed successfully, the connection has been added as a peer and
// runPeer has been launched.
clog.Trace("Connection set up", "inbound", dialDest == nil)
return nil
}
@ -974,12 +1003,7 @@ func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error {
case <-srv.quit:
return errServerStopped
}
select {
case err := <-c.cont:
return err
case <-srv.quit:
return errServerStopped
}
return <-c.cont
}
// runPeer runs in its own goroutine for each peer.

View File

@ -19,6 +19,7 @@ package p2p
import (
"crypto/ecdsa"
"errors"
"io"
"math/rand"
"net"
"reflect"
@ -26,6 +27,7 @@ import (
"time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/internal/testlog"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
@ -74,6 +76,7 @@ func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *
MaxPeers: 10,
ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(),
Logger: testlog.Logger(t, log.LvlTrace),
}
server := &Server{
Config: config,
@ -359,6 +362,7 @@ func TestServerAtCap(t *testing.T) {
PrivateKey: newkey(),
MaxPeers: 10,
NoDial: true,
NoDiscovery: true,
TrustedNodes: []*enode.Node{newNode(trustedID, nil)},
},
}
@ -377,19 +381,19 @@ func TestServerAtCap(t *testing.T) {
// Inject a few connections to fill up the peer set.
for i := 0; i < 10; i++ {
c := newconn(randomID())
if err := srv.checkpoint(c, srv.addpeer); err != nil {
if err := srv.checkpoint(c, srv.checkpointAddPeer); err != nil {
t.Fatalf("could not add conn %d: %v", i, err)
}
}
// Try inserting a non-trusted connection.
anotherID := randomID()
c := newconn(anotherID)
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
t.Error("wrong error for insert:", err)
}
// Try inserting a trusted connection.
c = newconn(trustedID)
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
t.Error("unexpected error for trusted conn @posthandshake:", err)
}
if !c.is(trustedConn) {
@ -399,14 +403,14 @@ func TestServerAtCap(t *testing.T) {
// Remove from trusted set and try again
srv.RemoveTrustedPeer(newNode(trustedID, nil))
c = newconn(trustedID)
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
t.Error("wrong error for insert:", err)
}
// Add anotherID to trusted set and try again
srv.AddTrustedPeer(newNode(anotherID, nil))
c = newconn(anotherID)
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
t.Error("unexpected error for trusted conn @posthandshake:", err)
}
if !c.is(trustedConn) {
@ -430,10 +434,11 @@ func TestServerPeerLimits(t *testing.T) {
srv := &Server{
Config: Config{
PrivateKey: srvkey,
MaxPeers: 0,
NoDial: true,
Protocols: []Protocol{discard},
PrivateKey: srvkey,
MaxPeers: 0,
NoDial: true,
NoDiscovery: true,
Protocols: []Protocol{discard},
},
newTransport: func(fd net.Conn) transport { return tp },
log: log.New(),
@ -541,29 +546,35 @@ func TestServerSetupConn(t *testing.T) {
}
for i, test := range tests {
srv := &Server{
Config: Config{
PrivateKey: srvkey,
MaxPeers: 10,
NoDial: true,
Protocols: []Protocol{discard},
},
newTransport: func(fd net.Conn) transport { return test.tt },
log: log.New(),
}
if !test.dontstart {
if err := srv.Start(); err != nil {
t.Fatalf("couldn't start server: %v", err)
t.Run(test.wantCalls, func(t *testing.T) {
cfg := Config{
PrivateKey: srvkey,
MaxPeers: 10,
NoDial: true,
NoDiscovery: true,
Protocols: []Protocol{discard},
Logger: testlog.Logger(t, log.LvlTrace),
}
}
p1, _ := net.Pipe()
srv.SetupConn(p1, test.flags, test.dialDest)
if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
}
if test.tt.calls != test.wantCalls {
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
}
srv := &Server{
Config: cfg,
newTransport: func(fd net.Conn) transport { return test.tt },
log: cfg.Logger,
}
if !test.dontstart {
if err := srv.Start(); err != nil {
t.Fatalf("couldn't start server: %v", err)
}
defer srv.Stop()
}
p1, _ := net.Pipe()
srv.SetupConn(p1, test.flags, test.dialDest)
if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
}
if test.tt.calls != test.wantCalls {
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
}
})
}
}
@ -616,3 +627,100 @@ func randomID() (id enode.ID) {
}
return id
}
// This test checks that inbound connections are throttled by IP.
func TestServerInboundThrottle(t *testing.T) {
const timeout = 5 * time.Second
newTransportCalled := make(chan struct{})
srv := &Server{
Config: Config{
PrivateKey: newkey(),
ListenAddr: "127.0.0.1:0",
MaxPeers: 10,
NoDial: true,
NoDiscovery: true,
Protocols: []Protocol{discard},
Logger: testlog.Logger(t, log.LvlTrace),
},
newTransport: func(fd net.Conn) transport {
newTransportCalled <- struct{}{}
return newRLPX(fd)
},
listenFunc: func(network, laddr string) (net.Listener, error) {
fakeAddr := &net.TCPAddr{IP: net.IP{95, 33, 21, 2}, Port: 4444}
return listenFakeAddr(network, laddr, fakeAddr)
},
}
if err := srv.Start(); err != nil {
t.Fatal("can't start: ", err)
}
defer srv.Stop()
// Dial the test server.
conn, err := net.DialTimeout("tcp", srv.ListenAddr, timeout)
if err != nil {
t.Fatalf("could not dial: %v", err)
}
select {
case <-newTransportCalled:
// OK
case <-time.After(timeout):
t.Error("newTransport not called")
}
conn.Close()
// Dial again. This time the server should close the connection immediately.
connClosed := make(chan struct{})
conn, err = net.DialTimeout("tcp", srv.ListenAddr, timeout)
if err != nil {
t.Fatalf("could not dial: %v", err)
}
defer conn.Close()
go func() {
conn.SetDeadline(time.Now().Add(timeout))
buf := make([]byte, 10)
if n, err := conn.Read(buf); err != io.EOF || n != 0 {
t.Errorf("expected io.EOF and n == 0, got error %q and n == %d", err, n)
}
connClosed <- struct{}{}
}()
select {
case <-connClosed:
// OK
case <-newTransportCalled:
t.Error("newTransport called for second attempt")
case <-time.After(timeout):
t.Error("connection not closed within timeout")
}
}
func listenFakeAddr(network, laddr string, remoteAddr net.Addr) (net.Listener, error) {
l, err := net.Listen(network, laddr)
if err == nil {
l = &fakeAddrListener{l, remoteAddr}
}
return l, err
}
// fakeAddrListener is a listener that creates connections with a mocked remote address.
type fakeAddrListener struct {
net.Listener
remoteAddr net.Addr
}
type fakeAddrConn struct {
net.Conn
remoteAddr net.Addr
}
func (l *fakeAddrListener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &fakeAddrConn{c, l.remoteAddr}, nil
}
func (c *fakeAddrConn) RemoteAddr() net.Addr {
return c.remoteAddr
}

82
p2p/util.go Normal file
View File

@ -0,0 +1,82 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package p2p
import (
"container/heap"
"time"
)
// expHeap tracks strings and their expiry time.
type expHeap []expItem
// expItem is an entry in addrHistory.
type expItem struct {
item string
exp time.Time
}
// nextExpiry returns the next expiry time.
func (h *expHeap) nextExpiry() time.Time {
return (*h)[0].exp
}
// add adds an item and sets its expiry time.
func (h *expHeap) add(item string, exp time.Time) {
heap.Push(h, expItem{item, exp})
}
// remove removes an item.
func (h *expHeap) remove(item string) bool {
for i, v := range *h {
if v.item == item {
heap.Remove(h, i)
return true
}
}
return false
}
// contains checks whether an item is present.
func (h expHeap) contains(item string) bool {
for _, v := range h {
if v.item == item {
return true
}
}
return false
}
// expire removes items with expiry time before 'now'.
func (h *expHeap) expire(now time.Time) {
for h.Len() > 0 && h.nextExpiry().Before(now) {
heap.Pop(h)
}
}
// heap.Interface boilerplate
func (h expHeap) Len() int { return len(h) }
func (h expHeap) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
func (h expHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *expHeap) Push(x interface{}) { *h = append(*h, x.(expItem)) }
func (h *expHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}

54
p2p/util_test.go Normal file
View File

@ -0,0 +1,54 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package p2p
import (
"testing"
"time"
)
func TestExpHeap(t *testing.T) {
var h expHeap
var (
basetime = time.Unix(4000, 0)
exptimeA = basetime.Add(2 * time.Second)
exptimeB = basetime.Add(3 * time.Second)
exptimeC = basetime.Add(4 * time.Second)
)
h.add("a", exptimeA)
h.add("b", exptimeB)
h.add("c", exptimeC)
if !h.nextExpiry().Equal(exptimeA) {
t.Fatal("wrong nextExpiry")
}
if !h.contains("a") || !h.contains("b") || !h.contains("c") {
t.Fatal("heap doesn't contain all live items")
}
h.expire(exptimeA.Add(1))
if !h.nextExpiry().Equal(exptimeB) {
t.Fatal("wrong nextExpiry")
}
if h.contains("a") {
t.Fatal("heap contains a even though it has already expired")
}
if !h.contains("b") || !h.contains("c") {
t.Fatal("heap doesn't contain all live items")
}
}