p2p: enforce connection retry limit on server side (#19684)

The dialer limits itself to one attempt every 30s. Apply the same limit
in Server and reject peers which try to connect too eagerly. The check
against the limit happens right after accepting the connection.

Further changes in this commit ensure we pass the Server logger
down to Peer instances, discovery and dialState. Unit test logging now
works in all Server tests.
This commit is contained in:
Felix Lange 2019-06-11 12:45:33 +02:00 committed by GitHub
parent c0a034ec89
commit c420dcb39c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 518 additions and 292 deletions

View File

@ -17,7 +17,6 @@
package p2p package p2p
import ( import (
"container/heap"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -29,9 +28,10 @@ import (
) )
const ( const (
// This is the amount of time spent waiting in between // This is the amount of time spent waiting in between redialing a certain node. The
// redialing a certain node. // limit is a bit higher than inboundThrottleTime to prevent failing dials in small
dialHistoryExpiration = 30 * time.Second // private networks.
dialHistoryExpiration = inboundThrottleTime + 5*time.Second
// Discovery lookups are throttled and can only run // Discovery lookups are throttled and can only run
// once every few seconds. // once every few seconds.
@ -72,16 +72,16 @@ type dialstate struct {
ntab discoverTable ntab discoverTable
netrestrict *netutil.Netlist netrestrict *netutil.Netlist
self enode.ID self enode.ID
bootnodes []*enode.Node // default dials when there are no peers
log log.Logger
start time.Time // time when the dialer was first used
lookupRunning bool lookupRunning bool
dialing map[enode.ID]connFlag dialing map[enode.ID]connFlag
lookupBuf []*enode.Node // current discovery lookup results lookupBuf []*enode.Node // current discovery lookup results
randomNodes []*enode.Node // filled from Table randomNodes []*enode.Node // filled from Table
static map[enode.ID]*dialTask static map[enode.ID]*dialTask
hist *dialHistory hist expHeap
start time.Time // time when the dialer was first used
bootnodes []*enode.Node // default dials when there are no peers
} }
type discoverTable interface { type discoverTable interface {
@ -91,15 +91,6 @@ type discoverTable interface {
ReadRandomNodes([]*enode.Node) int ReadRandomNodes([]*enode.Node) int
} }
// the dial history remembers recent dials.
type dialHistory []pastDial
// pastDial is an entry in the dial history.
type pastDial struct {
id enode.ID
exp time.Time
}
type task interface { type task interface {
Do(*Server) Do(*Server)
} }
@ -126,20 +117,23 @@ type waitExpireTask struct {
time.Duration time.Duration
} }
func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate {
s := &dialstate{ s := &dialstate{
maxDynDials: maxdyn, maxDynDials: maxdyn,
ntab: ntab, ntab: ntab,
self: self, self: self,
netrestrict: netrestrict, netrestrict: cfg.NetRestrict,
log: cfg.Logger,
static: make(map[enode.ID]*dialTask), static: make(map[enode.ID]*dialTask),
dialing: make(map[enode.ID]connFlag), dialing: make(map[enode.ID]connFlag),
bootnodes: make([]*enode.Node, len(bootnodes)), bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)),
randomNodes: make([]*enode.Node, maxdyn/2), randomNodes: make([]*enode.Node, maxdyn/2),
hist: new(dialHistory),
} }
copy(s.bootnodes, bootnodes) copy(s.bootnodes, cfg.BootstrapNodes)
for _, n := range static { if s.log == nil {
s.log = log.Root()
}
for _, n := range cfg.StaticNodes {
s.addStatic(n) s.addStatic(n)
} }
return s return s
@ -154,9 +148,6 @@ func (s *dialstate) addStatic(n *enode.Node) {
func (s *dialstate) removeStatic(n *enode.Node) { func (s *dialstate) removeStatic(n *enode.Node) {
// This removes a task so future attempts to connect will not be made. // This removes a task so future attempts to connect will not be made.
delete(s.static, n.ID()) delete(s.static, n.ID())
// This removes a previous dial timestamp so that application
// can force a server to reconnect with chosen peer immediately.
s.hist.remove(n.ID())
} }
func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task { func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task {
@ -167,7 +158,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
var newtasks []task var newtasks []task
addDial := func(flag connFlag, n *enode.Node) bool { addDial := func(flag connFlag, n *enode.Node) bool {
if err := s.checkDial(n, peers); err != nil { if err := s.checkDial(n, peers); err != nil {
log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err) s.log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err)
return false return false
} }
s.dialing[n.ID()] = flag s.dialing[n.ID()] = flag
@ -196,7 +187,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
err := s.checkDial(t.dest, peers) err := s.checkDial(t.dest, peers)
switch err { switch err {
case errNotWhitelisted, errSelf: case errNotWhitelisted, errSelf:
log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err) s.log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err)
delete(s.static, t.dest.ID()) delete(s.static, t.dest.ID())
case nil: case nil:
s.dialing[id] = t.flags s.dialing[id] = t.flags
@ -246,7 +237,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
// This should prevent cases where the dialer logic is not ticked // This should prevent cases where the dialer logic is not ticked
// because there are no pending events. // because there are no pending events.
if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 { if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
t := &waitExpireTask{s.hist.min().exp.Sub(now)} t := &waitExpireTask{s.hist.nextExpiry().Sub(now)}
newtasks = append(newtasks, t) newtasks = append(newtasks, t)
} }
return newtasks return newtasks
@ -271,7 +262,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
return errSelf return errSelf
case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()): case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()):
return errNotWhitelisted return errNotWhitelisted
case s.hist.contains(n.ID()): case s.hist.contains(string(n.ID().Bytes())):
return errRecentlyDialed return errRecentlyDialed
} }
return nil return nil
@ -280,7 +271,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
func (s *dialstate) taskDone(t task, now time.Time) { func (s *dialstate) taskDone(t task, now time.Time) {
switch t := t.(type) { switch t := t.(type) {
case *dialTask: case *dialTask:
s.hist.add(t.dest.ID(), now.Add(dialHistoryExpiration)) s.hist.add(string(t.dest.ID().Bytes()), now.Add(dialHistoryExpiration))
delete(s.dialing, t.dest.ID()) delete(s.dialing, t.dest.ID())
case *discoverTask: case *discoverTask:
s.lookupRunning = false s.lookupRunning = false
@ -296,7 +287,7 @@ func (t *dialTask) Do(srv *Server) {
} }
err := t.dial(srv, t.dest) err := t.dial(srv, t.dest)
if err != nil { if err != nil {
log.Trace("Dial error", "task", t, "err", err) srv.log.Trace("Dial error", "task", t, "err", err)
// Try resolving the ID of static nodes if dialing failed. // Try resolving the ID of static nodes if dialing failed.
if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
if t.resolve(srv) { if t.resolve(srv) {
@ -314,7 +305,7 @@ func (t *dialTask) Do(srv *Server) {
// The backoff delay resets when the node is found. // The backoff delay resets when the node is found.
func (t *dialTask) resolve(srv *Server) bool { func (t *dialTask) resolve(srv *Server) bool {
if srv.ntab == nil { if srv.ntab == nil {
log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled") srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
return false return false
} }
if t.resolveDelay == 0 { if t.resolveDelay == 0 {
@ -330,13 +321,13 @@ func (t *dialTask) resolve(srv *Server) bool {
if t.resolveDelay > maxResolveDelay { if t.resolveDelay > maxResolveDelay {
t.resolveDelay = maxResolveDelay t.resolveDelay = maxResolveDelay
} }
log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay) srv.log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay)
return false return false
} }
// The node was found. // The node was found.
t.resolveDelay = initialResolveDelay t.resolveDelay = initialResolveDelay
t.dest = resolved t.dest = resolved
log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) srv.log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
return true return true
} }
@ -385,49 +376,3 @@ func (t waitExpireTask) Do(*Server) {
func (t waitExpireTask) String() string { func (t waitExpireTask) String() string {
return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration) return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
} }
// Use only these methods to access or modify dialHistory.
func (h dialHistory) min() pastDial {
return h[0]
}
func (h *dialHistory) add(id enode.ID, exp time.Time) {
heap.Push(h, pastDial{id, exp})
}
func (h *dialHistory) remove(id enode.ID) bool {
for i, v := range *h {
if v.id == id {
heap.Remove(h, i)
return true
}
}
return false
}
func (h dialHistory) contains(id enode.ID) bool {
for _, v := range h {
if v.id == id {
return true
}
}
return false
}
func (h *dialHistory) expire(now time.Time) {
for h.Len() > 0 && h.min().exp.Before(now) {
heap.Pop(h)
}
}
// heap.Interface boilerplate
func (h dialHistory) Len() int { return len(h) }
func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *dialHistory) Push(x interface{}) {
*h = append(*h, x.(pastDial))
}
func (h *dialHistory) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}

View File

@ -20,10 +20,13 @@ import (
"encoding/binary" "encoding/binary"
"net" "net"
"reflect" "reflect"
"strings"
"testing" "testing"
"time" "time"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/internal/testlog"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/p2p/netutil"
@ -67,10 +70,10 @@ func runDialTest(t *testing.T, test dialtest) {
new := test.init.newTasks(running, pm(round.peers), vtime) new := test.init.newTasks(running, pm(round.peers), vtime)
if !sametasks(new, round.new) { if !sametasks(new, round.new) {
t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n", t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v",
i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running)) i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
} }
t.Log("tasks:", spew.Sdump(new)) t.Logf("round %d new tasks: %s", i, strings.TrimSpace(spew.Sdump(new)))
// Time advances by 16 seconds on every round. // Time advances by 16 seconds on every round.
vtime = vtime.Add(16 * time.Second) vtime = vtime.Add(16 * time.Second)
@ -88,8 +91,9 @@ func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t)
// This test checks that dynamic dials are launched from discovery results. // This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) { func TestDialStateDynDial(t *testing.T) {
config := &Config{Logger: testlog.Logger(t, log.LvlTrace)}
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(enode.ID{}, nil, nil, fakeTable{}, 5, nil), init: newDialState(enode.ID{}, fakeTable{}, 5, config),
rounds: []round{ rounds: []round{
// A discovery query is launched. // A discovery query is launched.
{ {
@ -153,7 +157,7 @@ func TestDialStateDynDial(t *testing.T) {
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
}, },
new: []task{ new: []task{
&waitExpireTask{Duration: 14 * time.Second}, &waitExpireTask{Duration: 19 * time.Second},
}, },
}, },
// In this round, the peer with id 2 drops off. The query // In this round, the peer with id 2 drops off. The query
@ -223,10 +227,13 @@ func TestDialStateDynDial(t *testing.T) {
// Tests that bootnodes are dialed if no peers are connectd, but not otherwise. // Tests that bootnodes are dialed if no peers are connectd, but not otherwise.
func TestDialStateDynDialBootnode(t *testing.T) { func TestDialStateDynDialBootnode(t *testing.T) {
bootnodes := []*enode.Node{ config := &Config{
newNode(uintID(1), nil), BootstrapNodes: []*enode.Node{
newNode(uintID(2), nil), newNode(uintID(1), nil),
newNode(uintID(3), nil), newNode(uintID(2), nil),
newNode(uintID(3), nil),
},
Logger: testlog.Logger(t, log.LvlTrace),
} }
table := fakeTable{ table := fakeTable{
newNode(uintID(4), nil), newNode(uintID(4), nil),
@ -236,7 +243,7 @@ func TestDialStateDynDialBootnode(t *testing.T) {
newNode(uintID(8), nil), newNode(uintID(8), nil),
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(enode.ID{}, nil, bootnodes, table, 5, nil), init: newDialState(enode.ID{}, table, 5, config),
rounds: []round{ rounds: []round{
// 2 dynamic dials attempted, bootnodes pending fallback interval // 2 dynamic dials attempted, bootnodes pending fallback interval
{ {
@ -259,25 +266,24 @@ func TestDialStateDynDialBootnode(t *testing.T) {
{ {
new: []task{ new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
}, },
}, },
// No dials succeed, 2nd bootnode is attempted // No dials succeed, 2nd bootnode is attempted
{ {
done: []task{ done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
}, },
new: []task{ new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
}, },
}, },
// No dials succeed, 3rd bootnode is attempted // No dials succeed, 3rd bootnode is attempted
{ {
done: []task{ done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
}, },
new: []task{ new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
@ -288,21 +294,19 @@ func TestDialStateDynDialBootnode(t *testing.T) {
done: []task{ done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
}, },
new: []task{ new: []task{},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
},
}, },
// Random dial succeeds, no more bootnodes are attempted // Random dial succeeds, no more bootnodes are attempted
{ {
new: []task{
&waitExpireTask{3 * time.Second},
},
peers: []*Peer{ peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
}, },
done: []task{ done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
}, },
}, },
}, },
@ -324,7 +328,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(enode.ID{}, nil, nil, table, 10, nil), init: newDialState(enode.ID{}, table, 10, &Config{Logger: testlog.Logger(t, log.LvlTrace)}),
rounds: []round{ rounds: []round{
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
{ {
@ -430,7 +434,7 @@ func TestDialStateNetRestrict(t *testing.T) {
restrict.Add("127.0.2.0/24") restrict.Add("127.0.2.0/24")
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(enode.ID{}, nil, nil, table, 10, restrict), init: newDialState(enode.ID{}, table, 10, &Config{NetRestrict: restrict}),
rounds: []round{ rounds: []round{
{ {
new: []task{ new: []task{
@ -444,16 +448,18 @@ func TestDialStateNetRestrict(t *testing.T) {
// This test checks that static dials are launched. // This test checks that static dials are launched.
func TestDialStateStaticDial(t *testing.T) { func TestDialStateStaticDial(t *testing.T) {
wantStatic := []*enode.Node{ config := &Config{
newNode(uintID(1), nil), StaticNodes: []*enode.Node{
newNode(uintID(2), nil), newNode(uintID(1), nil),
newNode(uintID(3), nil), newNode(uintID(2), nil),
newNode(uintID(4), nil), newNode(uintID(3), nil),
newNode(uintID(5), nil), newNode(uintID(4), nil),
newNode(uintID(5), nil),
},
Logger: testlog.Logger(t, log.LvlTrace),
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil), init: newDialState(enode.ID{}, fakeTable{}, 0, config),
rounds: []round{ rounds: []round{
// Static dials are launched for the nodes that // Static dials are launched for the nodes that
// aren't yet connected. // aren't yet connected.
@ -495,7 +501,7 @@ func TestDialStateStaticDial(t *testing.T) {
&dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)}, &dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)},
}, },
new: []task{ new: []task{
&waitExpireTask{Duration: 14 * time.Second}, &waitExpireTask{Duration: 19 * time.Second},
}, },
}, },
// Wait a round for dial history to expire, no new tasks should spawn. // Wait a round for dial history to expire, no new tasks should spawn.
@ -511,6 +517,9 @@ func TestDialStateStaticDial(t *testing.T) {
// If a static node is dropped, it should be immediately redialed, // If a static node is dropped, it should be immediately redialed,
// irrespective whether it was originally static or dynamic. // irrespective whether it was originally static or dynamic.
{ {
done: []task{
&waitExpireTask{Duration: 19 * time.Second},
},
peers: []*Peer{ peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
@ -518,67 +527,24 @@ func TestDialStateStaticDial(t *testing.T) {
}, },
new: []task{ new: []task{
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
&dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)},
}, },
}, },
}, },
}) })
} }
// This test checks that static peers will be redialed immediately if they were re-added to a static list.
func TestDialStaticAfterReset(t *testing.T) {
wantStatic := []*enode.Node{
newNode(uintID(1), nil),
newNode(uintID(2), nil),
}
rounds := []round{
// Static dials are launched for the nodes that aren't yet connected.
{
peers: nil,
new: []task{
&dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
},
},
// No new dial tasks, all peers are connected.
{
peers: []*Peer{
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
},
done: []task{
&dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
},
new: []task{
&waitExpireTask{Duration: 30 * time.Second},
},
},
}
dTest := dialtest{
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
rounds: rounds,
}
runDialTest(t, dTest)
for _, n := range wantStatic {
dTest.init.removeStatic(n)
dTest.init.addStatic(n)
}
// without removing peers they will be considered recently dialed
runDialTest(t, dTest)
}
// This test checks that past dials are not retried for some time. // This test checks that past dials are not retried for some time.
func TestDialStateCache(t *testing.T) { func TestDialStateCache(t *testing.T) {
wantStatic := []*enode.Node{ config := &Config{
newNode(uintID(1), nil), StaticNodes: []*enode.Node{
newNode(uintID(2), nil), newNode(uintID(1), nil),
newNode(uintID(3), nil), newNode(uintID(2), nil),
newNode(uintID(3), nil),
},
Logger: testlog.Logger(t, log.LvlTrace),
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil), init: newDialState(enode.ID{}, fakeTable{}, 0, config),
rounds: []round{ rounds: []round{
// Static dials are launched for the nodes that // Static dials are launched for the nodes that
// aren't yet connected. // aren't yet connected.
@ -606,28 +572,37 @@ func TestDialStateCache(t *testing.T) {
// entry to expire. // entry to expire.
{ {
peers: []*Peer{ peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
}, },
done: []task{ done: []task{
&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
}, },
new: []task{ new: []task{
&waitExpireTask{Duration: 14 * time.Second}, &waitExpireTask{Duration: 19 * time.Second},
}, },
}, },
// Still waiting for node 3's entry to expire in the cache. // Still waiting for node 3's entry to expire in the cache.
{ {
peers: []*Peer{ peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
},
},
{
peers: []*Peer{
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
}, },
}, },
// The cache entry for node 3 has expired and is retried. // The cache entry for node 3 has expired and is retried.
{ {
done: []task{
&waitExpireTask{Duration: 19 * time.Second},
},
peers: []*Peer{ peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
}, },
new: []task{ new: []task{
&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
@ -638,9 +613,13 @@ func TestDialStateCache(t *testing.T) {
} }
func TestDialResolve(t *testing.T) { func TestDialResolve(t *testing.T) {
config := &Config{
Logger: testlog.Logger(t, log.LvlTrace),
Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}},
}
resolved := newNode(uintID(1), net.IP{127, 0, 55, 234}) resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
table := &resolveMock{answer: resolved} table := &resolveMock{answer: resolved}
state := newDialState(enode.ID{}, nil, nil, table, 0, nil) state := newDialState(enode.ID{}, table, 0, config)
// Check that the task is generated with an incomplete ID. // Check that the task is generated with an incomplete ID.
dest := newNode(uintID(1), nil) dest := newNode(uintID(1), nil)
@ -651,8 +630,7 @@ func TestDialResolve(t *testing.T) {
} }
// Now run the task, it should resolve the ID once. // Now run the task, it should resolve the ID once.
config := Config{Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}} srv := &Server{ntab: table, log: config.Logger, Config: *config}
srv := &Server{ntab: table, Config: config}
tasks[0].Do(srv) tasks[0].Do(srv)
if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) { if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) {
t.Fatalf("wrong resolve calls, got %v", table.resolveCalls) t.Fatalf("wrong resolve calls, got %v", table.resolveCalls)

33
p2p/netutil/addrutil.go Normal file
View File

@ -0,0 +1,33 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package netutil
import "net"
// AddrIP gets the IP address contained in addr. It returns nil if no address is present.
func AddrIP(addr net.Addr) net.IP {
switch a := addr.(type) {
case *net.IPAddr:
return a.IP
case *net.TCPAddr:
return a.IP
case *net.UDPAddr:
return a.IP
default:
return nil
}
}

View File

@ -120,7 +120,7 @@ func NewPeer(id enode.ID, name string, caps []Cap) *Peer {
pipe, _ := net.Pipe() pipe, _ := net.Pipe()
node := enode.SignNull(new(enr.Record), id) node := enode.SignNull(new(enr.Record), id)
conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name} conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name}
peer := newPeer(conn, nil) peer := newPeer(log.Root(), conn, nil)
close(peer.closed) // ensures Disconnect doesn't block close(peer.closed) // ensures Disconnect doesn't block
return peer return peer
} }
@ -176,7 +176,7 @@ func (p *Peer) Inbound() bool {
return p.rw.is(inboundConn) return p.rw.is(inboundConn)
} }
func newPeer(conn *conn, protocols []Protocol) *Peer { func newPeer(log log.Logger, conn *conn, protocols []Protocol) *Peer {
protomap := matchProtocols(protocols, conn.caps, conn) protomap := matchProtocols(protocols, conn.caps, conn)
p := &Peer{ p := &Peer{
rw: conn, rw: conn,

View File

@ -24,6 +24,8 @@ import (
"reflect" "reflect"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/log"
) )
var discard = Protocol{ var discard = Protocol{
@ -52,7 +54,7 @@ func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) {
c2.caps = append(c2.caps, p.cap()) c2.caps = append(c2.caps, p.cap())
} }
peer := newPeer(c1, protos) peer := newPeer(log.Root(), c1, protos)
errc := make(chan error, 1) errc := make(chan error, 1)
go func() { go func() {
_, err := peer.run() _, err := peer.run()

View File

@ -22,6 +22,7 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt"
"net" "net"
"sort" "sort"
"sync" "sync"
@ -49,6 +50,9 @@ const (
defaultMaxPendingPeers = 50 defaultMaxPendingPeers = 50
defaultDialRatio = 3 defaultDialRatio = 3
// This time limits inbound connection attempts per source IP.
inboundThrottleTime = 30 * time.Second
// Maximum time allowed for reading a complete message. // Maximum time allowed for reading a complete message.
// This is effectively the amount of time a connection can be idle. // This is effectively the amount of time a connection can be idle.
frameReadTimeout = 30 * time.Second frameReadTimeout = 30 * time.Second
@ -158,6 +162,7 @@ type Server struct {
// the whole protocol stack. // the whole protocol stack.
newTransport func(net.Conn) transport newTransport func(net.Conn) transport
newPeerHook func(*Peer) newPeerHook func(*Peer)
listenFunc func(network, addr string) (net.Listener, error)
lock sync.Mutex // protects running lock sync.Mutex // protects running
running bool running bool
@ -167,24 +172,26 @@ type Server struct {
ntab discoverTable ntab discoverTable
listener net.Listener listener net.Listener
ourHandshake *protoHandshake ourHandshake *protoHandshake
lastLookup time.Time
DiscV5 *discv5.Network DiscV5 *discv5.Network
loopWG sync.WaitGroup // loop, listenLoop
peerFeed event.Feed
log log.Logger
// These are for Peers, PeerCount (and nothing else). // Channels into the run loop.
peerOp chan peerOpFunc quit chan struct{}
peerOpDone chan struct{} addstatic chan *enode.Node
removestatic chan *enode.Node
addtrusted chan *enode.Node
removetrusted chan *enode.Node
peerOp chan peerOpFunc
peerOpDone chan struct{}
delpeer chan peerDrop
checkpointPostHandshake chan *conn
checkpointAddPeer chan *conn
quit chan struct{} // State of run loop and listenLoop.
addstatic chan *enode.Node lastLookup time.Time
removestatic chan *enode.Node inboundHistory expHeap
addtrusted chan *enode.Node
removetrusted chan *enode.Node
posthandshake chan *conn
addpeer chan *conn
delpeer chan peerDrop
loopWG sync.WaitGroup // loop, listenLoop
peerFeed event.Feed
log log.Logger
} }
type peerOpFunc func(map[enode.ID]*Peer) type peerOpFunc func(map[enode.ID]*Peer)
@ -415,7 +422,7 @@ func (srv *Server) Start() (err error) {
srv.running = true srv.running = true
srv.log = srv.Config.Logger srv.log = srv.Config.Logger
if srv.log == nil { if srv.log == nil {
srv.log = log.New() srv.log = log.Root()
} }
if srv.NoDial && srv.ListenAddr == "" { if srv.NoDial && srv.ListenAddr == "" {
srv.log.Warn("P2P server will be useless, neither dialing nor listening") srv.log.Warn("P2P server will be useless, neither dialing nor listening")
@ -428,13 +435,16 @@ func (srv *Server) Start() (err error) {
if srv.newTransport == nil { if srv.newTransport == nil {
srv.newTransport = newRLPX srv.newTransport = newRLPX
} }
if srv.listenFunc == nil {
srv.listenFunc = net.Listen
}
if srv.Dialer == nil { if srv.Dialer == nil {
srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}} srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}}
} }
srv.quit = make(chan struct{}) srv.quit = make(chan struct{})
srv.addpeer = make(chan *conn)
srv.delpeer = make(chan peerDrop) srv.delpeer = make(chan peerDrop)
srv.posthandshake = make(chan *conn) srv.checkpointPostHandshake = make(chan *conn)
srv.checkpointAddPeer = make(chan *conn)
srv.addstatic = make(chan *enode.Node) srv.addstatic = make(chan *enode.Node)
srv.removestatic = make(chan *enode.Node) srv.removestatic = make(chan *enode.Node)
srv.addtrusted = make(chan *enode.Node) srv.addtrusted = make(chan *enode.Node)
@ -455,7 +465,7 @@ func (srv *Server) Start() (err error) {
} }
dynPeers := srv.maxDialedConns() dynPeers := srv.maxDialedConns()
dialer := newDialState(srv.localnode.ID(), srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict) dialer := newDialState(srv.localnode.ID(), srv.ntab, dynPeers, &srv.Config)
srv.loopWG.Add(1) srv.loopWG.Add(1)
go srv.run(dialer) go srv.run(dialer)
return nil return nil
@ -541,6 +551,7 @@ func (srv *Server) setupDiscovery() error {
NetRestrict: srv.NetRestrict, NetRestrict: srv.NetRestrict,
Bootnodes: srv.BootstrapNodes, Bootnodes: srv.BootstrapNodes,
Unhandled: unhandled, Unhandled: unhandled,
Log: srv.log,
} }
ntab, err := discover.ListenUDP(conn, srv.localnode, cfg) ntab, err := discover.ListenUDP(conn, srv.localnode, cfg)
if err != nil { if err != nil {
@ -569,27 +580,28 @@ func (srv *Server) setupDiscovery() error {
} }
func (srv *Server) setupListening() error { func (srv *Server) setupListening() error {
// Launch the TCP listener. // Launch the listener.
listener, err := net.Listen("tcp", srv.ListenAddr) listener, err := srv.listenFunc("tcp", srv.ListenAddr)
if err != nil { if err != nil {
return err return err
} }
laddr := listener.Addr().(*net.TCPAddr)
srv.ListenAddr = laddr.String()
srv.listener = listener srv.listener = listener
srv.localnode.Set(enr.TCP(laddr.Port)) srv.ListenAddr = listener.Addr().String()
// Update the local node record and map the TCP listening port if NAT is configured.
if tcp, ok := listener.Addr().(*net.TCPAddr); ok {
srv.localnode.Set(enr.TCP(tcp.Port))
if !tcp.IP.IsLoopback() && srv.NAT != nil {
srv.loopWG.Add(1)
go func() {
nat.Map(srv.NAT, srv.quit, "tcp", tcp.Port, tcp.Port, "ethereum p2p")
srv.loopWG.Done()
}()
}
}
srv.loopWG.Add(1) srv.loopWG.Add(1)
go srv.listenLoop() go srv.listenLoop()
// Map the TCP listening port if NAT is configured.
if !laddr.IP.IsLoopback() && srv.NAT != nil {
srv.loopWG.Add(1)
go func() {
nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p")
srv.loopWG.Done()
}()
}
return nil return nil
} }
@ -657,12 +669,14 @@ running:
case <-srv.quit: case <-srv.quit:
// The server was stopped. Run the cleanup logic. // The server was stopped. Run the cleanup logic.
break running break running
case n := <-srv.addstatic: case n := <-srv.addstatic:
// This channel is used by AddPeer to add to the // This channel is used by AddPeer to add to the
// ephemeral static peer list. Add it to the dialer, // ephemeral static peer list. Add it to the dialer,
// it will keep the node connected. // it will keep the node connected.
srv.log.Trace("Adding static node", "node", n) srv.log.Trace("Adding static node", "node", n)
dialstate.addStatic(n) dialstate.addStatic(n)
case n := <-srv.removestatic: case n := <-srv.removestatic:
// This channel is used by RemovePeer to send a // This channel is used by RemovePeer to send a
// disconnect request to a peer and begin the // disconnect request to a peer and begin the
@ -672,6 +686,7 @@ running:
if p, ok := peers[n.ID()]; ok { if p, ok := peers[n.ID()]; ok {
p.Disconnect(DiscRequested) p.Disconnect(DiscRequested)
} }
case n := <-srv.addtrusted: case n := <-srv.addtrusted:
// This channel is used by AddTrustedPeer to add an enode // This channel is used by AddTrustedPeer to add an enode
// to the trusted node set. // to the trusted node set.
@ -681,6 +696,7 @@ running:
if p, ok := peers[n.ID()]; ok { if p, ok := peers[n.ID()]; ok {
p.rw.set(trustedConn, true) p.rw.set(trustedConn, true)
} }
case n := <-srv.removetrusted: case n := <-srv.removetrusted:
// This channel is used by RemoveTrustedPeer to remove an enode // This channel is used by RemoveTrustedPeer to remove an enode
// from the trusted node set. // from the trusted node set.
@ -691,10 +707,12 @@ running:
if p, ok := peers[n.ID()]; ok { if p, ok := peers[n.ID()]; ok {
p.rw.set(trustedConn, false) p.rw.set(trustedConn, false)
} }
case op := <-srv.peerOp: case op := <-srv.peerOp:
// This channel is used by Peers and PeerCount. // This channel is used by Peers and PeerCount.
op(peers) op(peers)
srv.peerOpDone <- struct{}{} srv.peerOpDone <- struct{}{}
case t := <-taskdone: case t := <-taskdone:
// A task got done. Tell dialstate about it so it // A task got done. Tell dialstate about it so it
// can update its state and remove it from the active // can update its state and remove it from the active
@ -702,7 +720,8 @@ running:
srv.log.Trace("Dial task done", "task", t) srv.log.Trace("Dial task done", "task", t)
dialstate.taskDone(t, time.Now()) dialstate.taskDone(t, time.Now())
delTask(t) delTask(t)
case c := <-srv.posthandshake:
case c := <-srv.checkpointPostHandshake:
// A connection has passed the encryption handshake so // A connection has passed the encryption handshake so
// the remote identity is known (but hasn't been verified yet). // the remote identity is known (but hasn't been verified yet).
if trusted[c.node.ID()] { if trusted[c.node.ID()] {
@ -710,18 +729,15 @@ running:
c.flags |= trustedConn c.flags |= trustedConn
} }
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
select { c.cont <- srv.postHandshakeChecks(peers, inboundCount, c)
case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c):
case <-srv.quit: case c := <-srv.checkpointAddPeer:
break running
}
case c := <-srv.addpeer:
// At this point the connection is past the protocol handshake. // At this point the connection is past the protocol handshake.
// Its capabilities are known and the remote identity is verified. // Its capabilities are known and the remote identity is verified.
err := srv.protoHandshakeChecks(peers, inboundCount, c) err := srv.addPeerChecks(peers, inboundCount, c)
if err == nil { if err == nil {
// The handshakes are done and it passed all checks. // The handshakes are done and it passed all checks.
p := newPeer(c, srv.Protocols) p := newPeer(srv.log, c, srv.Protocols)
// If message events are enabled, pass the peerFeed // If message events are enabled, pass the peerFeed
// to the peer // to the peer
if srv.EnableMsgEvents { if srv.EnableMsgEvents {
@ -738,11 +754,8 @@ running:
// The dialer logic relies on the assumption that // The dialer logic relies on the assumption that
// dial tasks complete after the peer has been added or // dial tasks complete after the peer has been added or
// discarded. Unblock the task last. // discarded. Unblock the task last.
select { c.cont <- err
case c.cont <- err:
case <-srv.quit:
break running
}
case pd := <-srv.delpeer: case pd := <-srv.delpeer:
// A peer disconnected. // A peer disconnected.
d := common.PrettyDuration(mclock.Now() - pd.created) d := common.PrettyDuration(mclock.Now() - pd.created)
@ -777,17 +790,7 @@ running:
} }
} }
func (srv *Server) protoHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { func (srv *Server) postHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
// Drop connections with no matching protocols.
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
return DiscUselessPeer
}
// Repeat the encryption handshake checks because the
// peer set might have changed between the handshakes.
return srv.encHandshakeChecks(peers, inboundCount, c)
}
func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
switch { switch {
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers: case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
return DiscTooManyPeers return DiscTooManyPeers
@ -802,9 +805,20 @@ func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int
} }
} }
func (srv *Server) addPeerChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
// Drop connections with no matching protocols.
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
return DiscUselessPeer
}
// Repeat the post-handshake checks because the
// peer set might have changed since those checks were performed.
return srv.postHandshakeChecks(peers, inboundCount, c)
}
func (srv *Server) maxInboundConns() int { func (srv *Server) maxInboundConns() int {
return srv.MaxPeers - srv.maxDialedConns() return srv.MaxPeers - srv.maxDialedConns()
} }
func (srv *Server) maxDialedConns() int { func (srv *Server) maxDialedConns() int {
if srv.NoDiscovery || srv.NoDial { if srv.NoDiscovery || srv.NoDial {
return 0 return 0
@ -832,7 +846,7 @@ func (srv *Server) listenLoop() {
} }
for { for {
// Wait for a handshake slot before accepting. // Wait for a free slot before accepting.
<-slots <-slots
var ( var (
@ -851,21 +865,16 @@ func (srv *Server) listenLoop() {
break break
} }
// Reject connections that do not match NetRestrict. remoteIP := netutil.AddrIP(fd.RemoteAddr())
if srv.NetRestrict != nil { if err := srv.checkInboundConn(fd, remoteIP); err != nil {
if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) { srv.log.Debug("Rejected inbound connnection", "addr", fd.RemoteAddr(), "err", err)
srv.log.Debug("Rejected conn (not whitelisted in NetRestrict)", "addr", fd.RemoteAddr()) fd.Close()
fd.Close() slots <- struct{}{}
slots <- struct{}{} continue
continue
}
} }
if remoteIP != nil {
var ip net.IP fd = newMeteredConn(fd, true, remoteIP)
if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok {
ip = tcp.IP
} }
fd = newMeteredConn(fd, true, ip)
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr()) srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
go func() { go func() {
srv.SetupConn(fd, inboundConn, nil) srv.SetupConn(fd, inboundConn, nil)
@ -874,6 +883,22 @@ func (srv *Server) listenLoop() {
} }
} }
func (srv *Server) checkInboundConn(fd net.Conn, remoteIP net.IP) error {
if remoteIP != nil {
// Reject connections that do not match NetRestrict.
if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) {
return fmt.Errorf("not whitelisted in NetRestrict")
}
// Reject Internet peers that try too often.
srv.inboundHistory.expire(time.Now())
if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) {
return fmt.Errorf("too many attempts")
}
srv.inboundHistory.add(remoteIP.String(), time.Now().Add(inboundThrottleTime))
}
return nil
}
// SetupConn runs the handshakes and attempts to add the connection // SetupConn runs the handshakes and attempts to add the connection
// as a peer. It returns when the connection has been added as a peer // as a peer. It returns when the connection has been added as a peer
// or the handshakes have failed. // or the handshakes have failed.
@ -895,6 +920,7 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
if !running { if !running {
return errServerStopped return errServerStopped
} }
// If dialing, figure out the remote public key. // If dialing, figure out the remote public key.
var dialPubkey *ecdsa.PublicKey var dialPubkey *ecdsa.PublicKey
if dialDest != nil { if dialDest != nil {
@ -903,7 +929,8 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
return errors.New("dial destination doesn't have a secp256k1 public key") return errors.New("dial destination doesn't have a secp256k1 public key")
} }
} }
// Run the encryption handshake.
// Run the RLPx handshake.
remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey) remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey)
if err != nil { if err != nil {
srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err) srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err)
@ -922,12 +949,13 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
conn.handshakeDone(c.node.ID()) conn.handshakeDone(c.node.ID())
} }
clog := srv.log.New("id", c.node.ID(), "addr", c.fd.RemoteAddr(), "conn", c.flags) clog := srv.log.New("id", c.node.ID(), "addr", c.fd.RemoteAddr(), "conn", c.flags)
err = srv.checkpoint(c, srv.posthandshake) err = srv.checkpoint(c, srv.checkpointPostHandshake)
if err != nil { if err != nil {
clog.Trace("Rejected peer before protocol handshake", "err", err) clog.Trace("Rejected peer", "err", err)
return err return err
} }
// Run the protocol handshake
// Run the capability negotiation handshake.
phs, err := c.doProtoHandshake(srv.ourHandshake) phs, err := c.doProtoHandshake(srv.ourHandshake)
if err != nil { if err != nil {
clog.Trace("Failed proto handshake", "err", err) clog.Trace("Failed proto handshake", "err", err)
@ -938,14 +966,15 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
return DiscUnexpectedIdentity return DiscUnexpectedIdentity
} }
c.caps, c.name = phs.Caps, phs.Name c.caps, c.name = phs.Caps, phs.Name
err = srv.checkpoint(c, srv.addpeer) err = srv.checkpoint(c, srv.checkpointAddPeer)
if err != nil { if err != nil {
clog.Trace("Rejected peer", "err", err) clog.Trace("Rejected peer", "err", err)
return err return err
} }
// If the checks completed successfully, runPeer has now been
// launched by run. // If the checks completed successfully, the connection has been added as a peer and
clog.Trace("connection set up", "inbound", dialDest == nil) // runPeer has been launched.
clog.Trace("Connection set up", "inbound", dialDest == nil)
return nil return nil
} }
@ -974,12 +1003,7 @@ func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error {
case <-srv.quit: case <-srv.quit:
return errServerStopped return errServerStopped
} }
select { return <-c.cont
case err := <-c.cont:
return err
case <-srv.quit:
return errServerStopped
}
} }
// runPeer runs in its own goroutine for each peer. // runPeer runs in its own goroutine for each peer.

