p2p: enforce connection retry limit on server side (#19684)
The dialer limits itself to one attempt every 30s. Apply the same limit in Server and reject peers which try to connect too eagerly. The check against the limit happens right after accepting the connection. Further changes in this commit ensure we pass the Server logger down to Peer instances, discovery and dialState. Unit test logging now works in all Server tests.
This commit is contained in:
parent
c0a034ec89
commit
c420dcb39c
107
p2p/dial.go
107
p2p/dial.go
@ -17,7 +17,6 @@
|
|||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"container/heap"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@ -29,9 +28,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// This is the amount of time spent waiting in between
|
// This is the amount of time spent waiting in between redialing a certain node. The
|
||||||
// redialing a certain node.
|
// limit is a bit higher than inboundThrottleTime to prevent failing dials in small
|
||||||
dialHistoryExpiration = 30 * time.Second
|
// private networks.
|
||||||
|
dialHistoryExpiration = inboundThrottleTime + 5*time.Second
|
||||||
|
|
||||||
// Discovery lookups are throttled and can only run
|
// Discovery lookups are throttled and can only run
|
||||||
// once every few seconds.
|
// once every few seconds.
|
||||||
@ -72,16 +72,16 @@ type dialstate struct {
|
|||||||
ntab discoverTable
|
ntab discoverTable
|
||||||
netrestrict *netutil.Netlist
|
netrestrict *netutil.Netlist
|
||||||
self enode.ID
|
self enode.ID
|
||||||
|
bootnodes []*enode.Node // default dials when there are no peers
|
||||||
|
log log.Logger
|
||||||
|
|
||||||
|
start time.Time // time when the dialer was first used
|
||||||
lookupRunning bool
|
lookupRunning bool
|
||||||
dialing map[enode.ID]connFlag
|
dialing map[enode.ID]connFlag
|
||||||
lookupBuf []*enode.Node // current discovery lookup results
|
lookupBuf []*enode.Node // current discovery lookup results
|
||||||
randomNodes []*enode.Node // filled from Table
|
randomNodes []*enode.Node // filled from Table
|
||||||
static map[enode.ID]*dialTask
|
static map[enode.ID]*dialTask
|
||||||
hist *dialHistory
|
hist expHeap
|
||||||
|
|
||||||
start time.Time // time when the dialer was first used
|
|
||||||
bootnodes []*enode.Node // default dials when there are no peers
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type discoverTable interface {
|
type discoverTable interface {
|
||||||
@ -91,15 +91,6 @@ type discoverTable interface {
|
|||||||
ReadRandomNodes([]*enode.Node) int
|
ReadRandomNodes([]*enode.Node) int
|
||||||
}
|
}
|
||||||
|
|
||||||
// the dial history remembers recent dials.
|
|
||||||
type dialHistory []pastDial
|
|
||||||
|
|
||||||
// pastDial is an entry in the dial history.
|
|
||||||
type pastDial struct {
|
|
||||||
id enode.ID
|
|
||||||
exp time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type task interface {
|
type task interface {
|
||||||
Do(*Server)
|
Do(*Server)
|
||||||
}
|
}
|
||||||
@ -126,20 +117,23 @@ type waitExpireTask struct {
|
|||||||
time.Duration
|
time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
|
func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate {
|
||||||
s := &dialstate{
|
s := &dialstate{
|
||||||
maxDynDials: maxdyn,
|
maxDynDials: maxdyn,
|
||||||
ntab: ntab,
|
ntab: ntab,
|
||||||
self: self,
|
self: self,
|
||||||
netrestrict: netrestrict,
|
netrestrict: cfg.NetRestrict,
|
||||||
|
log: cfg.Logger,
|
||||||
static: make(map[enode.ID]*dialTask),
|
static: make(map[enode.ID]*dialTask),
|
||||||
dialing: make(map[enode.ID]connFlag),
|
dialing: make(map[enode.ID]connFlag),
|
||||||
bootnodes: make([]*enode.Node, len(bootnodes)),
|
bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)),
|
||||||
randomNodes: make([]*enode.Node, maxdyn/2),
|
randomNodes: make([]*enode.Node, maxdyn/2),
|
||||||
hist: new(dialHistory),
|
|
||||||
}
|
}
|
||||||
copy(s.bootnodes, bootnodes)
|
copy(s.bootnodes, cfg.BootstrapNodes)
|
||||||
for _, n := range static {
|
if s.log == nil {
|
||||||
|
s.log = log.Root()
|
||||||
|
}
|
||||||
|
for _, n := range cfg.StaticNodes {
|
||||||
s.addStatic(n)
|
s.addStatic(n)
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
@ -154,9 +148,6 @@ func (s *dialstate) addStatic(n *enode.Node) {
|
|||||||
func (s *dialstate) removeStatic(n *enode.Node) {
|
func (s *dialstate) removeStatic(n *enode.Node) {
|
||||||
// This removes a task so future attempts to connect will not be made.
|
// This removes a task so future attempts to connect will not be made.
|
||||||
delete(s.static, n.ID())
|
delete(s.static, n.ID())
|
||||||
// This removes a previous dial timestamp so that application
|
|
||||||
// can force a server to reconnect with chosen peer immediately.
|
|
||||||
s.hist.remove(n.ID())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task {
|
func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task {
|
||||||
@ -167,7 +158,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
|
|||||||
var newtasks []task
|
var newtasks []task
|
||||||
addDial := func(flag connFlag, n *enode.Node) bool {
|
addDial := func(flag connFlag, n *enode.Node) bool {
|
||||||
if err := s.checkDial(n, peers); err != nil {
|
if err := s.checkDial(n, peers); err != nil {
|
||||||
log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err)
|
s.log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
s.dialing[n.ID()] = flag
|
s.dialing[n.ID()] = flag
|
||||||
@ -196,7 +187,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
|
|||||||
err := s.checkDial(t.dest, peers)
|
err := s.checkDial(t.dest, peers)
|
||||||
switch err {
|
switch err {
|
||||||
case errNotWhitelisted, errSelf:
|
case errNotWhitelisted, errSelf:
|
||||||
log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err)
|
s.log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err)
|
||||||
delete(s.static, t.dest.ID())
|
delete(s.static, t.dest.ID())
|
||||||
case nil:
|
case nil:
|
||||||
s.dialing[id] = t.flags
|
s.dialing[id] = t.flags
|
||||||
@ -246,7 +237,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
|
|||||||
// This should prevent cases where the dialer logic is not ticked
|
// This should prevent cases where the dialer logic is not ticked
|
||||||
// because there are no pending events.
|
// because there are no pending events.
|
||||||
if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
|
if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
|
||||||
t := &waitExpireTask{s.hist.min().exp.Sub(now)}
|
t := &waitExpireTask{s.hist.nextExpiry().Sub(now)}
|
||||||
newtasks = append(newtasks, t)
|
newtasks = append(newtasks, t)
|
||||||
}
|
}
|
||||||
return newtasks
|
return newtasks
|
||||||
@ -271,7 +262,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
|
|||||||
return errSelf
|
return errSelf
|
||||||
case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()):
|
case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()):
|
||||||
return errNotWhitelisted
|
return errNotWhitelisted
|
||||||
case s.hist.contains(n.ID()):
|
case s.hist.contains(string(n.ID().Bytes())):
|
||||||
return errRecentlyDialed
|
return errRecentlyDialed
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -280,7 +271,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
|
|||||||
func (s *dialstate) taskDone(t task, now time.Time) {
|
func (s *dialstate) taskDone(t task, now time.Time) {
|
||||||
switch t := t.(type) {
|
switch t := t.(type) {
|
||||||
case *dialTask:
|
case *dialTask:
|
||||||
s.hist.add(t.dest.ID(), now.Add(dialHistoryExpiration))
|
s.hist.add(string(t.dest.ID().Bytes()), now.Add(dialHistoryExpiration))
|
||||||
delete(s.dialing, t.dest.ID())
|
delete(s.dialing, t.dest.ID())
|
||||||
case *discoverTask:
|
case *discoverTask:
|
||||||
s.lookupRunning = false
|
s.lookupRunning = false
|
||||||
@ -296,7 +287,7 @@ func (t *dialTask) Do(srv *Server) {
|
|||||||
}
|
}
|
||||||
err := t.dial(srv, t.dest)
|
err := t.dial(srv, t.dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Trace("Dial error", "task", t, "err", err)
|
srv.log.Trace("Dial error", "task", t, "err", err)
|
||||||
// Try resolving the ID of static nodes if dialing failed.
|
// Try resolving the ID of static nodes if dialing failed.
|
||||||
if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
|
if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
|
||||||
if t.resolve(srv) {
|
if t.resolve(srv) {
|
||||||
@ -314,7 +305,7 @@ func (t *dialTask) Do(srv *Server) {
|
|||||||
// The backoff delay resets when the node is found.
|
// The backoff delay resets when the node is found.
|
||||||
func (t *dialTask) resolve(srv *Server) bool {
|
func (t *dialTask) resolve(srv *Server) bool {
|
||||||
if srv.ntab == nil {
|
if srv.ntab == nil {
|
||||||
log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
|
srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if t.resolveDelay == 0 {
|
if t.resolveDelay == 0 {
|
||||||
@ -330,13 +321,13 @@ func (t *dialTask) resolve(srv *Server) bool {
|
|||||||
if t.resolveDelay > maxResolveDelay {
|
if t.resolveDelay > maxResolveDelay {
|
||||||
t.resolveDelay = maxResolveDelay
|
t.resolveDelay = maxResolveDelay
|
||||||
}
|
}
|
||||||
log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay)
|
srv.log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// The node was found.
|
// The node was found.
|
||||||
t.resolveDelay = initialResolveDelay
|
t.resolveDelay = initialResolveDelay
|
||||||
t.dest = resolved
|
t.dest = resolved
|
||||||
log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
|
srv.log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -385,49 +376,3 @@ func (t waitExpireTask) Do(*Server) {
|
|||||||
func (t waitExpireTask) String() string {
|
func (t waitExpireTask) String() string {
|
||||||
return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
|
return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use only these methods to access or modify dialHistory.
|
|
||||||
func (h dialHistory) min() pastDial {
|
|
||||||
return h[0]
|
|
||||||
}
|
|
||||||
func (h *dialHistory) add(id enode.ID, exp time.Time) {
|
|
||||||
heap.Push(h, pastDial{id, exp})
|
|
||||||
|
|
||||||
}
|
|
||||||
func (h *dialHistory) remove(id enode.ID) bool {
|
|
||||||
for i, v := range *h {
|
|
||||||
if v.id == id {
|
|
||||||
heap.Remove(h, i)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
func (h dialHistory) contains(id enode.ID) bool {
|
|
||||||
for _, v := range h {
|
|
||||||
if v.id == id {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
func (h *dialHistory) expire(now time.Time) {
|
|
||||||
for h.Len() > 0 && h.min().exp.Before(now) {
|
|
||||||
heap.Pop(h)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// heap.Interface boilerplate
|
|
||||||
func (h dialHistory) Len() int { return len(h) }
|
|
||||||
func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
|
|
||||||
func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
|
||||||
func (h *dialHistory) Push(x interface{}) {
|
|
||||||
*h = append(*h, x.(pastDial))
|
|
||||||
}
|
|
||||||
func (h *dialHistory) Pop() interface{} {
|
|
||||||
old := *h
|
|
||||||
n := len(old)
|
|
||||||
x := old[n-1]
|
|
||||||
*h = old[0 : n-1]
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
138
p2p/dial_test.go
138
p2p/dial_test.go
@ -20,10 +20,13 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
|
"github.com/ethereum/go-ethereum/internal/testlog"
|
||||||
|
"github.com/ethereum/go-ethereum/log"
|
||||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||||
"github.com/ethereum/go-ethereum/p2p/enr"
|
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||||
"github.com/ethereum/go-ethereum/p2p/netutil"
|
"github.com/ethereum/go-ethereum/p2p/netutil"
|
||||||
@ -67,10 +70,10 @@ func runDialTest(t *testing.T, test dialtest) {
|
|||||||
|
|
||||||
new := test.init.newTasks(running, pm(round.peers), vtime)
|
new := test.init.newTasks(running, pm(round.peers), vtime)
|
||||||
if !sametasks(new, round.new) {
|
if !sametasks(new, round.new) {
|
||||||
t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n",
|
t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v",
|
||||||
i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
|
i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
|
||||||
}
|
}
|
||||||
t.Log("tasks:", spew.Sdump(new))
|
t.Logf("round %d new tasks: %s", i, strings.TrimSpace(spew.Sdump(new)))
|
||||||
|
|
||||||
// Time advances by 16 seconds on every round.
|
// Time advances by 16 seconds on every round.
|
||||||
vtime = vtime.Add(16 * time.Second)
|
vtime = vtime.Add(16 * time.Second)
|
||||||
@ -88,8 +91,9 @@ func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t)
|
|||||||
|
|
||||||
// This test checks that dynamic dials are launched from discovery results.
|
// This test checks that dynamic dials are launched from discovery results.
|
||||||
func TestDialStateDynDial(t *testing.T) {
|
func TestDialStateDynDial(t *testing.T) {
|
||||||
|
config := &Config{Logger: testlog.Logger(t, log.LvlTrace)}
|
||||||
runDialTest(t, dialtest{
|
runDialTest(t, dialtest{
|
||||||
init: newDialState(enode.ID{}, nil, nil, fakeTable{}, 5, nil),
|
init: newDialState(enode.ID{}, fakeTable{}, 5, config),
|
||||||
rounds: []round{
|
rounds: []round{
|
||||||
// A discovery query is launched.
|
// A discovery query is launched.
|
||||||
{
|
{
|
||||||
@ -153,7 +157,7 @@ func TestDialStateDynDial(t *testing.T) {
|
|||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
||||||
},
|
},
|
||||||
new: []task{
|
new: []task{
|
||||||
&waitExpireTask{Duration: 14 * time.Second},
|
&waitExpireTask{Duration: 19 * time.Second},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
// In this round, the peer with id 2 drops off. The query
|
// In this round, the peer with id 2 drops off. The query
|
||||||
@ -223,10 +227,13 @@ func TestDialStateDynDial(t *testing.T) {
|
|||||||
|
|
||||||
// Tests that bootnodes are dialed if no peers are connectd, but not otherwise.
|
// Tests that bootnodes are dialed if no peers are connectd, but not otherwise.
|
||||||
func TestDialStateDynDialBootnode(t *testing.T) {
|
func TestDialStateDynDialBootnode(t *testing.T) {
|
||||||
bootnodes := []*enode.Node{
|
config := &Config{
|
||||||
|
BootstrapNodes: []*enode.Node{
|
||||||
newNode(uintID(1), nil),
|
newNode(uintID(1), nil),
|
||||||
newNode(uintID(2), nil),
|
newNode(uintID(2), nil),
|
||||||
newNode(uintID(3), nil),
|
newNode(uintID(3), nil),
|
||||||
|
},
|
||||||
|
Logger: testlog.Logger(t, log.LvlTrace),
|
||||||
}
|
}
|
||||||
table := fakeTable{
|
table := fakeTable{
|
||||||
newNode(uintID(4), nil),
|
newNode(uintID(4), nil),
|
||||||
@ -236,7 +243,7 @@ func TestDialStateDynDialBootnode(t *testing.T) {
|
|||||||
newNode(uintID(8), nil),
|
newNode(uintID(8), nil),
|
||||||
}
|
}
|
||||||
runDialTest(t, dialtest{
|
runDialTest(t, dialtest{
|
||||||
init: newDialState(enode.ID{}, nil, bootnodes, table, 5, nil),
|
init: newDialState(enode.ID{}, table, 5, config),
|
||||||
rounds: []round{
|
rounds: []round{
|
||||||
// 2 dynamic dials attempted, bootnodes pending fallback interval
|
// 2 dynamic dials attempted, bootnodes pending fallback interval
|
||||||
{
|
{
|
||||||
@ -259,25 +266,24 @@ func TestDialStateDynDialBootnode(t *testing.T) {
|
|||||||
{
|
{
|
||||||
new: []task{
|
new: []task{
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
|
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
|
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
// No dials succeed, 2nd bootnode is attempted
|
// No dials succeed, 2nd bootnode is attempted
|
||||||
{
|
{
|
||||||
done: []task{
|
done: []task{
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
|
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
|
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
|
||||||
},
|
},
|
||||||
new: []task{
|
new: []task{
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
|
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
|
||||||
|
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
|
||||||
|
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
// No dials succeed, 3rd bootnode is attempted
|
// No dials succeed, 3rd bootnode is attempted
|
||||||
{
|
{
|
||||||
done: []task{
|
done: []task{
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
|
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
|
||||||
|
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
||||||
},
|
},
|
||||||
new: []task{
|
new: []task{
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
|
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
|
||||||
@ -288,21 +294,19 @@ func TestDialStateDynDialBootnode(t *testing.T) {
|
|||||||
done: []task{
|
done: []task{
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
|
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
|
||||||
},
|
},
|
||||||
new: []task{
|
new: []task{},
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
|
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
|
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
// Random dial succeeds, no more bootnodes are attempted
|
// Random dial succeeds, no more bootnodes are attempted
|
||||||
{
|
{
|
||||||
|
new: []task{
|
||||||
|
&waitExpireTask{3 * time.Second},
|
||||||
|
},
|
||||||
peers: []*Peer{
|
peers: []*Peer{
|
||||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
|
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
|
||||||
},
|
},
|
||||||
done: []task{
|
done: []task{
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
|
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
|
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
|
||||||
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -324,7 +328,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
runDialTest(t, dialtest{
|
runDialTest(t, dialtest{
|
||||||
init: newDialState(enode.ID{}, nil, nil, table, 10, nil),
|
init: newDialState(enode.ID{}, table, 10, &Config{Logger: testlog.Logger(t, log.LvlTrace)}),
|
||||||
rounds: []round{
|
rounds: []round{
|
||||||
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
|
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
|
||||||
{
|
{
|
||||||
@ -430,7 +434,7 @@ func TestDialStateNetRestrict(t *testing.T) {
|
|||||||
restrict.Add("127.0.2.0/24")
|
restrict.Add("127.0.2.0/24")
|
||||||
|
|
||||||
runDialTest(t, dialtest{
|
runDialTest(t, dialtest{
|
||||||
init: newDialState(enode.ID{}, nil, nil, table, 10, restrict),
|
init: newDialState(enode.ID{}, table, 10, &Config{NetRestrict: restrict}),
|
||||||
rounds: []round{
|
rounds: []round{
|
||||||
{
|
{
|
||||||
new: []task{
|
new: []task{
|
||||||
@ -444,16 +448,18 @@ func TestDialStateNetRestrict(t *testing.T) {
|
|||||||
|
|
||||||
// This test checks that static dials are launched.
|
// This test checks that static dials are launched.
|
||||||
func TestDialStateStaticDial(t *testing.T) {
|
func TestDialStateStaticDial(t *testing.T) {
|
||||||
wantStatic := []*enode.Node{
|
config := &Config{
|
||||||
|
StaticNodes: []*enode.Node{
|
||||||
newNode(uintID(1), nil),
|
newNode(uintID(1), nil),
|
||||||
newNode(uintID(2), nil),
|
newNode(uintID(2), nil),
|
||||||
newNode(uintID(3), nil),
|
newNode(uintID(3), nil),
|
||||||
newNode(uintID(4), nil),
|
newNode(uintID(4), nil),
|
||||||
newNode(uintID(5), nil),
|
newNode(uintID(5), nil),
|
||||||
|
},
|
||||||
|
Logger: testlog.Logger(t, log.LvlTrace),
|
||||||
}
|
}
|
||||||
|
|
||||||
runDialTest(t, dialtest{
|
runDialTest(t, dialtest{
|
||||||
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
|
init: newDialState(enode.ID{}, fakeTable{}, 0, config),
|
||||||
rounds: []round{
|
rounds: []round{
|
||||||
// Static dials are launched for the nodes that
|
// Static dials are launched for the nodes that
|
||||||
// aren't yet connected.
|
// aren't yet connected.
|
||||||
@ -495,7 +501,7 @@ func TestDialStateStaticDial(t *testing.T) {
|
|||||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)},
|
&dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)},
|
||||||
},
|
},
|
||||||
new: []task{
|
new: []task{
|
||||||
&waitExpireTask{Duration: 14 * time.Second},
|
&waitExpireTask{Duration: 19 * time.Second},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
// Wait a round for dial history to expire, no new tasks should spawn.
|
// Wait a round for dial history to expire, no new tasks should spawn.
|
||||||
@ -511,6 +517,9 @@ func TestDialStateStaticDial(t *testing.T) {
|
|||||||
// If a static node is dropped, it should be immediately redialed,
|
// If a static node is dropped, it should be immediately redialed,
|
||||||
// irrespective whether it was originally static or dynamic.
|
// irrespective whether it was originally static or dynamic.
|
||||||
{
|
{
|
||||||
|
done: []task{
|
||||||
|
&waitExpireTask{Duration: 19 * time.Second},
|
||||||
|
},
|
||||||
peers: []*Peer{
|
peers: []*Peer{
|
||||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
|
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
|
||||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
|
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
|
||||||
@ -518,67 +527,24 @@ func TestDialStateStaticDial(t *testing.T) {
|
|||||||
},
|
},
|
||||||
new: []task{
|
new: []task{
|
||||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
|
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
|
||||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// This test checks that static peers will be redialed immediately if they were re-added to a static list.
|
|
||||||
func TestDialStaticAfterReset(t *testing.T) {
|
|
||||||
wantStatic := []*enode.Node{
|
|
||||||
newNode(uintID(1), nil),
|
|
||||||
newNode(uintID(2), nil),
|
|
||||||
}
|
|
||||||
|
|
||||||
rounds := []round{
|
|
||||||
// Static dials are launched for the nodes that aren't yet connected.
|
|
||||||
{
|
|
||||||
peers: nil,
|
|
||||||
new: []task{
|
|
||||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
|
|
||||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
// No new dial tasks, all peers are connected.
|
|
||||||
{
|
|
||||||
peers: []*Peer{
|
|
||||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
|
|
||||||
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
|
|
||||||
},
|
|
||||||
done: []task{
|
|
||||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
|
|
||||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
|
|
||||||
},
|
|
||||||
new: []task{
|
|
||||||
&waitExpireTask{Duration: 30 * time.Second},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
dTest := dialtest{
|
|
||||||
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
|
|
||||||
rounds: rounds,
|
|
||||||
}
|
|
||||||
runDialTest(t, dTest)
|
|
||||||
for _, n := range wantStatic {
|
|
||||||
dTest.init.removeStatic(n)
|
|
||||||
dTest.init.addStatic(n)
|
|
||||||
}
|
|
||||||
// without removing peers they will be considered recently dialed
|
|
||||||
runDialTest(t, dTest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// This test checks that past dials are not retried for some time.
|
// This test checks that past dials are not retried for some time.
|
||||||
func TestDialStateCache(t *testing.T) {
|
func TestDialStateCache(t *testing.T) {
|
||||||
wantStatic := []*enode.Node{
|
config := &Config{
|
||||||
|
StaticNodes: []*enode.Node{
|
||||||
newNode(uintID(1), nil),
|
newNode(uintID(1), nil),
|
||||||
newNode(uintID(2), nil),
|
newNode(uintID(2), nil),
|
||||||
newNode(uintID(3), nil),
|
newNode(uintID(3), nil),
|
||||||
|
},
|
||||||
|
Logger: testlog.Logger(t, log.LvlTrace),
|
||||||
}
|
}
|
||||||
|
|
||||||
runDialTest(t, dialtest{
|
runDialTest(t, dialtest{
|
||||||
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
|
init: newDialState(enode.ID{}, fakeTable{}, 0, config),
|
||||||
rounds: []round{
|
rounds: []round{
|
||||||
// Static dials are launched for the nodes that
|
// Static dials are launched for the nodes that
|
||||||
// aren't yet connected.
|
// aren't yet connected.
|
||||||
@ -606,28 +572,37 @@ func TestDialStateCache(t *testing.T) {
|
|||||||
// entry to expire.
|
// entry to expire.
|
||||||
{
|
{
|
||||||
peers: []*Peer{
|
peers: []*Peer{
|
||||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
|
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
|
||||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
|
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
|
||||||
},
|
},
|
||||||
done: []task{
|
done: []task{
|
||||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
|
&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
|
||||||
},
|
},
|
||||||
new: []task{
|
new: []task{
|
||||||
&waitExpireTask{Duration: 14 * time.Second},
|
&waitExpireTask{Duration: 19 * time.Second},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
// Still waiting for node 3's entry to expire in the cache.
|
// Still waiting for node 3's entry to expire in the cache.
|
||||||
{
|
{
|
||||||
peers: []*Peer{
|
peers: []*Peer{
|
||||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
|
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
|
||||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
|
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
|
||||||
|
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
// The cache entry for node 3 has expired and is retried.
|
// The cache entry for node 3 has expired and is retried.
|
||||||
{
|
{
|
||||||
|
done: []task{
|
||||||
|
&waitExpireTask{Duration: 19 * time.Second},
|
||||||
|
},
|
||||||
peers: []*Peer{
|
peers: []*Peer{
|
||||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
|
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
|
||||||
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
|
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
|
||||||
},
|
},
|
||||||
new: []task{
|
new: []task{
|
||||||
&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
|
&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
|
||||||
@ -638,9 +613,13 @@ func TestDialStateCache(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDialResolve(t *testing.T) {
|
func TestDialResolve(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
Logger: testlog.Logger(t, log.LvlTrace),
|
||||||
|
Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}},
|
||||||
|
}
|
||||||
resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
|
resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
|
||||||
table := &resolveMock{answer: resolved}
|
table := &resolveMock{answer: resolved}
|
||||||
state := newDialState(enode.ID{}, nil, nil, table, 0, nil)
|
state := newDialState(enode.ID{}, table, 0, config)
|
||||||
|
|
||||||
// Check that the task is generated with an incomplete ID.
|
// Check that the task is generated with an incomplete ID.
|
||||||
dest := newNode(uintID(1), nil)
|
dest := newNode(uintID(1), nil)
|
||||||
@ -651,8 +630,7 @@ func TestDialResolve(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now run the task, it should resolve the ID once.
|
// Now run the task, it should resolve the ID once.
|
||||||
config := Config{Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}}
|
srv := &Server{ntab: table, log: config.Logger, Config: *config}
|
||||||
srv := &Server{ntab: table, Config: config}
|
|
||||||
tasks[0].Do(srv)
|
tasks[0].Do(srv)
|
||||||
if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) {
|
if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) {
|
||||||
t.Fatalf("wrong resolve calls, got %v", table.resolveCalls)
|
t.Fatalf("wrong resolve calls, got %v", table.resolveCalls)
|
||||||
|
33
p2p/netutil/addrutil.go
Normal file
33
p2p/netutil/addrutil.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
// Copyright 2016 The go-ethereum Authors
|
||||||
|
// This file is part of the go-ethereum library.
|
||||||
|
//
|
||||||
|
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Lesser General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Lesser General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Lesser General Public License
|
||||||
|
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package netutil
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
// AddrIP gets the IP address contained in addr. It returns nil if no address is present.
|
||||||
|
func AddrIP(addr net.Addr) net.IP {
|
||||||
|
switch a := addr.(type) {
|
||||||
|
case *net.IPAddr:
|
||||||
|
return a.IP
|
||||||
|
case *net.TCPAddr:
|
||||||
|
return a.IP
|
||||||
|
case *net.UDPAddr:
|
||||||
|
return a.IP
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
@ -120,7 +120,7 @@ func NewPeer(id enode.ID, name string, caps []Cap) *Peer {
|
|||||||
pipe, _ := net.Pipe()
|
pipe, _ := net.Pipe()
|
||||||
node := enode.SignNull(new(enr.Record), id)
|
node := enode.SignNull(new(enr.Record), id)
|
||||||
conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name}
|
conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name}
|
||||||
peer := newPeer(conn, nil)
|
peer := newPeer(log.Root(), conn, nil)
|
||||||
close(peer.closed) // ensures Disconnect doesn't block
|
close(peer.closed) // ensures Disconnect doesn't block
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
@ -176,7 +176,7 @@ func (p *Peer) Inbound() bool {
|
|||||||
return p.rw.is(inboundConn)
|
return p.rw.is(inboundConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPeer(conn *conn, protocols []Protocol) *Peer {
|
func newPeer(log log.Logger, conn *conn, protocols []Protocol) *Peer {
|
||||||
protomap := matchProtocols(protocols, conn.caps, conn)
|
protomap := matchProtocols(protocols, conn.caps, conn)
|
||||||
p := &Peer{
|
p := &Peer{
|
||||||
rw: conn,
|
rw: conn,
|
||||||
|
@ -24,6 +24,8 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
var discard = Protocol{
|
var discard = Protocol{
|
||||||
@ -52,7 +54,7 @@ func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) {
|
|||||||
c2.caps = append(c2.caps, p.cap())
|
c2.caps = append(c2.caps, p.cap())
|
||||||
}
|
}
|
||||||
|
|
||||||
peer := newPeer(c1, protos)
|
peer := newPeer(log.Root(), c1, protos)
|
||||||
errc := make(chan error, 1)
|
errc := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
_, err := peer.run()
|
_, err := peer.run()
|
||||||
|
174
p2p/server.go
174
p2p/server.go
@ -22,6 +22,7 @@ import (
|
|||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
@ -49,6 +50,9 @@ const (
|
|||||||
defaultMaxPendingPeers = 50
|
defaultMaxPendingPeers = 50
|
||||||
defaultDialRatio = 3
|
defaultDialRatio = 3
|
||||||
|
|
||||||
|
// This time limits inbound connection attempts per source IP.
|
||||||
|
inboundThrottleTime = 30 * time.Second
|
||||||
|
|
||||||
// Maximum time allowed for reading a complete message.
|
// Maximum time allowed for reading a complete message.
|
||||||
// This is effectively the amount of time a connection can be idle.
|
// This is effectively the amount of time a connection can be idle.
|
||||||
frameReadTimeout = 30 * time.Second
|
frameReadTimeout = 30 * time.Second
|
||||||
@ -158,6 +162,7 @@ type Server struct {
|
|||||||
// the whole protocol stack.
|
// the whole protocol stack.
|
||||||
newTransport func(net.Conn) transport
|
newTransport func(net.Conn) transport
|
||||||
newPeerHook func(*Peer)
|
newPeerHook func(*Peer)
|
||||||
|
listenFunc func(network, addr string) (net.Listener, error)
|
||||||
|
|
||||||
lock sync.Mutex // protects running
|
lock sync.Mutex // protects running
|
||||||
running bool
|
running bool
|
||||||
@ -167,24 +172,26 @@ type Server struct {
|
|||||||
ntab discoverTable
|
ntab discoverTable
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
ourHandshake *protoHandshake
|
ourHandshake *protoHandshake
|
||||||
lastLookup time.Time
|
|
||||||
DiscV5 *discv5.Network
|
DiscV5 *discv5.Network
|
||||||
|
loopWG sync.WaitGroup // loop, listenLoop
|
||||||
|
peerFeed event.Feed
|
||||||
|
log log.Logger
|
||||||
|
|
||||||
// These are for Peers, PeerCount (and nothing else).
|
// Channels into the run loop.
|
||||||
peerOp chan peerOpFunc
|
|
||||||
peerOpDone chan struct{}
|
|
||||||
|
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
addstatic chan *enode.Node
|
addstatic chan *enode.Node
|
||||||
removestatic chan *enode.Node
|
removestatic chan *enode.Node
|
||||||
addtrusted chan *enode.Node
|
addtrusted chan *enode.Node
|
||||||
removetrusted chan *enode.Node
|
removetrusted chan *enode.Node
|
||||||
posthandshake chan *conn
|
peerOp chan peerOpFunc
|
||||||
addpeer chan *conn
|
peerOpDone chan struct{}
|
||||||
delpeer chan peerDrop
|
delpeer chan peerDrop
|
||||||
loopWG sync.WaitGroup // loop, listenLoop
|
checkpointPostHandshake chan *conn
|
||||||
peerFeed event.Feed
|
checkpointAddPeer chan *conn
|
||||||
log log.Logger
|
|
||||||
|
// State of run loop and listenLoop.
|
||||||
|
lastLookup time.Time
|
||||||
|
inboundHistory expHeap
|
||||||
}
|
}
|
||||||
|
|
||||||
type peerOpFunc func(map[enode.ID]*Peer)
|
type peerOpFunc func(map[enode.ID]*Peer)
|
||||||
@ -415,7 +422,7 @@ func (srv *Server) Start() (err error) {
|
|||||||
srv.running = true
|
srv.running = true
|
||||||
srv.log = srv.Config.Logger
|
srv.log = srv.Config.Logger
|
||||||
if srv.log == nil {
|
if srv.log == nil {
|
||||||
srv.log = log.New()
|
srv.log = log.Root()
|
||||||
}
|
}
|
||||||
if srv.NoDial && srv.ListenAddr == "" {
|
if srv.NoDial && srv.ListenAddr == "" {
|
||||||
srv.log.Warn("P2P server will be useless, neither dialing nor listening")
|
srv.log.Warn("P2P server will be useless, neither dialing nor listening")
|
||||||
@ -428,13 +435,16 @@ func (srv *Server) Start() (err error) {
|
|||||||
if srv.newTransport == nil {
|
if srv.newTransport == nil {
|
||||||
srv.newTransport = newRLPX
|
srv.newTransport = newRLPX
|
||||||
}
|
}
|
||||||
|
if srv.listenFunc == nil {
|
||||||
|
srv.listenFunc = net.Listen
|
||||||
|
}
|
||||||
if srv.Dialer == nil {
|
if srv.Dialer == nil {
|
||||||
srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}}
|
srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}}
|
||||||
}
|
}
|
||||||
srv.quit = make(chan struct{})
|
srv.quit = make(chan struct{})
|
||||||
srv.addpeer = make(chan *conn)
|
|
||||||
srv.delpeer = make(chan peerDrop)
|
srv.delpeer = make(chan peerDrop)
|
||||||
srv.posthandshake = make(chan *conn)
|
srv.checkpointPostHandshake = make(chan *conn)
|
||||||
|
srv.checkpointAddPeer = make(chan *conn)
|
||||||
srv.addstatic = make(chan *enode.Node)
|
srv.addstatic = make(chan *enode.Node)
|
||||||
srv.removestatic = make(chan *enode.Node)
|
srv.removestatic = make(chan *enode.Node)
|
||||||
srv.addtrusted = make(chan *enode.Node)
|
srv.addtrusted = make(chan *enode.Node)
|
||||||
@ -455,7 +465,7 @@ func (srv *Server) Start() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dynPeers := srv.maxDialedConns()
|
dynPeers := srv.maxDialedConns()
|
||||||
dialer := newDialState(srv.localnode.ID(), srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
|
dialer := newDialState(srv.localnode.ID(), srv.ntab, dynPeers, &srv.Config)
|
||||||
srv.loopWG.Add(1)
|
srv.loopWG.Add(1)
|
||||||
go srv.run(dialer)
|
go srv.run(dialer)
|
||||||
return nil
|
return nil
|
||||||
@ -541,6 +551,7 @@ func (srv *Server) setupDiscovery() error {
|
|||||||
NetRestrict: srv.NetRestrict,
|
NetRestrict: srv.NetRestrict,
|
||||||
Bootnodes: srv.BootstrapNodes,
|
Bootnodes: srv.BootstrapNodes,
|
||||||
Unhandled: unhandled,
|
Unhandled: unhandled,
|
||||||
|
Log: srv.log,
|
||||||
}
|
}
|
||||||
ntab, err := discover.ListenUDP(conn, srv.localnode, cfg)
|
ntab, err := discover.ListenUDP(conn, srv.localnode, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -569,27 +580,28 @@ func (srv *Server) setupDiscovery() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) setupListening() error {
|
func (srv *Server) setupListening() error {
|
||||||
// Launch the TCP listener.
|
// Launch the listener.
|
||||||
listener, err := net.Listen("tcp", srv.ListenAddr)
|
listener, err := srv.listenFunc("tcp", srv.ListenAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
laddr := listener.Addr().(*net.TCPAddr)
|
|
||||||
srv.ListenAddr = laddr.String()
|
|
||||||
srv.listener = listener
|
srv.listener = listener
|
||||||
srv.localnode.Set(enr.TCP(laddr.Port))
|
srv.ListenAddr = listener.Addr().String()
|
||||||
|
|
||||||
srv.loopWG.Add(1)
|
// Update the local node record and map the TCP listening port if NAT is configured.
|
||||||
go srv.listenLoop()
|
if tcp, ok := listener.Addr().(*net.TCPAddr); ok {
|
||||||
|
srv.localnode.Set(enr.TCP(tcp.Port))
|
||||||
// Map the TCP listening port if NAT is configured.
|
if !tcp.IP.IsLoopback() && srv.NAT != nil {
|
||||||
if !laddr.IP.IsLoopback() && srv.NAT != nil {
|
|
||||||
srv.loopWG.Add(1)
|
srv.loopWG.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p")
|
nat.Map(srv.NAT, srv.quit, "tcp", tcp.Port, tcp.Port, "ethereum p2p")
|
||||||
srv.loopWG.Done()
|
srv.loopWG.Done()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.loopWG.Add(1)
|
||||||
|
go srv.listenLoop()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -657,12 +669,14 @@ running:
|
|||||||
case <-srv.quit:
|
case <-srv.quit:
|
||||||
// The server was stopped. Run the cleanup logic.
|
// The server was stopped. Run the cleanup logic.
|
||||||
break running
|
break running
|
||||||
|
|
||||||
case n := <-srv.addstatic:
|
case n := <-srv.addstatic:
|
||||||
// This channel is used by AddPeer to add to the
|
// This channel is used by AddPeer to add to the
|
||||||
// ephemeral static peer list. Add it to the dialer,
|
// ephemeral static peer list. Add it to the dialer,
|
||||||
// it will keep the node connected.
|
// it will keep the node connected.
|
||||||
srv.log.Trace("Adding static node", "node", n)
|
srv.log.Trace("Adding static node", "node", n)
|
||||||
dialstate.addStatic(n)
|
dialstate.addStatic(n)
|
||||||
|
|
||||||
case n := <-srv.removestatic:
|
case n := <-srv.removestatic:
|
||||||
// This channel is used by RemovePeer to send a
|
// This channel is used by RemovePeer to send a
|
||||||
// disconnect request to a peer and begin the
|
// disconnect request to a peer and begin the
|
||||||
@ -672,6 +686,7 @@ running:
|
|||||||
if p, ok := peers[n.ID()]; ok {
|
if p, ok := peers[n.ID()]; ok {
|
||||||
p.Disconnect(DiscRequested)
|
p.Disconnect(DiscRequested)
|
||||||
}
|
}
|
||||||
|
|
||||||
case n := <-srv.addtrusted:
|
case n := <-srv.addtrusted:
|
||||||
// This channel is used by AddTrustedPeer to add an enode
|
// This channel is used by AddTrustedPeer to add an enode
|
||||||
// to the trusted node set.
|
// to the trusted node set.
|
||||||
@ -681,6 +696,7 @@ running:
|
|||||||
if p, ok := peers[n.ID()]; ok {
|
if p, ok := peers[n.ID()]; ok {
|
||||||
p.rw.set(trustedConn, true)
|
p.rw.set(trustedConn, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
case n := <-srv.removetrusted:
|
case n := <-srv.removetrusted:
|
||||||
// This channel is used by RemoveTrustedPeer to remove an enode
|
// This channel is used by RemoveTrustedPeer to remove an enode
|
||||||
// from the trusted node set.
|
// from the trusted node set.
|
||||||
@ -691,10 +707,12 @@ running:
|
|||||||
if p, ok := peers[n.ID()]; ok {
|
if p, ok := peers[n.ID()]; ok {
|
||||||
p.rw.set(trustedConn, false)
|
p.rw.set(trustedConn, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
case op := <-srv.peerOp:
|
case op := <-srv.peerOp:
|
||||||
// This channel is used by Peers and PeerCount.
|
// This channel is used by Peers and PeerCount.
|
||||||
op(peers)
|
op(peers)
|
||||||
srv.peerOpDone <- struct{}{}
|
srv.peerOpDone <- struct{}{}
|
||||||
|
|
||||||
case t := <-taskdone:
|
case t := <-taskdone:
|
||||||
// A task got done. Tell dialstate about it so it
|
// A task got done. Tell dialstate about it so it
|
||||||
// can update its state and remove it from the active
|
// can update its state and remove it from the active
|
||||||
@ -702,7 +720,8 @@ running:
|
|||||||
srv.log.Trace("Dial task done", "task", t)
|
srv.log.Trace("Dial task done", "task", t)
|
||||||
dialstate.taskDone(t, time.Now())
|
dialstate.taskDone(t, time.Now())
|
||||||
delTask(t)
|
delTask(t)
|
||||||
case c := <-srv.posthandshake:
|
|
||||||
|
case c := <-srv.checkpointPostHandshake:
|
||||||
// A connection has passed the encryption handshake so
|
// A connection has passed the encryption handshake so
|
||||||
// the remote identity is known (but hasn't been verified yet).
|
// the remote identity is known (but hasn't been verified yet).
|
||||||
if trusted[c.node.ID()] {
|
if trusted[c.node.ID()] {
|
||||||
@ -710,18 +729,15 @@ running:
|
|||||||
c.flags |= trustedConn
|
c.flags |= trustedConn
|
||||||
}
|
}
|
||||||
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
|
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
|
||||||
select {
|
c.cont <- srv.postHandshakeChecks(peers, inboundCount, c)
|
||||||
case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c):
|
|
||||||
case <-srv.quit:
|
case c := <-srv.checkpointAddPeer:
|
||||||
break running
|
|
||||||
}
|
|
||||||
case c := <-srv.addpeer:
|
|
||||||
// At this point the connection is past the protocol handshake.
|
// At this point the connection is past the protocol handshake.
|
||||||
// Its capabilities are known and the remote identity is verified.
|
// Its capabilities are known and the remote identity is verified.
|
||||||
err := srv.protoHandshakeChecks(peers, inboundCount, c)
|
err := srv.addPeerChecks(peers, inboundCount, c)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// The handshakes are done and it passed all checks.
|
// The handshakes are done and it passed all checks.
|
||||||
p := newPeer(c, srv.Protocols)
|
p := newPeer(srv.log, c, srv.Protocols)
|
||||||
// If message events are enabled, pass the peerFeed
|
// If message events are enabled, pass the peerFeed
|
||||||
// to the peer
|
// to the peer
|
||||||
if srv.EnableMsgEvents {
|
if srv.EnableMsgEvents {
|
||||||
@ -738,11 +754,8 @@ running:
|
|||||||
// The dialer logic relies on the assumption that
|
// The dialer logic relies on the assumption that
|
||||||
// dial tasks complete after the peer has been added or
|
// dial tasks complete after the peer has been added or
|
||||||
// discarded. Unblock the task last.
|
// discarded. Unblock the task last.
|
||||||
select {
|
c.cont <- err
|
||||||
case c.cont <- err:
|
|
||||||
case <-srv.quit:
|
|
||||||
break running
|
|
||||||
}
|
|
||||||
case pd := <-srv.delpeer:
|
case pd := <-srv.delpeer:
|
||||||
// A peer disconnected.
|
// A peer disconnected.
|
||||||
d := common.PrettyDuration(mclock.Now() - pd.created)
|
d := common.PrettyDuration(mclock.Now() - pd.created)
|
||||||
@ -777,17 +790,7 @@ running:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) protoHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
|
func (srv *Server) postHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
|
||||||
// Drop connections with no matching protocols.
|
|
||||||
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
|
|
||||||
return DiscUselessPeer
|
|
||||||
}
|
|
||||||
// Repeat the encryption handshake checks because the
|
|
||||||
// peer set might have changed between the handshakes.
|
|
||||||
return srv.encHandshakeChecks(peers, inboundCount, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
|
|
||||||
switch {
|
switch {
|
||||||
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
|
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
|
||||||
return DiscTooManyPeers
|
return DiscTooManyPeers
|
||||||
@ -802,9 +805,20 @@ func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (srv *Server) addPeerChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
|
||||||
|
// Drop connections with no matching protocols.
|
||||||
|
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
|
||||||
|
return DiscUselessPeer
|
||||||
|
}
|
||||||
|
// Repeat the post-handshake checks because the
|
||||||
|
// peer set might have changed since those checks were performed.
|
||||||
|
return srv.postHandshakeChecks(peers, inboundCount, c)
|
||||||
|
}
|
||||||
|
|
||||||
func (srv *Server) maxInboundConns() int {
|
func (srv *Server) maxInboundConns() int {
|
||||||
return srv.MaxPeers - srv.maxDialedConns()
|
return srv.MaxPeers - srv.maxDialedConns()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) maxDialedConns() int {
|
func (srv *Server) maxDialedConns() int {
|
||||||
if srv.NoDiscovery || srv.NoDial {
|
if srv.NoDiscovery || srv.NoDial {
|
||||||
return 0
|
return 0
|
||||||
@ -832,7 +846,7 @@ func (srv *Server) listenLoop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Wait for a handshake slot before accepting.
|
// Wait for a free slot before accepting.
|
||||||
<-slots
|
<-slots
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -851,21 +865,16 @@ func (srv *Server) listenLoop() {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject connections that do not match NetRestrict.
|
remoteIP := netutil.AddrIP(fd.RemoteAddr())
|
||||||
if srv.NetRestrict != nil {
|
if err := srv.checkInboundConn(fd, remoteIP); err != nil {
|
||||||
if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) {
|
srv.log.Debug("Rejected inbound connnection", "addr", fd.RemoteAddr(), "err", err)
|
||||||
srv.log.Debug("Rejected conn (not whitelisted in NetRestrict)", "addr", fd.RemoteAddr())
|
|
||||||
fd.Close()
|
fd.Close()
|
||||||
slots <- struct{}{}
|
slots <- struct{}{}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if remoteIP != nil {
|
||||||
|
fd = newMeteredConn(fd, true, remoteIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ip net.IP
|
|
||||||
if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok {
|
|
||||||
ip = tcp.IP
|
|
||||||
}
|
|
||||||
fd = newMeteredConn(fd, true, ip)
|
|
||||||
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
|
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
|
||||||
go func() {
|
go func() {
|
||||||
srv.SetupConn(fd, inboundConn, nil)
|
srv.SetupConn(fd, inboundConn, nil)
|
||||||
@ -874,6 +883,22 @@ func (srv *Server) listenLoop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (srv *Server) checkInboundConn(fd net.Conn, remoteIP net.IP) error {
|
||||||
|
if remoteIP != nil {
|
||||||
|
// Reject connections that do not match NetRestrict.
|
||||||
|
if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) {
|
||||||
|
return fmt.Errorf("not whitelisted in NetRestrict")
|
||||||
|
}
|
||||||
|
// Reject Internet peers that try too often.
|
||||||
|
srv.inboundHistory.expire(time.Now())
|
||||||
|
if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) {
|
||||||
|
return fmt.Errorf("too many attempts")
|
||||||
|
}
|
||||||
|
srv.inboundHistory.add(remoteIP.String(), time.Now().Add(inboundThrottleTime))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetupConn runs the handshakes and attempts to add the connection
|
// SetupConn runs the handshakes and attempts to add the connection
|
||||||
// as a peer. It returns when the connection has been added as a peer
|
// as a peer. It returns when the connection has been added as a peer
|
||||||
// or the handshakes have failed.
|
// or the handshakes have failed.
|
||||||
@ -895,6 +920,7 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
|
|||||||
if !running {
|
if !running {
|
||||||
return errServerStopped
|
return errServerStopped
|
||||||
}
|
}
|
||||||
|
|
||||||
// If dialing, figure out the remote public key.
|
// If dialing, figure out the remote public key.
|
||||||
var dialPubkey *ecdsa.PublicKey
|
var dialPubkey *ecdsa.PublicKey
|
||||||
if dialDest != nil {
|
if dialDest != nil {
|
||||||
@ -903,7 +929,8 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
|
|||||||
return errors.New("dial destination doesn't have a secp256k1 public key")
|
return errors.New("dial destination doesn't have a secp256k1 public key")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Run the encryption handshake.
|
|
||||||
|
// Run the RLPx handshake.
|
||||||
remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey)
|
remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err)
|
srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err)
|
||||||
@ -922,12 +949,13 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
|
|||||||
conn.handshakeDone(c.node.ID())
|
conn.handshakeDone(c.node.ID())
|
||||||
}
|
}
|
||||||
clog := srv.log.New("id", c.node.ID(), "addr", c.fd.RemoteAddr(), "conn", c.flags)
|
clog := srv.log.New("id", c.node.ID(), "addr", c.fd.RemoteAddr(), "conn", c.flags)
|
||||||
err = srv.checkpoint(c, srv.posthandshake)
|
err = srv.checkpoint(c, srv.checkpointPostHandshake)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
clog.Trace("Rejected peer before protocol handshake", "err", err)
|
clog.Trace("Rejected peer", "err", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Run the protocol handshake
|
|
||||||
|
// Run the capability negotiation handshake.
|
||||||
phs, err := c.doProtoHandshake(srv.ourHandshake)
|
phs, err := c.doProtoHandshake(srv.ourHandshake)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
clog.Trace("Failed proto handshake", "err", err)
|
clog.Trace("Failed proto handshake", "err", err)
|
||||||
@ -938,14 +966,15 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
|
|||||||
return DiscUnexpectedIdentity
|
return DiscUnexpectedIdentity
|
||||||
}
|
}
|
||||||
c.caps, c.name = phs.Caps, phs.Name
|
c.caps, c.name = phs.Caps, phs.Name
|
||||||
err = srv.checkpoint(c, srv.addpeer)
|
err = srv.checkpoint(c, srv.checkpointAddPeer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
clog.Trace("Rejected peer", "err", err)
|
clog.Trace("Rejected peer", "err", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// If the checks completed successfully, runPeer has now been
|
|
||||||
// launched by run.
|
// If the checks completed successfully, the connection has been added as a peer and
|
||||||
clog.Trace("connection set up", "inbound", dialDest == nil)
|
// runPeer has been launched.
|
||||||
|
clog.Trace("Connection set up", "inbound", dialDest == nil)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -974,12 +1003,7 @@ func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error {
|
|||||||
case <-srv.quit:
|
case <-srv.quit:
|
||||||
return errServerStopped
|
return errServerStopped
|
||||||
}
|
}
|
||||||
select {
|
return <-c.cont
|
||||||
case err := <-c.cont:
|
|
||||||
return err
|
|
||||||
case <-srv.quit:
|
|
||||||
return errServerStopped
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// runPeer runs in its own goroutine for each peer.
|
// runPeer runs in its own goroutine for each peer.
|
||||||
|
@ -19,6 +19,7 @@ package p2p
|
|||||||
import (
|
import (
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -26,6 +27,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/crypto"
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
"github.com/ethereum/go-ethereum/internal/testlog"
|
||||||
"github.com/ethereum/go-ethereum/log"
|
"github.com/ethereum/go-ethereum/log"
|
||||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||||
"github.com/ethereum/go-ethereum/p2p/enr"
|
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||||
@ -74,6 +76,7 @@ func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *
|
|||||||
MaxPeers: 10,
|
MaxPeers: 10,
|
||||||
ListenAddr: "127.0.0.1:0",
|
ListenAddr: "127.0.0.1:0",
|
||||||
PrivateKey: newkey(),
|
PrivateKey: newkey(),
|
||||||
|
Logger: testlog.Logger(t, log.LvlTrace),
|
||||||
}
|
}
|
||||||
server := &Server{
|
server := &Server{
|
||||||
Config: config,
|
Config: config,
|
||||||
@ -359,6 +362,7 @@ func TestServerAtCap(t *testing.T) {
|
|||||||
PrivateKey: newkey(),
|
PrivateKey: newkey(),
|
||||||
MaxPeers: 10,
|
MaxPeers: 10,
|
||||||
NoDial: true,
|
NoDial: true,
|
||||||
|
NoDiscovery: true,
|
||||||
TrustedNodes: []*enode.Node{newNode(trustedID, nil)},
|
TrustedNodes: []*enode.Node{newNode(trustedID, nil)},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -377,19 +381,19 @@ func TestServerAtCap(t *testing.T) {
|
|||||||
// Inject a few connections to fill up the peer set.
|
// Inject a few connections to fill up the peer set.
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
c := newconn(randomID())
|
c := newconn(randomID())
|
||||||
if err := srv.checkpoint(c, srv.addpeer); err != nil {
|
if err := srv.checkpoint(c, srv.checkpointAddPeer); err != nil {
|
||||||
t.Fatalf("could not add conn %d: %v", i, err)
|
t.Fatalf("could not add conn %d: %v", i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Try inserting a non-trusted connection.
|
// Try inserting a non-trusted connection.
|
||||||
anotherID := randomID()
|
anotherID := randomID()
|
||||||
c := newconn(anotherID)
|
c := newconn(anotherID)
|
||||||
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
|
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
|
||||||
t.Error("wrong error for insert:", err)
|
t.Error("wrong error for insert:", err)
|
||||||
}
|
}
|
||||||
// Try inserting a trusted connection.
|
// Try inserting a trusted connection.
|
||||||
c = newconn(trustedID)
|
c = newconn(trustedID)
|
||||||
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
|
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
|
||||||
t.Error("unexpected error for trusted conn @posthandshake:", err)
|
t.Error("unexpected error for trusted conn @posthandshake:", err)
|
||||||
}
|
}
|
||||||
if !c.is(trustedConn) {
|
if !c.is(trustedConn) {
|
||||||
@ -399,14 +403,14 @@ func TestServerAtCap(t *testing.T) {
|
|||||||
// Remove from trusted set and try again
|
// Remove from trusted set and try again
|
||||||
srv.RemoveTrustedPeer(newNode(trustedID, nil))
|
srv.RemoveTrustedPeer(newNode(trustedID, nil))
|
||||||
c = newconn(trustedID)
|
c = newconn(trustedID)
|
||||||
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
|
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
|
||||||
t.Error("wrong error for insert:", err)
|
t.Error("wrong error for insert:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add anotherID to trusted set and try again
|
// Add anotherID to trusted set and try again
|
||||||
srv.AddTrustedPeer(newNode(anotherID, nil))
|
srv.AddTrustedPeer(newNode(anotherID, nil))
|
||||||
c = newconn(anotherID)
|
c = newconn(anotherID)
|
||||||
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
|
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
|
||||||
t.Error("unexpected error for trusted conn @posthandshake:", err)
|
t.Error("unexpected error for trusted conn @posthandshake:", err)
|
||||||
}
|
}
|
||||||
if !c.is(trustedConn) {
|
if !c.is(trustedConn) {
|
||||||
@ -433,6 +437,7 @@ func TestServerPeerLimits(t *testing.T) {
|
|||||||
PrivateKey: srvkey,
|
PrivateKey: srvkey,
|
||||||
MaxPeers: 0,
|
MaxPeers: 0,
|
||||||
NoDial: true,
|
NoDial: true,
|
||||||
|
NoDiscovery: true,
|
||||||
Protocols: []Protocol{discard},
|
Protocols: []Protocol{discard},
|
||||||
},
|
},
|
||||||
newTransport: func(fd net.Conn) transport { return tp },
|
newTransport: func(fd net.Conn) transport { return tp },
|
||||||
@ -541,20 +546,25 @@ func TestServerSetupConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, test := range tests {
|
for i, test := range tests {
|
||||||
srv := &Server{
|
t.Run(test.wantCalls, func(t *testing.T) {
|
||||||
Config: Config{
|
cfg := Config{
|
||||||
PrivateKey: srvkey,
|
PrivateKey: srvkey,
|
||||||
MaxPeers: 10,
|
MaxPeers: 10,
|
||||||
NoDial: true,
|
NoDial: true,
|
||||||
|
NoDiscovery: true,
|
||||||
Protocols: []Protocol{discard},
|
Protocols: []Protocol{discard},
|
||||||
},
|
Logger: testlog.Logger(t, log.LvlTrace),
|
||||||
|
}
|
||||||
|
srv := &Server{
|
||||||
|
Config: cfg,
|
||||||
newTransport: func(fd net.Conn) transport { return test.tt },
|
newTransport: func(fd net.Conn) transport { return test.tt },
|
||||||
log: log.New(),
|
log: cfg.Logger,
|
||||||
}
|
}
|
||||||
if !test.dontstart {
|
if !test.dontstart {
|
||||||
if err := srv.Start(); err != nil {
|
if err := srv.Start(); err != nil {
|
||||||
t.Fatalf("couldn't start server: %v", err)
|
t.Fatalf("couldn't start server: %v", err)
|
||||||
}
|
}
|
||||||
|
defer srv.Stop()
|
||||||
}
|
}
|
||||||
p1, _ := net.Pipe()
|
p1, _ := net.Pipe()
|
||||||
srv.SetupConn(p1, test.flags, test.dialDest)
|
srv.SetupConn(p1, test.flags, test.dialDest)
|
||||||
@ -564,6 +574,7 @@ func TestServerSetupConn(t *testing.T) {
|
|||||||
if test.tt.calls != test.wantCalls {
|
if test.tt.calls != test.wantCalls {
|
||||||
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
|
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -616,3 +627,100 @@ func randomID() (id enode.ID) {
|
|||||||
}
|
}
|
||||||
return id
|
return id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This test checks that inbound connections are throttled by IP.
|
||||||
|
func TestServerInboundThrottle(t *testing.T) {
|
||||||
|
const timeout = 5 * time.Second
|
||||||
|
newTransportCalled := make(chan struct{})
|
||||||
|
srv := &Server{
|
||||||
|
Config: Config{
|
||||||
|
PrivateKey: newkey(),
|
||||||
|
ListenAddr: "127.0.0.1:0",
|
||||||
|
MaxPeers: 10,
|
||||||
|
NoDial: true,
|
||||||
|
NoDiscovery: true,
|
||||||
|
Protocols: []Protocol{discard},
|
||||||
|
Logger: testlog.Logger(t, log.LvlTrace),
|
||||||
|
},
|
||||||
|
newTransport: func(fd net.Conn) transport {
|
||||||
|
newTransportCalled <- struct{}{}
|
||||||
|
return newRLPX(fd)
|
||||||
|
},
|
||||||
|
listenFunc: func(network, laddr string) (net.Listener, error) {
|
||||||
|
fakeAddr := &net.TCPAddr{IP: net.IP{95, 33, 21, 2}, Port: 4444}
|
||||||
|
return listenFakeAddr(network, laddr, fakeAddr)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := srv.Start(); err != nil {
|
||||||
|
t.Fatal("can't start: ", err)
|
||||||
|
}
|
||||||
|
defer srv.Stop()
|
||||||
|
|
||||||
|
// Dial the test server.
|
||||||
|
conn, err := net.DialTimeout("tcp", srv.ListenAddr, timeout)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("could not dial: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-newTransportCalled:
|
||||||
|
// OK
|
||||||
|
case <-time.After(timeout):
|
||||||
|
t.Error("newTransport not called")
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
|
||||||
|
// Dial again. This time the server should close the connection immediately.
|
||||||
|
connClosed := make(chan struct{})
|
||||||
|
conn, err = net.DialTimeout("tcp", srv.ListenAddr, timeout)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("could not dial: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
go func() {
|
||||||
|
conn.SetDeadline(time.Now().Add(timeout))
|
||||||
|
buf := make([]byte, 10)
|
||||||
|
if n, err := conn.Read(buf); err != io.EOF || n != 0 {
|
||||||
|
t.Errorf("expected io.EOF and n == 0, got error %q and n == %d", err, n)
|
||||||
|
}
|
||||||
|
connClosed <- struct{}{}
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-connClosed:
|
||||||
|
// OK
|
||||||
|
case <-newTransportCalled:
|
||||||
|
t.Error("newTransport called for second attempt")
|
||||||
|
case <-time.After(timeout):
|
||||||
|
t.Error("connection not closed within timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenFakeAddr(network, laddr string, remoteAddr net.Addr) (net.Listener, error) {
|
||||||
|
l, err := net.Listen(network, laddr)
|
||||||
|
if err == nil {
|
||||||
|
l = &fakeAddrListener{l, remoteAddr}
|
||||||
|
}
|
||||||
|
return l, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeAddrListener is a listener that creates connections with a mocked remote address.
|
||||||
|
type fakeAddrListener struct {
|
||||||
|
net.Listener
|
||||||
|
remoteAddr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeAddrConn struct {
|
||||||
|
net.Conn
|
||||||
|
remoteAddr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *fakeAddrListener) Accept() (net.Conn, error) {
|
||||||
|
c, err := l.Listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &fakeAddrConn{c, l.remoteAddr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeAddrConn) RemoteAddr() net.Addr {
|
||||||
|
return c.remoteAddr
|
||||||
|
}
|
||||||
|
82
p2p/util.go
Normal file
82
p2p/util.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
// Copyright 2019 The go-ethereum Authors
|
||||||
|
// This file is part of the go-ethereum library.
|
||||||
|
//
|
||||||
|
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Lesser General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Lesser General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Lesser General Public License
|
||||||
|
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/heap"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// expHeap tracks strings and their expiry time.
|
||||||
|
type expHeap []expItem
|
||||||
|
|
||||||
|
// expItem is an entry in addrHistory.
|
||||||
|
type expItem struct {
|
||||||
|
item string
|
||||||
|
exp time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// nextExpiry returns the next expiry time.
|
||||||
|
func (h *expHeap) nextExpiry() time.Time {
|
||||||
|
return (*h)[0].exp
|
||||||
|
}
|
||||||
|
|
||||||
|
// add adds an item and sets its expiry time.
|
||||||
|
func (h *expHeap) add(item string, exp time.Time) {
|
||||||
|
heap.Push(h, expItem{item, exp})
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove removes an item.
|
||||||
|
func (h *expHeap) remove(item string) bool {
|
||||||
|
for i, v := range *h {
|
||||||
|
if v.item == item {
|
||||||
|
heap.Remove(h, i)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// contains checks whether an item is present.
|
||||||
|
func (h expHeap) contains(item string) bool {
|
||||||
|
for _, v := range h {
|
||||||
|
if v.item == item {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// expire removes items with expiry time before 'now'.
|
||||||
|
func (h *expHeap) expire(now time.Time) {
|
||||||
|
for h.Len() > 0 && h.nextExpiry().Before(now) {
|
||||||
|
heap.Pop(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// heap.Interface boilerplate
|
||||||
|
func (h expHeap) Len() int { return len(h) }
|
||||||
|
func (h expHeap) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
|
||||||
|
func (h expHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||||
|
func (h *expHeap) Push(x interface{}) { *h = append(*h, x.(expItem)) }
|
||||||
|
func (h *expHeap) Pop() interface{} {
|
||||||
|
old := *h
|
||||||
|
n := len(old)
|
||||||
|
x := old[n-1]
|
||||||
|
*h = old[0 : n-1]
|
||||||
|
return x
|
||||||
|
}
|
54
p2p/util_test.go
Normal file
54
p2p/util_test.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
// Copyright 2019 The go-ethereum Authors
|
||||||
|
// This file is part of the go-ethereum library.
|
||||||
|
//
|
||||||
|
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Lesser General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Lesser General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Lesser General Public License
|
||||||
|
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExpHeap(t *testing.T) {
|
||||||
|
var h expHeap
|
||||||
|
|
||||||
|
var (
|
||||||
|
basetime = time.Unix(4000, 0)
|
||||||
|
exptimeA = basetime.Add(2 * time.Second)
|
||||||
|
exptimeB = basetime.Add(3 * time.Second)
|
||||||
|
exptimeC = basetime.Add(4 * time.Second)
|
||||||
|
)
|
||||||
|
h.add("a", exptimeA)
|
||||||
|
h.add("b", exptimeB)
|
||||||
|
h.add("c", exptimeC)
|
||||||
|
|
||||||
|
if !h.nextExpiry().Equal(exptimeA) {
|
||||||
|
t.Fatal("wrong nextExpiry")
|
||||||
|
}
|
||||||
|
if !h.contains("a") || !h.contains("b") || !h.contains("c") {
|
||||||
|
t.Fatal("heap doesn't contain all live items")
|
||||||
|
}
|
||||||
|
|
||||||
|
h.expire(exptimeA.Add(1))
|
||||||
|
if !h.nextExpiry().Equal(exptimeB) {
|
||||||
|
t.Fatal("wrong nextExpiry")
|
||||||
|
}
|
||||||
|
if h.contains("a") {
|
||||||
|
t.Fatal("heap contains a even though it has already expired")
|
||||||
|
}
|
||||||
|
if !h.contains("b") || !h.contains("c") {
|
||||||
|
t.Fatal("heap doesn't contain all live items")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user