p2p: enforce connection retry limit on server side (#19684)
The dialer limits itself to one attempt every 30s. Apply the same limit in Server and reject peers which try to connect too eagerly. The check against the limit happens right after accepting the connection. Further changes in this commit ensure we pass the Server logger down to Peer instances, discovery and dialState. Unit test logging now works in all Server tests.
This commit is contained in:
parent
c0a034ec89
commit
c420dcb39c
107
p2p/dial.go
107
p2p/dial.go
@ -17,7 +17,6 @@
|
||||
package p2p
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@ -29,9 +28,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// This is the amount of time spent waiting in between
|
||||
// redialing a certain node.
|
||||
dialHistoryExpiration = 30 * time.Second
|
||||
// This is the amount of time spent waiting in between redialing a certain node. The
|
||||
// limit is a bit higher than inboundThrottleTime to prevent failing dials in small
|
||||
// private networks.
|
||||
dialHistoryExpiration = inboundThrottleTime + 5*time.Second
|
||||
|
||||
// Discovery lookups are throttled and can only run
|
||||
// once every few seconds.
|
||||
@ -72,16 +72,16 @@ type dialstate struct {
|
||||
ntab discoverTable
|
||||
netrestrict *netutil.Netlist
|
||||
self enode.ID
|
||||
bootnodes []*enode.Node // default dials when there are no peers
|
||||
log log.Logger
|
||||
|
||||
start time.Time // time when the dialer was first used
|
||||
lookupRunning bool
|
||||
dialing map[enode.ID]connFlag
|
||||
lookupBuf []*enode.Node // current discovery lookup results
|
||||
randomNodes []*enode.Node // filled from Table
|
||||
static map[enode.ID]*dialTask
|
||||
hist *dialHistory
|
||||
|
||||
start time.Time // time when the dialer was first used
|
||||
bootnodes []*enode.Node // default dials when there are no peers
|
||||
hist expHeap
|
||||
}
|
||||
|
||||
type discoverTable interface {
|
||||
@ -91,15 +91,6 @@ type discoverTable interface {
|
||||
ReadRandomNodes([]*enode.Node) int
|
||||
}
|
||||
|
||||
// the dial history remembers recent dials.
|
||||
type dialHistory []pastDial
|
||||
|
||||
// pastDial is an entry in the dial history.
|
||||
type pastDial struct {
|
||||
id enode.ID
|
||||
exp time.Time
|
||||
}
|
||||
|
||||
type task interface {
|
||||
Do(*Server)
|
||||
}
|
||||
@ -126,20 +117,23 @@ type waitExpireTask struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
|
||||
func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate {
|
||||
s := &dialstate{
|
||||
maxDynDials: maxdyn,
|
||||
ntab: ntab,
|
||||
self: self,
|
||||
netrestrict: netrestrict,
|
||||
netrestrict: cfg.NetRestrict,
|
||||
log: cfg.Logger,
|
||||
static: make(map[enode.ID]*dialTask),
|
||||
dialing: make(map[enode.ID]connFlag),
|
||||
bootnodes: make([]*enode.Node, len(bootnodes)),
|
||||
bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)),
|
||||
randomNodes: make([]*enode.Node, maxdyn/2),
|
||||
hist: new(dialHistory),
|
||||
}
|
||||
copy(s.bootnodes, bootnodes)
|
||||
for _, n := range static {
|
||||
copy(s.bootnodes, cfg.BootstrapNodes)
|
||||
if s.log == nil {
|
||||
s.log = log.Root()
|
||||
}
|
||||
for _, n := range cfg.StaticNodes {
|
||||
s.addStatic(n)
|
||||
}
|
||||
return s
|
||||
@ -154,9 +148,6 @@ func (s *dialstate) addStatic(n *enode.Node) {
|
||||
func (s *dialstate) removeStatic(n *enode.Node) {
|
||||
// This removes a task so future attempts to connect will not be made.
|
||||
delete(s.static, n.ID())
|
||||
// This removes a previous dial timestamp so that application
|
||||
// can force a server to reconnect with chosen peer immediately.
|
||||
s.hist.remove(n.ID())
|
||||
}
|
||||
|
||||
func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task {
|
||||
@ -167,7 +158,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
|
||||
var newtasks []task
|
||||
addDial := func(flag connFlag, n *enode.Node) bool {
|
||||
if err := s.checkDial(n, peers); err != nil {
|
||||
log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err)
|
||||
s.log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err)
|
||||
return false
|
||||
}
|
||||
s.dialing[n.ID()] = flag
|
||||
@ -196,7 +187,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
|
||||
err := s.checkDial(t.dest, peers)
|
||||
switch err {
|
||||
case errNotWhitelisted, errSelf:
|
||||
log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err)
|
||||
s.log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err)
|
||||
delete(s.static, t.dest.ID())
|
||||
case nil:
|
||||
s.dialing[id] = t.flags
|
||||
@ -246,7 +237,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
|
||||
// This should prevent cases where the dialer logic is not ticked
|
||||
// because there are no pending events.
|
||||
if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
|
||||
t := &waitExpireTask{s.hist.min().exp.Sub(now)}
|
||||
t := &waitExpireTask{s.hist.nextExpiry().Sub(now)}
|
||||
newtasks = append(newtasks, t)
|
||||
}
|
||||
return newtasks
|
||||
@ -271,7 +262,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
|
||||
return errSelf
|
||||
case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()):
|
||||
return errNotWhitelisted
|
||||
case s.hist.contains(n.ID()):
|
||||
case s.hist.contains(string(n.ID().Bytes())):
|
||||
return errRecentlyDialed
|
||||
}
|
||||
return nil
|
||||
@ -280,7 +271,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
|
||||
func (s *dialstate) taskDone(t task, now time.Time) {
|
||||
switch t := t.(type) {
|
||||
case *dialTask:
|
||||
s.hist.add(t.dest.ID(), now.Add(dialHistoryExpiration))
|
||||
s.hist.add(string(t.dest.ID().Bytes()), now.Add(dialHistoryExpiration))
|
||||
delete(s.dialing, t.dest.ID())
|
||||
case *discoverTask:
|
||||
s.lookupRunning = false
|
||||
@ -296,7 +287,7 @@ func (t *dialTask) Do(srv *Server) {
|
||||
}
|
||||
err := t.dial(srv, t.dest)
|
||||
if err != nil {
|
||||
log.Trace("Dial error", "task", t, "err", err)
|
||||
srv.log.Trace("Dial error", "task", t, "err", err)
|
||||
// Try resolving the ID of static nodes if dialing failed.
|
||||
if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
|
||||
if t.resolve(srv) {
|
||||
@ -314,7 +305,7 @@ func (t *dialTask) Do(srv *Server) {
|
||||
// The backoff delay resets when the node is found.
|
||||
func (t *dialTask) resolve(srv *Server) bool {
|
||||
if srv.ntab == nil {
|
||||
log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
|
||||
srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
|
||||
return false
|
||||
}
|
||||
if t.resolveDelay == 0 {
|
||||
@ -330,13 +321,13 @@ func (t *dialTask) resolve(srv *Server) bool {
|
||||
if t.resolveDelay > maxResolveDelay {
|
||||
t.resolveDelay = maxResolveDelay
|
||||
}
|
||||
log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay)
|
||||
srv.log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay)
|
||||
return false
|
||||
}
|
||||
// The node was found.
|
||||
t.resolveDelay = initialResolveDelay
|
||||
t.dest = resolved
|
||||
log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
|
||||
srv.log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
|
||||
return true
|
||||
}
|
||||
|
||||
@ -385,49 +376,3 @@ func (t waitExpireTask) Do(*Server) {
|
||||
func (t waitExpireTask) String() string {
|
||||
return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
|
||||
}
|
||||
|
||||
// Use only these methods to access or modify dialHistory.
|
||||
func (h dialHistory) min() pastDial {
|
||||
return h[0]
|
||||
}
|
||||
func (h *dialHistory) add(id enode.ID, exp time.Time) {
|
||||
heap.Push(h, pastDial{id, exp})
|
||||
|
||||
}
|
||||
func (h *dialHistory) remove(id enode.ID) bool {
|
||||
for i, v := range *h {
|
||||
if v.id == id {
|
||||
heap.Remove(h, i)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (h dialHistory) contains(id enode.ID) bool {
|
||||
for _, v := range h {
|
||||
if v.id == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (h *dialHistory) expire(now time.Time) {
|
||||
for h.Len() > 0 && h.min().exp.Before(now) {
|
||||
heap.Pop(h)
|
||||
}
|
||||
}
|
||||
|
||||
// heap.Interface boilerplate
|
||||
func (h dialHistory) Len() int { return len(h) }
|
||||
func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
|
||||
func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
func (h *dialHistory) Push(x interface{}) {
|
||||
*h = append(*h, x.(pastDial))
|
||||
}
|
||||
func (h *dialHistory) Pop() interface{} {
|
||||
old := *h
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*h = old[0 : n-1]
|
||||
return x
|
||||
}
|
||||
|
138
p2p/dial_test.go
138
p2p/dial_test.go
@ -20,10 +20,13 @@ import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/ethereum/go-ethereum/internal/testlog"
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||
"github.com/ethereum/go-ethereum/p2p/netutil"
|
||||
@ -67,10 +70,10 @@ func runDialTest(t *testing.T, test dialtest) {
|
||||
|
||||
new := test.init.newTasks(running, pm(round.peers), vtime)
|
||||
if !sametasks(new, round.new) {
|
||||
t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n",
|
||||
t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v",
|
||||
i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
|
||||
}
|
||||
t.Log("tasks:", spew.Sdump(new))
|
||||
t.Logf("round %d new tasks: %s", i, strings.TrimSpace(spew.Sdump(new)))
|
||||
|
||||
// Time advances by 16 seconds on every round.
|
||||
vtime = vtime.Add(16 * time.Second)
|
||||
@ -88,8 +91,9 @@ func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t)
|
||||
|
||||
// This test checks that dynamic dials are launched from discovery results.
|
||||
func TestDialStateDynDial(t *testing.T) {
|
||||
config := &Config{Logger: testlog.Logger(t, log.LvlTrace)}
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(enode.ID{}, nil, nil, fakeTable{}, 5, nil),
|
||||
init: newDialState(enode.ID{}, fakeTable{}, 5, config),
|
||||
rounds: []round{
|
||||
// A discovery query is launched.
|
||||
{
|
||||
@ -153,7 +157,7 @@ func TestDialStateDynDial(t *testing.T) {
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
||||
},
|
||||
new: []task{
|
||||
&waitExpireTask{Duration: 14 * time.Second},
|
||||
&waitExpireTask{Duration: 19 * time.Second},
|
||||
},
|
||||
},
|
||||
// In this round, the peer with id 2 drops off. The query
|
||||
@ -223,10 +227,13 @@ func TestDialStateDynDial(t *testing.T) {
|
||||
|
||||
// Tests that bootnodes are dialed if no peers are connectd, but not otherwise.
|
||||
func TestDialStateDynDialBootnode(t *testing.T) {
|
||||
bootnodes := []*enode.Node{
|
||||
config := &Config{
|
||||
BootstrapNodes: []*enode.Node{
|
||||
newNode(uintID(1), nil),
|
||||
newNode(uintID(2), nil),
|
||||
newNode(uintID(3), nil),
|
||||
},
|
||||
Logger: testlog.Logger(t, log.LvlTrace),
|
||||
}
|
||||
table := fakeTable{
|
||||
newNode(uintID(4), nil),
|
||||
@ -236,7 +243,7 @@ func TestDialStateDynDialBootnode(t *testing.T) {
|
||||
newNode(uintID(8), nil),
|
||||
}
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(enode.ID{}, nil, bootnodes, table, 5, nil),
|
||||
init: newDialState(enode.ID{}, table, 5, config),
|
||||
rounds: []round{
|
||||
// 2 dynamic dials attempted, bootnodes pending fallback interval
|
||||
{
|
||||
@ -259,25 +266,24 @@ func TestDialStateDynDialBootnode(t *testing.T) {
|
||||
{
|
||||
new: []task{
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
||||
},
|
||||
},
|
||||
// No dials succeed, 2nd bootnode is attempted
|
||||
{
|
||||
done: []task{
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
||||
},
|
||||
new: []task{
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
||||
},
|
||||
},
|
||||
// No dials succeed, 3rd bootnode is attempted
|
||||
{
|
||||
done: []task{
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
||||
},
|
||||
new: []task{
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
|
||||
@ -288,21 +294,19 @@ func TestDialStateDynDialBootnode(t *testing.T) {
|
||||
done: []task{
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
|
||||
},
|
||||
new: []task{
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
||||
},
|
||||
new: []task{},
|
||||
},
|
||||
// Random dial succeeds, no more bootnodes are attempted
|
||||
{
|
||||
new: []task{
|
||||
&waitExpireTask{3 * time.Second},
|
||||
},
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
|
||||
},
|
||||
done: []task{
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
|
||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -324,7 +328,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
|
||||
}
|
||||
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(enode.ID{}, nil, nil, table, 10, nil),
|
||||
init: newDialState(enode.ID{}, table, 10, &Config{Logger: testlog.Logger(t, log.LvlTrace)}),
|
||||
rounds: []round{
|
||||
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
|
||||
{
|
||||
@ -430,7 +434,7 @@ func TestDialStateNetRestrict(t *testing.T) {
|
||||
restrict.Add("127.0.2.0/24")
|
||||
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(enode.ID{}, nil, nil, table, 10, restrict),
|
||||
init: newDialState(enode.ID{}, table, 10, &Config{NetRestrict: restrict}),
|
||||
rounds: []round{
|
||||
{
|
||||
new: []task{
|
||||
@ -444,16 +448,18 @@ func TestDialStateNetRestrict(t *testing.T) {
|
||||
|
||||
// This test checks that static dials are launched.
|
||||
func TestDialStateStaticDial(t *testing.T) {
|
||||
wantStatic := []*enode.Node{
|
||||
config := &Config{
|
||||
StaticNodes: []*enode.Node{
|
||||
newNode(uintID(1), nil),
|
||||
newNode(uintID(2), nil),
|
||||
newNode(uintID(3), nil),
|
||||
newNode(uintID(4), nil),
|
||||
newNode(uintID(5), nil),
|
||||
},
|
||||
Logger: testlog.Logger(t, log.LvlTrace),
|
||||
}
|
||||
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
|
||||
init: newDialState(enode.ID{}, fakeTable{}, 0, config),
|
||||
rounds: []round{
|
||||
// Static dials are launched for the nodes that
|
||||
// aren't yet connected.
|
||||
@ -495,7 +501,7 @@ func TestDialStateStaticDial(t *testing.T) {
|
||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)},
|
||||
},
|
||||
new: []task{
|
||||
&waitExpireTask{Duration: 14 * time.Second},
|
||||
&waitExpireTask{Duration: 19 * time.Second},
|
||||
},
|
||||
},
|
||||
// Wait a round for dial history to expire, no new tasks should spawn.
|
||||
@ -511,6 +517,9 @@ func TestDialStateStaticDial(t *testing.T) {
|
||||
// If a static node is dropped, it should be immediately redialed,
|
||||
// irrespective whether it was originally static or dynamic.
|
||||
{
|
||||
done: []task{
|
||||
&waitExpireTask{Duration: 19 * time.Second},
|
||||
},
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
|
||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
|
||||
@ -518,67 +527,24 @@ func TestDialStateStaticDial(t *testing.T) {
|
||||
},
|
||||
new: []task{
|
||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
|
||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// This test checks that static peers will be redialed immediately if they were re-added to a static list.
|
||||
func TestDialStaticAfterReset(t *testing.T) {
|
||||
wantStatic := []*enode.Node{
|
||||
newNode(uintID(1), nil),
|
||||
newNode(uintID(2), nil),
|
||||
}
|
||||
|
||||
rounds := []round{
|
||||
// Static dials are launched for the nodes that aren't yet connected.
|
||||
{
|
||||
peers: nil,
|
||||
new: []task{
|
||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
|
||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
|
||||
},
|
||||
},
|
||||
// No new dial tasks, all peers are connected.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
|
||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
|
||||
},
|
||||
done: []task{
|
||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
|
||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
|
||||
},
|
||||
new: []task{
|
||||
&waitExpireTask{Duration: 30 * time.Second},
|
||||
},
|
||||
},
|
||||
}
|
||||
dTest := dialtest{
|
||||
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
|
||||
rounds: rounds,
|
||||
}
|
||||
runDialTest(t, dTest)
|
||||
for _, n := range wantStatic {
|
||||
dTest.init.removeStatic(n)
|
||||
dTest.init.addStatic(n)
|
||||
}
|
||||
// without removing peers they will be considered recently dialed
|
||||
runDialTest(t, dTest)
|
||||
}
|
||||
|
||||
// This test checks that past dials are not retried for some time.
|
||||
func TestDialStateCache(t *testing.T) {
|
||||
wantStatic := []*enode.Node{
|
||||
config := &Config{
|
||||
StaticNodes: []*enode.Node{
|
||||
newNode(uintID(1), nil),
|
||||
newNode(uintID(2), nil),
|
||||
newNode(uintID(3), nil),
|
||||
},
|
||||
Logger: testlog.Logger(t, log.LvlTrace),
|
||||
}
|
||||
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
|
||||
init: newDialState(enode.ID{}, fakeTable{}, 0, config),
|
||||
rounds: []round{
|
||||
// Static dials are launched for the nodes that
|
||||
// aren't yet connected.
|
||||
@ -606,28 +572,37 @@ func TestDialStateCache(t *testing.T) {
|
||||
// entry to expire.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
|
||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
|
||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
|
||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
|
||||
},
|
||||
done: []task{
|
||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
|
||||
},
|
||||
new: []task{
|
||||
&waitExpireTask{Duration: 14 * time.Second},
|
||||
&waitExpireTask{Duration: 19 * time.Second},
|
||||
},
|
||||
},
|
||||
// Still waiting for node 3's entry to expire in the cache.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
|
||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
|
||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
|
||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
|
||||
},
|
||||
},
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
|
||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
|
||||
},
|
||||
},
|
||||
// The cache entry for node 3 has expired and is retried.
|
||||
{
|
||||
done: []task{
|
||||
&waitExpireTask{Duration: 19 * time.Second},
|
||||
},
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
|
||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
|
||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
|
||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
|
||||
},
|
||||
new: []task{
|
||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
|
||||
@ -638,9 +613,13 @@ func TestDialStateCache(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDialResolve(t *testing.T) {
|
||||
config := &Config{
|
||||
Logger: testlog.Logger(t, log.LvlTrace),
|
||||
Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}},
|
||||
}
|
||||
resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
|
||||
table := &resolveMock{answer: resolved}
|
||||
state := newDialState(enode.ID{}, nil, nil, table, 0, nil)
|
||||
state := newDialState(enode.ID{}, table, 0, config)
|
||||
|
||||
// Check that the task is generated with an incomplete ID.
|
||||
dest := newNode(uintID(1), nil)
|
||||
@ -651,8 +630,7 @@ func TestDialResolve(t *testing.T) {
|
||||
}
|
||||
|
||||
// Now run the task, it should resolve the ID once.
|
||||
config := Config{Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}}
|
||||
srv := &Server{ntab: table, Config: config}
|
||||
srv := &Server{ntab: table, log: config.Logger, Config: *config}
|
||||
tasks[0].Do(srv)
|
||||
if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) {
|
||||
t.Fatalf("wrong resolve calls, got %v", table.resolveCalls)
|
||||
|
33
p2p/netutil/addrutil.go
Normal file
33
p2p/netutil/addrutil.go
Normal file
@ -0,0 +1,33 @@
|
||||
// Copyright 2016 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package netutil
|
||||
|
||||
import "net"
|
||||
|
||||
// AddrIP gets the IP address contained in addr. It returns nil if no address is present.
|
||||
func AddrIP(addr net.Addr) net.IP {
|
||||
switch a := addr.(type) {
|
||||
case *net.IPAddr:
|
||||
return a.IP
|
||||
case *net.TCPAddr:
|
||||
return a.IP
|
||||
case *net.UDPAddr:
|
||||
return a.IP
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
@ -120,7 +120,7 @@ func NewPeer(id enode.ID, name string, caps []Cap) *Peer {
|
||||
pipe, _ := net.Pipe()
|
||||
node := enode.SignNull(new(enr.Record), id)
|
||||
conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name}
|
||||
peer := newPeer(conn, nil)
|
||||
peer := newPeer(log.Root(), conn, nil)
|
||||
close(peer.closed) // ensures Disconnect doesn't block
|
||||
return peer
|
||||
}
|
||||
@ -176,7 +176,7 @@ func (p *Peer) Inbound() bool {
|
||||
return p.rw.is(inboundConn)
|
||||
}
|
||||
|
||||
func newPeer(conn *conn, protocols []Protocol) *Peer {
|
||||
func newPeer(log log.Logger, conn *conn, protocols []Protocol) *Peer {
|
||||
protomap := matchProtocols(protocols, conn.caps, conn)
|
||||
p := &Peer{
|
||||
rw: conn,
|
||||
|
@ -24,6 +24,8 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
)
|
||||
|
||||
var discard = Protocol{
|
||||
@ -52,7 +54,7 @@ func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) {
|
||||
c2.caps = append(c2.caps, p.cap())
|
||||
}
|
||||
|
||||
peer := newPeer(c1, protos)
|
||||
peer := newPeer(log.Root(), c1, protos)
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := peer.run()
|
||||
|
174
p2p/server.go
174
p2p/server.go
@ -22,6 +22,7 @@ import (
|
||||
"crypto/ecdsa"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
@ -49,6 +50,9 @@ const (
|
||||
defaultMaxPendingPeers = 50
|
||||
defaultDialRatio = 3
|
||||
|
||||
// This time limits inbound connection attempts per source IP.
|
||||
inboundThrottleTime = 30 * time.Second
|
||||
|
||||
// Maximum time allowed for reading a complete message.
|
||||
// This is effectively the amount of time a connection can be idle.
|
||||
frameReadTimeout = 30 * time.Second
|
||||
@ -158,6 +162,7 @@ type Server struct {
|
||||
// the whole protocol stack.
|
||||
newTransport func(net.Conn) transport
|
||||
newPeerHook func(*Peer)
|
||||
listenFunc func(network, addr string) (net.Listener, error)
|
||||
|
||||
lock sync.Mutex // protects running
|
||||
running bool
|
||||
@ -167,24 +172,26 @@ type Server struct {
|
||||
ntab discoverTable
|
||||
listener net.Listener
|
||||
ourHandshake *protoHandshake
|
||||
lastLookup time.Time
|
||||
DiscV5 *discv5.Network
|
||||
loopWG sync.WaitGroup // loop, listenLoop
|
||||
peerFeed event.Feed
|
||||
log log.Logger
|
||||
|
||||
// These are for Peers, PeerCount (and nothing else).
|
||||
peerOp chan peerOpFunc
|
||||
peerOpDone chan struct{}
|
||||
|
||||
// Channels into the run loop.
|
||||
quit chan struct{}
|
||||
addstatic chan *enode.Node
|
||||
removestatic chan *enode.Node
|
||||
addtrusted chan *enode.Node
|
||||
removetrusted chan *enode.Node
|
||||
posthandshake chan *conn
|
||||
addpeer chan *conn
|
||||
peerOp chan peerOpFunc
|
||||
peerOpDone chan struct{}
|
||||
delpeer chan peerDrop
|
||||
loopWG sync.WaitGroup // loop, listenLoop
|
||||
peerFeed event.Feed
|
||||
log log.Logger
|
||||
checkpointPostHandshake chan *conn
|
||||
checkpointAddPeer chan *conn
|
||||
|
||||
// State of run loop and listenLoop.
|
||||
lastLookup time.Time
|
||||
inboundHistory expHeap
|
||||
}
|
||||
|
||||
type peerOpFunc func(map[enode.ID]*Peer)
|
||||
@ -415,7 +422,7 @@ func (srv *Server) Start() (err error) {
|
||||
srv.running = true
|
||||
srv.log = srv.Config.Logger
|
||||
if srv.log == nil {
|
||||
srv.log = log.New()
|
||||
srv.log = log.Root()
|
||||
}
|
||||
if srv.NoDial && srv.ListenAddr == "" {
|
||||
srv.log.Warn("P2P server will be useless, neither dialing nor listening")
|
||||
@ -428,13 +435,16 @@ func (srv *Server) Start() (err error) {
|
||||
if srv.newTransport == nil {
|
||||
srv.newTransport = newRLPX
|
||||
}
|
||||
if srv.listenFunc == nil {
|
||||
srv.listenFunc = net.Listen
|
||||
}
|
||||
if srv.Dialer == nil {
|
||||
srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}}
|
||||
}
|
||||
srv.quit = make(chan struct{})
|
||||
srv.addpeer = make(chan *conn)
|
||||
srv.delpeer = make(chan peerDrop)
|
||||
srv.posthandshake = make(chan *conn)
|
||||
srv.checkpointPostHandshake = make(chan *conn)
|
||||
srv.checkpointAddPeer = make(chan *conn)
|
||||
srv.addstatic = make(chan *enode.Node)
|
||||
srv.removestatic = make(chan *enode.Node)
|
||||
srv.addtrusted = make(chan *enode.Node)
|
||||
@ -455,7 +465,7 @@ func (srv *Server) Start() (err error) {
|
||||
}
|
||||
|
||||
dynPeers := srv.maxDialedConns()
|
||||
dialer := newDialState(srv.localnode.ID(), srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
|
||||
dialer := newDialState(srv.localnode.ID(), srv.ntab, dynPeers, &srv.Config)
|
||||
srv.loopWG.Add(1)
|
||||
go srv.run(dialer)
|
||||
return nil
|
||||
@ -541,6 +551,7 @@ func (srv *Server) setupDiscovery() error {
|
||||
NetRestrict: srv.NetRestrict,
|
||||
Bootnodes: srv.BootstrapNodes,
|
||||
Unhandled: unhandled,
|
||||
Log: srv.log,
|
||||
}
|
||||
ntab, err := discover.ListenUDP(conn, srv.localnode, cfg)
|
||||
if err != nil {
|
||||
@ -569,27 +580,28 @@ func (srv *Server) setupDiscovery() error {
|
||||
}
|
||||
|
||||
func (srv *Server) setupListening() error {
|
||||
// Launch the TCP listener.
|
||||
listener, err := net.Listen("tcp", srv.ListenAddr)
|
||||
// Launch the listener.
|
||||
listener, err := srv.listenFunc("tcp", srv.ListenAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
laddr := listener.Addr().(*net.TCPAddr)
|
||||
srv.ListenAddr = laddr.String()
|
||||
srv.listener = listener
|
||||
srv.localnode.Set(enr.TCP(laddr.Port))
|
||||
srv.ListenAddr = listener.Addr().String()
|
||||
|
||||
srv.loopWG.Add(1)
|
||||
go srv.listenLoop()
|
||||
|
||||
// Map the TCP listening port if NAT is configured.
|
||||
if !laddr.IP.IsLoopback() && srv.NAT != nil {
|
||||
// Update the local node record and map the TCP listening port if NAT is configured.
|
||||
if tcp, ok := listener.Addr().(*net.TCPAddr); ok {
|
||||
srv.localnode.Set(enr.TCP(tcp.Port))
|
||||
if !tcp.IP.IsLoopback() && srv.NAT != nil {
|
||||
srv.loopWG.Add(1)
|
||||
go func() {
|
||||
nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p")
|
||||
nat.Map(srv.NAT, srv.quit, "tcp", tcp.Port, tcp.Port, "ethereum p2p")
|
||||
srv.loopWG.Done()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
srv.loopWG.Add(1)
|
||||
go srv.listenLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -657,12 +669,14 @@ running:
|
||||
case <-srv.quit:
|
||||
// The server was stopped. Run the cleanup logic.
|
||||
break running
|
||||
|
||||
case n := <-srv.addstatic:
|
||||
// This channel is used by AddPeer to add to the
|
||||
// ephemeral static peer list. Add it to the dialer,
|
||||
// it will keep the node connected.
|
||||
srv.log.Trace("Adding static node", "node", n)
|
||||
dialstate.addStatic(n)
|
||||
|
||||
case n := <-srv.removestatic:
|
||||
// This channel is used by RemovePeer to send a
|
||||
// disconnect request to a peer and begin the
|
||||
@ -672,6 +686,7 @@ running:
|
||||
if p, ok := peers[n.ID()]; ok {
|
||||
p.Disconnect(DiscRequested)
|
||||
}
|
||||
|
||||
case n := <-srv.addtrusted:
|
||||
// This channel is used by AddTrustedPeer to add an enode
|
||||
// to the trusted node set.
|
||||
@ -681,6 +696,7 @@ running:
|
||||
if p, ok := peers[n.ID()]; ok {
|
||||
p.rw.set(trustedConn, true)
|
||||
}
|
||||
|
||||
case n := <-srv.removetrusted:
|
||||
// This channel is used by RemoveTrustedPeer to remove an enode
|
||||
// from the trusted node set.
|
||||
@ -691,10 +707,12 @@ running:
|
||||
if p, ok := peers[n.ID()]; ok {
|
||||
p.rw.set(trustedConn, false)
|
||||
}
|
||||
|
||||
case op := <-srv.peerOp:
|
||||
// This channel is used by Peers and PeerCount.
|
||||
op(peers)
|
||||
srv.peerOpDone <- struct{}{}
|
||||
|
||||
case t := <-taskdone:
|
||||
// A task got done. Tell dialstate about it so it
|
||||
// can update its state and remove it from the active
|
||||
@ -702,7 +720,8 @@ running:
|
||||
srv.log.Trace("Dial task done", "task", t)
|
||||
dialstate.taskDone(t, time.Now())
|
||||
delTask(t)
|
||||
case c := <-srv.posthandshake:
|
||||
|
||||
case c := <-srv.checkpointPostHandshake:
|
||||
// A connection has passed the encryption handshake so
|
||||
// the remote identity is known (but hasn't been verified yet).
|
||||
if trusted[c.node.ID()] {
|
||||
@ -710,18 +729,15 @@ running:
|
||||
c.flags |= trustedConn
|
||||
}
|
||||
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
|
||||
select {
|
||||
case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c):
|
||||
case <-srv.quit:
|
||||
break running
|
||||
}
|
||||
case c := <-srv.addpeer:
|
||||
c.cont <- srv.postHandshakeChecks(peers, inboundCount, c)
|
||||
|
||||
case c := <-srv.checkpointAddPeer:
|
||||
// At this point the connection is past the protocol handshake.
|
||||
// Its capabilities are known and the remote identity is verified.
|
||||
err := srv.protoHandshakeChecks(peers, inboundCount, c)
|
||||
err := srv.addPeerChecks(peers, inboundCount, c)
|
||||
if err == nil {
|
||||
// The handshakes are done and it passed all checks.
|
||||
p := newPeer(c, srv.Protocols)
|
||||
p := newPeer(srv.log, c, srv.Protocols)
|
||||
// If message events are enabled, pass the peerFeed
|
||||
// to the peer
|
||||
if srv.EnableMsgEvents {
|
||||
@ -738,11 +754,8 @@ running:
|
||||
// The dialer logic relies on the assumption that
|
||||
// dial tasks complete after the peer has been added or
|
||||
// discarded. Unblock the task last.
|
||||
select {
|
||||
case c.cont <- err:
|
||||
case <-srv.quit:
|
||||
break running
|
||||
}
|
||||
c.cont <- err
|
||||
|
||||
case pd := <-srv.delpeer:
|
||||
// A peer disconnected.
|
||||
d := common.PrettyDuration(mclock.Now() - pd.created)
|
||||
@ -777,17 +790,7 @@ running:
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) protoHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
|
||||
// Drop connections with no matching protocols.
|
||||
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
|
||||
return DiscUselessPeer
|
||||
}
|
||||
// Repeat the encryption handshake checks because the
|
||||
// peer set might have changed between the handshakes.
|
||||
return srv.encHandshakeChecks(peers, inboundCount, c)
|
||||
}
|
||||
|
||||
func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
|
||||
func (srv *Server) postHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
|
||||
switch {
|
||||
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
|
||||
return DiscTooManyPeers
|
||||
@ -802,9 +805,20 @@ func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) addPeerChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
|
||||
// Drop connections with no matching protocols.
|
||||
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
|
||||
return DiscUselessPeer
|
||||
}
|
||||
// Repeat the post-handshake checks because the
|
||||
// peer set might have changed since those checks were performed.
|
||||
return srv.postHandshakeChecks(peers, inboundCount, c)
|
||||
}
|
||||
|
||||
func (srv *Server) maxInboundConns() int {
|
||||
return srv.MaxPeers - srv.maxDialedConns()
|
||||
}
|
||||
|
||||
func (srv *Server) maxDialedConns() int {
|
||||
if srv.NoDiscovery || srv.NoDial {
|
||||
return 0
|
||||
@ -832,7 +846,7 @@ func (srv *Server) listenLoop() {
|
||||
}
|
||||
|
||||
for {
|
||||
// Wait for a handshake slot before accepting.
|
||||
// Wait for a free slot before accepting.
|
||||
<-slots
|
||||
|
||||
var (
|
||||
@ -851,21 +865,16 @@ func (srv *Server) listenLoop() {
|
||||
break
|
||||
}
|
||||
|
||||
// Reject connections that do not match NetRestrict.
|
||||
if srv.NetRestrict != nil {
|
||||
if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) {
|
||||
srv.log.Debug("Rejected conn (not whitelisted in NetRestrict)", "addr", fd.RemoteAddr())
|
||||
remoteIP := netutil.AddrIP(fd.RemoteAddr())
|
||||
if err := srv.checkInboundConn(fd, remoteIP); err != nil {
|
||||
srv.log.Debug("Rejected inbound connnection", "addr", fd.RemoteAddr(), "err", err)
|
||||
fd.Close()
|
||||
slots <- struct{}{}
|
||||
continue
|
||||
}
|
||||
if remoteIP != nil {
|
||||
fd = newMeteredConn(fd, true, remoteIP)
|
||||
}
|
||||
|
||||
var ip net.IP
|
||||
if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok {
|
||||
ip = tcp.IP
|
||||
}
|
||||
fd = newMeteredConn(fd, true, ip)
|
||||
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
|
||||
go func() {
|
||||
srv.SetupConn(fd, inboundConn, nil)
|
||||
@ -874,6 +883,22 @@ func (srv *Server) listenLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) checkInboundConn(fd net.Conn, remoteIP net.IP) error {
|
||||
if remoteIP != nil {
|
||||
// Reject connections that do not match NetRestrict.
|
||||
if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) {
|
||||
return fmt.Errorf("not whitelisted in NetRestrict")
|
||||
}
|
||||
// Reject Internet peers that try too often.
|
||||
srv.inboundHistory.expire(time.Now())
|
||||
if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) {
|
||||
return fmt.Errorf("too many attempts")
|
||||
}
|
||||
srv.inboundHistory.add(remoteIP.String(), time.Now().Add(inboundThrottleTime))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetupConn runs the handshakes and attempts to add the connection
|
||||
// as a peer. It returns when the connection has been added as a peer
|
||||
// or the handshakes have failed.
|
||||
@ -895,6 +920,7 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
|
||||
if !running {
|
||||
return errServerStopped
|
||||
}
|
||||
|
||||
// If dialing, figure out the remote public key.
|
||||
var dialPubkey *ecdsa.PublicKey
|
||||
if dialDest != nil {
|
||||
@ -903,7 +929,8 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
|
||||
return errors.New("dial destination doesn't have a secp256k1 public key")
|
||||
}
|
||||
}
|
||||
// Run the encryption handshake.
|
||||
|
||||
// Run the RLPx handshake.
|
||||
remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey)
|
||||
if err != nil {
|
||||
srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err)
|
||||
@ -922,12 +949,13 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
|
||||
conn.handshakeDone(c.node.ID())
|
||||
}
|
||||
clog := srv.log.New("id", c.node.ID(), "addr", c.fd.RemoteAddr(), "conn", c.flags)
|
||||
err = srv.checkpoint(c, srv.posthandshake)
|
||||
err = srv.checkpoint(c, srv.checkpointPostHandshake)
|
||||
if err != nil {
|
||||
clog.Trace("Rejected peer before protocol handshake", "err", err)
|
||||
clog.Trace("Rejected peer", "err", err)
|
||||
return err
|
||||
}
|
||||
// Run the protocol handshake
|
||||
|
||||
// Run the capability negotiation handshake.
|
||||
phs, err := c.doProtoHandshake(srv.ourHandshake)
|
||||
if err != nil {
|
||||
clog.Trace("Failed proto handshake", "err", err)
|
||||
@ -938,14 +966,15 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
|
||||
return DiscUnexpectedIdentity
|
||||
}
|
||||
c.caps, c.name = phs.Caps, phs.Name
|
||||
err = srv.checkpoint(c, srv.addpeer)
|
||||
err = srv.checkpoint(c, srv.checkpointAddPeer)
|
||||
if err != nil {
|
||||
clog.Trace("Rejected peer", "err", err)
|
||||
return err
|
||||
}
|
||||
// If the checks completed successfully, runPeer has now been
|
||||
// launched by run.
|
||||
clog.Trace("connection set up", "inbound", dialDest == nil)
|
||||
|
||||
// If the checks completed successfully, the connection has been added as a peer and
|
||||
// runPeer has been launched.
|
||||
clog.Trace("Connection set up", "inbound", dialDest == nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -974,12 +1003,7 @@ func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error {
|
||||
case <-srv.quit:
|
||||
return errServerStopped
|
||||
}
|
||||
select {
|
||||
case err := <-c.cont:
|
||||
return err
|
||||
case <-srv.quit:
|
||||
return errServerStopped
|
||||
}
|
||||
return <-c.cont
|
||||
}
|
||||
|
||||
// runPeer runs in its own goroutine for each peer.
|
||||
|
@ -19,6 +19,7 @@ package p2p
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
@ -26,6 +27,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/internal/testlog"
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||
@ -74,6 +76,7 @@ func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *
|
||||
MaxPeers: 10,
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
PrivateKey: newkey(),
|
||||
Logger: testlog.Logger(t, log.LvlTrace),
|
||||
}
|
||||
server := &Server{
|
||||
Config: config,
|
||||
@ -359,6 +362,7 @@ func TestServerAtCap(t *testing.T) {
|
||||
PrivateKey: newkey(),
|
||||
MaxPeers: 10,
|
||||
NoDial: true,
|
||||
NoDiscovery: true,
|
||||
TrustedNodes: []*enode.Node{newNode(trustedID, nil)},
|
||||
},
|
||||
}
|
||||
@ -377,19 +381,19 @@ func TestServerAtCap(t *testing.T) {
|
||||
// Inject a few connections to fill up the peer set.
|
||||
for i := 0; i < 10; i++ {
|
||||
c := newconn(randomID())
|
||||
if err := srv.checkpoint(c, srv.addpeer); err != nil {
|
||||
if err := srv.checkpoint(c, srv.checkpointAddPeer); err != nil {
|
||||
t.Fatalf("could not add conn %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
// Try inserting a non-trusted connection.
|
||||
anotherID := randomID()
|
||||
c := newconn(anotherID)
|
||||
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
|
||||
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
|
||||
t.Error("wrong error for insert:", err)
|
||||
}
|
||||
// Try inserting a trusted connection.
|
||||
c = newconn(trustedID)
|
||||
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
|
||||
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
|
||||
t.Error("unexpected error for trusted conn @posthandshake:", err)
|
||||
}
|
||||
if !c.is(trustedConn) {
|
||||
@ -399,14 +403,14 @@ func TestServerAtCap(t *testing.T) {
|
||||
// Remove from trusted set and try again
|
||||
srv.RemoveTrustedPeer(newNode(trustedID, nil))
|
||||
c = newconn(trustedID)
|
||||
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
|
||||
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
|
||||
t.Error("wrong error for insert:", err)
|
||||
}
|
||||
|
||||
// Add anotherID to trusted set and try again
|
||||
srv.AddTrustedPeer(newNode(anotherID, nil))
|
||||
c = newconn(anotherID)
|
||||
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
|
||||
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
|
||||
t.Error("unexpected error for trusted conn @posthandshake:", err)
|
||||
}
|
||||
if !c.is(trustedConn) {
|
||||
@ -433,6 +437,7 @@ func TestServerPeerLimits(t *testing.T) {
|
||||
PrivateKey: srvkey,
|
||||
MaxPeers: 0,
|
||||
NoDial: true,
|
||||
NoDiscovery: true,
|
||||
Protocols: []Protocol{discard},
|
||||
},
|
||||
newTransport: func(fd net.Conn) transport { return tp },
|
||||
@ -541,20 +546,25 @@ func TestServerSetupConn(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
srv := &Server{
|
||||
Config: Config{
|
||||
t.Run(test.wantCalls, func(t *testing.T) {
|
||||
cfg := Config{
|
||||
PrivateKey: srvkey,
|
||||
MaxPeers: 10,
|
||||
NoDial: true,
|
||||
NoDiscovery: true,
|
||||
Protocols: []Protocol{discard},
|
||||
},
|
||||
Logger: testlog.Logger(t, log.LvlTrace),
|
||||
}
|
||||
srv := &Server{
|
||||
Config: cfg,
|
||||
newTransport: func(fd net.Conn) transport { return test.tt },
|
||||
log: log.New(),
|
||||
log: cfg.Logger,
|
||||
}
|
||||
if !test.dontstart {
|
||||
if err := srv.Start(); err != nil {
|
||||
t.Fatalf("couldn't start server: %v", err)
|
||||
}
|
||||
defer srv.Stop()
|
||||
}
|
||||
p1, _ := net.Pipe()
|
||||
srv.SetupConn(p1, test.flags, test.dialDest)
|
||||
@ -564,6 +574,7 @@ func TestServerSetupConn(t *testing.T) {
|
||||
if test.tt.calls != test.wantCalls {
|
||||
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -616,3 +627,100 @@ func randomID() (id enode.ID) {
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// This test checks that inbound connections are throttled by IP.
|
||||
func TestServerInboundThrottle(t *testing.T) {
|
||||
const timeout = 5 * time.Second
|
||||
newTransportCalled := make(chan struct{})
|
||||
srv := &Server{
|
||||
Config: Config{
|
||||
PrivateKey: newkey(),
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
MaxPeers: 10,
|
||||
NoDial: true,
|
||||
NoDiscovery: true,
|
||||
Protocols: []Protocol{discard},
|
||||
Logger: testlog.Logger(t, log.LvlTrace),
|
||||
},
|
||||
newTransport: func(fd net.Conn) transport {
|
||||
newTransportCalled <- struct{}{}
|
||||
return newRLPX(fd)
|
||||
},
|
||||
listenFunc: func(network, laddr string) (net.Listener, error) {
|
||||
fakeAddr := &net.TCPAddr{IP: net.IP{95, 33, 21, 2}, Port: 4444}
|
||||
return listenFakeAddr(network, laddr, fakeAddr)
|
||||
},
|
||||
}
|
||||
if err := srv.Start(); err != nil {
|
||||
t.Fatal("can't start: ", err)
|
||||
}
|
||||
defer srv.Stop()
|
||||
|
||||
// Dial the test server.
|
||||
conn, err := net.DialTimeout("tcp", srv.ListenAddr, timeout)
|
||||
if err != nil {
|
||||
t.Fatalf("could not dial: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-newTransportCalled:
|
||||
// OK
|
||||
case <-time.After(timeout):
|
||||
t.Error("newTransport not called")
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
// Dial again. This time the server should close the connection immediately.
|
||||
connClosed := make(chan struct{})
|
||||
conn, err = net.DialTimeout("tcp", srv.ListenAddr, timeout)
|
||||
if err != nil {
|
||||
t.Fatalf("could not dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
go func() {
|
||||
conn.SetDeadline(time.Now().Add(timeout))
|
||||
buf := make([]byte, 10)
|
||||
if n, err := conn.Read(buf); err != io.EOF || n != 0 {
|
||||
t.Errorf("expected io.EOF and n == 0, got error %q and n == %d", err, n)
|
||||
}
|
||||
connClosed <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-connClosed:
|
||||
// OK
|
||||
case <-newTransportCalled:
|
||||
t.Error("newTransport called for second attempt")
|
||||
case <-time.After(timeout):
|
||||
t.Error("connection not closed within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func listenFakeAddr(network, laddr string, remoteAddr net.Addr) (net.Listener, error) {
|
||||
l, err := net.Listen(network, laddr)
|
||||
if err == nil {
|
||||
l = &fakeAddrListener{l, remoteAddr}
|
||||
}
|
||||
return l, err
|
||||
}
|
||||
|
||||
// fakeAddrListener is a listener that creates connections with a mocked remote address.
|
||||
type fakeAddrListener struct {
|
||||
net.Listener
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
type fakeAddrConn struct {
|
||||
net.Conn
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (l *fakeAddrListener) Accept() (net.Conn, error) {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fakeAddrConn{c, l.remoteAddr}, nil
|
||||
}
|
||||
|
||||
func (c *fakeAddrConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
82
p2p/util.go
Normal file
82
p2p/util.go
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright 2019 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package p2p
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"time"
|
||||
)
|
||||
|
||||
// expHeap tracks strings and their expiry time.
|
||||
type expHeap []expItem
|
||||
|
||||
// expItem is an entry in addrHistory.
|
||||
type expItem struct {
|
||||
item string
|
||||
exp time.Time
|
||||
}
|
||||
|
||||
// nextExpiry returns the next expiry time.
|
||||
func (h *expHeap) nextExpiry() time.Time {
|
||||
return (*h)[0].exp
|
||||
}
|
||||
|
||||
// add adds an item and sets its expiry time.
|
||||
func (h *expHeap) add(item string, exp time.Time) {
|
||||
heap.Push(h, expItem{item, exp})
|
||||
}
|
||||
|
||||
// remove removes an item.
|
||||
func (h *expHeap) remove(item string) bool {
|
||||
for i, v := range *h {
|
||||
if v.item == item {
|
||||
heap.Remove(h, i)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// contains checks whether an item is present.
|
||||
func (h expHeap) contains(item string) bool {
|
||||
for _, v := range h {
|
||||
if v.item == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// expire removes items with expiry time before 'now'.
|
||||
func (h *expHeap) expire(now time.Time) {
|
||||
for h.Len() > 0 && h.nextExpiry().Before(now) {
|
||||
heap.Pop(h)
|
||||
}
|
||||
}
|
||||
|
||||
// heap.Interface boilerplate
|
||||
func (h expHeap) Len() int { return len(h) }
|
||||
func (h expHeap) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
|
||||
func (h expHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
func (h *expHeap) Push(x interface{}) { *h = append(*h, x.(expItem)) }
|
||||
func (h *expHeap) Pop() interface{} {
|
||||
old := *h
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*h = old[0 : n-1]
|
||||
return x
|
||||
}
|
54
p2p/util_test.go
Normal file
54
p2p/util_test.go
Normal file
@ -0,0 +1,54 @@
|
||||
// Copyright 2019 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package p2p
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestExpHeap(t *testing.T) {
|
||||
var h expHeap
|
||||
|
||||
var (
|
||||
basetime = time.Unix(4000, 0)
|
||||
exptimeA = basetime.Add(2 * time.Second)
|
||||
exptimeB = basetime.Add(3 * time.Second)
|
||||
exptimeC = basetime.Add(4 * time.Second)
|
||||
)
|
||||
h.add("a", exptimeA)
|
||||
h.add("b", exptimeB)
|
||||
h.add("c", exptimeC)
|
||||
|
||||
if !h.nextExpiry().Equal(exptimeA) {
|
||||
t.Fatal("wrong nextExpiry")
|
||||
}
|
||||
if !h.contains("a") || !h.contains("b") || !h.contains("c") {
|
||||
t.Fatal("heap doesn't contain all live items")
|
||||
}
|
||||
|
||||
h.expire(exptimeA.Add(1))
|
||||
if !h.nextExpiry().Equal(exptimeB) {
|
||||
t.Fatal("wrong nextExpiry")
|
||||
}
|
||||
if h.contains("a") {
|
||||
t.Fatal("heap contains a even though it has already expired")
|
||||
}
|
||||
if !h.contains("b") || !h.contains("c") {
|
||||
t.Fatal("heap doesn't contain all live items")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user