View File

@ -19,6 +19,7 @@ package p2p
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"errors" "errors"
"io"
"math/rand" "math/rand"
"net" "net"
"reflect" "reflect"
@ -26,6 +27,7 @@ import (
"time" "time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/internal/testlog"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/enr"
@ -74,6 +76,7 @@ func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *
MaxPeers: 10, MaxPeers: 10,
ListenAddr: "127.0.0.1:0", ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(), PrivateKey: newkey(),
Logger: testlog.Logger(t, log.LvlTrace),
} }
server := &Server{ server := &Server{
Config: config, Config: config,
@ -359,6 +362,7 @@ func TestServerAtCap(t *testing.T) {
PrivateKey: newkey(), PrivateKey: newkey(),
MaxPeers: 10, MaxPeers: 10,
NoDial: true, NoDial: true,
NoDiscovery: true,
TrustedNodes: []*enode.Node{newNode(trustedID, nil)}, TrustedNodes: []*enode.Node{newNode(trustedID, nil)},
}, },
} }
@ -377,19 +381,19 @@ func TestServerAtCap(t *testing.T) {
// Inject a few connections to fill up the peer set. // Inject a few connections to fill up the peer set.
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
c := newconn(randomID()) c := newconn(randomID())
if err := srv.checkpoint(c, srv.addpeer); err != nil { if err := srv.checkpoint(c, srv.checkpointAddPeer); err != nil {
t.Fatalf("could not add conn %d: %v", i, err) t.Fatalf("could not add conn %d: %v", i, err)
} }
} }
// Try inserting a non-trusted connection. // Try inserting a non-trusted connection.
anotherID := randomID() anotherID := randomID()
c := newconn(anotherID) c := newconn(anotherID)
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
t.Error("wrong error for insert:", err) t.Error("wrong error for insert:", err)
} }
// Try inserting a trusted connection. // Try inserting a trusted connection.
c = newconn(trustedID) c = newconn(trustedID)
if err := srv.checkpoint(c, srv.posthandshake); err != nil { if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
t.Error("unexpected error for trusted conn @posthandshake:", err) t.Error("unexpected error for trusted conn @posthandshake:", err)
} }
if !c.is(trustedConn) { if !c.is(trustedConn) {
@ -399,14 +403,14 @@ func TestServerAtCap(t *testing.T) {
// Remove from trusted set and try again // Remove from trusted set and try again
srv.RemoveTrustedPeer(newNode(trustedID, nil)) srv.RemoveTrustedPeer(newNode(trustedID, nil))
c = newconn(trustedID) c = newconn(trustedID)
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
t.Error("wrong error for insert:", err) t.Error("wrong error for insert:", err)
} }
// Add anotherID to trusted set and try again // Add anotherID to trusted set and try again
srv.AddTrustedPeer(newNode(anotherID, nil)) srv.AddTrustedPeer(newNode(anotherID, nil))
c = newconn(anotherID) c = newconn(anotherID)
if err := srv.checkpoint(c, srv.posthandshake); err != nil { if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
t.Error("unexpected error for trusted conn @posthandshake:", err) t.Error("unexpected error for trusted conn @posthandshake:", err)
} }
if !c.is(trustedConn) { if !c.is(trustedConn) {
@ -430,10 +434,11 @@ func TestServerPeerLimits(t *testing.T) {
srv := &Server{ srv := &Server{
Config: Config{ Config: Config{
PrivateKey: srvkey, PrivateKey: srvkey,
MaxPeers: 0, MaxPeers: 0,
NoDial: true, NoDial: true,
Protocols: []Protocol{discard}, NoDiscovery: true,
Protocols: []Protocol{discard},
}, },
newTransport: func(fd net.Conn) transport { return tp }, newTransport: func(fd net.Conn) transport { return tp },
log: log.New(), log: log.New(),
@ -541,29 +546,35 @@ func TestServerSetupConn(t *testing.T) {
} }
for i, test := range tests { for i, test := range tests {
srv := &Server{ t.Run(test.wantCalls, func(t *testing.T) {
Config: Config{ cfg := Config{
PrivateKey: srvkey, PrivateKey: srvkey,
MaxPeers: 10, MaxPeers: 10,
NoDial: true, NoDial: true,
Protocols: []Protocol{discard}, NoDiscovery: true,
}, Protocols: []Protocol{discard},
newTransport: func(fd net.Conn) transport { return test.tt }, Logger: testlog.Logger(t, log.LvlTrace),
log: log.New(),
}
if !test.dontstart {
if err := srv.Start(); err != nil {
t.Fatalf("couldn't start server: %v", err)
} }
} srv := &Server{
p1, _ := net.Pipe() Config: cfg,
srv.SetupConn(p1, test.flags, test.dialDest) newTransport: func(fd net.Conn) transport { return test.tt },
if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { log: cfg.Logger,
t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) }
} if !test.dontstart {
if test.tt.calls != test.wantCalls { if err := srv.Start(); err != nil {
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) t.Fatalf("couldn't start server: %v", err)
} }
defer srv.Stop()
}
p1, _ := net.Pipe()
srv.SetupConn(p1, test.flags, test.dialDest)
if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
}
if test.tt.calls != test.wantCalls {
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
}
})
} }
} }
@ -616,3 +627,100 @@ func randomID() (id enode.ID) {
} }
return id return id
} }
// This test checks that inbound connections are throttled by IP.
func TestServerInboundThrottle(t *testing.T) {
const timeout = 5 * time.Second
newTransportCalled := make(chan struct{})
srv := &Server{
Config: Config{
PrivateKey: newkey(),
ListenAddr: "127.0.0.1:0",
MaxPeers: 10,
NoDial: true,
NoDiscovery: true,
Protocols: []Protocol{discard},
Logger: testlog.Logger(t, log.LvlTrace),
},
newTransport: func(fd net.Conn) transport {
newTransportCalled <- struct{}{}
return newRLPX(fd)
},
listenFunc: func(network, laddr string) (net.Listener, error) {
fakeAddr := &net.TCPAddr{IP: net.IP{95, 33, 21, 2}, Port: 4444}
return listenFakeAddr(network, laddr, fakeAddr)
},
}
if err := srv.Start(); err != nil {
t.Fatal("can't start: ", err)
}
defer srv.Stop()
// Dial the test server.
conn, err := net.DialTimeout("tcp", srv.ListenAddr, timeout)
if err != nil {
t.Fatalf("could not dial: %v", err)
}
select {
case <-newTransportCalled:
// OK
case <-time.After(timeout):
t.Error("newTransport not called")
}
conn.Close()
// Dial again. This time the server should close the connection immediately.
connClosed := make(chan struct{})
conn, err = net.DialTimeout("tcp", srv.ListenAddr, timeout)
if err != nil {
t.Fatalf("could not dial: %v", err)
}
defer conn.Close()
go func() {
conn.SetDeadline(time.Now().Add(timeout))
buf := make([]byte, 10)
if n, err := conn.Read(buf); err != io.EOF || n != 0 {
t.Errorf("expected io.EOF and n == 0, got error %q and n == %d", err, n)
}
connClosed <- struct{}{}
}()
select {
case <-connClosed:
// OK
case <-newTransportCalled:
t.Error("newTransport called for second attempt")
case <-time.After(timeout):
t.Error("connection not closed within timeout")
}
}
func listenFakeAddr(network, laddr string, remoteAddr net.Addr) (net.Listener, error) {
l, err := net.Listen(network, laddr)
if err == nil {
l = &fakeAddrListener{l, remoteAddr}
}
return l, err
}
// fakeAddrListener is a listener that creates connections with a mocked remote address.
type fakeAddrListener struct {
net.Listener
remoteAddr net.Addr
}
type fakeAddrConn struct {
net.Conn
remoteAddr net.Addr
}
func (l *fakeAddrListener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &fakeAddrConn{c, l.remoteAddr}, nil
}
func (c *fakeAddrConn) RemoteAddr() net.Addr {
return c.remoteAddr
}

82
p2p/util.go Normal file
View File

@ -0,0 +1,82 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package p2p
import (
"container/heap"
"time"
)
// expHeap tracks strings and their expiry time.
type expHeap []expItem
// expItem is an entry in addrHistory.
type expItem struct {
item string
exp time.Time
}
// nextExpiry returns the next expiry time.
func (h *expHeap) nextExpiry() time.Time {
return (*h)[0].exp
}
// add adds an item and sets its expiry time.
func (h *expHeap) add(item string, exp time.Time) {
heap.Push(h, expItem{item, exp})
}
// remove removes an item.
func (h *expHeap) remove(item string) bool {
for i, v := range *h {
if v.item == item {
heap.Remove(h, i)
return true
}
}
return false
}
// contains checks whether an item is present.
func (h expHeap) contains(item string) bool {
for _, v := range h {
if v.item == item {
return true
}
}
return false
}
// expire removes items with expiry time before 'now'.
func (h *expHeap) expire(now time.Time) {
for h.Len() > 0 && h.nextExpiry().Before(now) {
heap.Pop(h)
}
}
// heap.Interface boilerplate
func (h expHeap) Len() int { return len(h) }
func (h expHeap) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
func (h expHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *expHeap) Push(x interface{}) { *h = append(*h, x.(expItem)) }
func (h *expHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}

54
p2p/util_test.go Normal file
View File

@ -0,0 +1,54 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package p2p
import (
"testing"
"time"
)
func TestExpHeap(t *testing.T) {
var h expHeap
var (
basetime = time.Unix(4000, 0)
exptimeA = basetime.Add(2 * time.Second)
exptimeB = basetime.Add(3 * time.Second)
exptimeC = basetime.Add(4 * time.Second)
)
h.add("a", exptimeA)
h.add("b", exptimeB)
h.add("c", exptimeC)
if !h.nextExpiry().Equal(exptimeA) {
t.Fatal("wrong nextExpiry")
}
if !h.contains("a") || !h.contains("b") || !h.contains("c") {
t.Fatal("heap doesn't contain all live items")
}
h.expire(exptimeA.Add(1))
if !h.nextExpiry().Equal(exptimeB) {
t.Fatal("wrong nextExpiry")
}
if h.contains("a") {
t.Fatal("heap contains a even though it has already expired")
}
if !h.contains("b") || !h.contains("c") {
t.Fatal("heap doesn't contain all live items")
}
}