p2p, p2p/discover: misc connectivity improvements (#16069)
* p2p: add DialRatio for configuration of inbound vs. dialed connections * p2p: add connection flags to PeerInfo * p2p/netutil: add SameNet, DistinctNetSet * p2p/discover: improve revalidation and seeding This changes node revalidation to be periodic instead of on-demand. This should prevent issues where dead nodes get stuck in closer buckets because no other node will ever come along to replace them. Every 5 seconds (on average), the last node in a random bucket is checked and moved to the front of the bucket if it is still responding. If revalidation fails, the last node is replaced by an entry of the 'replacement list' containing recently-seen nodes. Most close buckets are removed because it's very unlikely we'll ever encounter a node that would fall into any of those buckets. Table seeding is also improved: we now require a few minutes of table membership before considering a node as a potential seed node. This should make it less likely to store short-lived nodes as potential seeds. * p2p/discover: fix nits in UDP transport We would skip sending neighbors replies if there were fewer than maxNeighbors results and CheckRelayIP returned an error for the last one. While here, also resolve a TODO about pong reply tokens.
This commit is contained in:
parent
1d39912a9b
commit
9123eceb0f
@ -122,7 +122,12 @@ func main() {
|
||||
utils.Fatalf("%v", err)
|
||||
}
|
||||
} else {
|
||||
if _, err := discover.ListenUDP(nodeKey, conn, realaddr, nil, "", restrictList); err != nil {
|
||||
cfg := discover.Config{
|
||||
PrivateKey: nodeKey,
|
||||
AnnounceAddr: realaddr,
|
||||
NetRestrict: restrictList,
|
||||
}
|
||||
if _, err := discover.ListenUDP(conn, cfg); err != nil {
|
||||
utils.Fatalf("%v", err)
|
||||
}
|
||||
}
|
||||
|
@ -29,6 +29,7 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
@ -51,9 +52,8 @@ type Node struct {
|
||||
// with ID.
|
||||
sha common.Hash
|
||||
|
||||
// whether this node is currently being pinged in order to replace
|
||||
// it in a bucket
|
||||
contested bool
|
||||
// Time when the node was added to the table.
|
||||
addedAt time.Time
|
||||
}
|
||||
|
||||
// NewNode creates a new node. It is mostly meant to be used for
|
||||
|
@ -23,10 +23,11 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
crand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
@ -35,29 +36,45 @@ import (
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
"github.com/ethereum/go-ethereum/p2p/netutil"
|
||||
)
|
||||
|
||||
const (
|
||||
alpha = 3 // Kademlia concurrency factor
|
||||
bucketSize = 16 // Kademlia bucket size
|
||||
hashBits = len(common.Hash{}) * 8
|
||||
nBuckets = hashBits + 1 // Number of buckets
|
||||
alpha = 3 // Kademlia concurrency factor
|
||||
bucketSize = 16 // Kademlia bucket size
|
||||
maxReplacements = 10 // Size of per-bucket replacement list
|
||||
|
||||
maxBondingPingPongs = 16
|
||||
maxFindnodeFailures = 5
|
||||
// We keep buckets for the upper 1/15 of distances because
|
||||
// it's very unlikely we'll ever encounter a node that's closer.
|
||||
hashBits = len(common.Hash{}) * 8
|
||||
nBuckets = hashBits / 15 // Number of buckets
|
||||
bucketMinDistance = hashBits - nBuckets // Log distance of closest bucket
|
||||
|
||||
autoRefreshInterval = 1 * time.Hour
|
||||
seedCount = 30
|
||||
seedMaxAge = 5 * 24 * time.Hour
|
||||
// IP address limits.
|
||||
bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24
|
||||
tableIPLimit, tableSubnet = 10, 24
|
||||
|
||||
maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions
|
||||
maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped
|
||||
|
||||
refreshInterval = 30 * time.Minute
|
||||
revalidateInterval = 10 * time.Second
|
||||
copyNodesInterval = 30 * time.Second
|
||||
seedMinTableTime = 5 * time.Minute
|
||||
seedCount = 30
|
||||
seedMaxAge = 5 * 24 * time.Hour
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
mutex sync.Mutex // protects buckets, their content, and nursery
|
||||
mutex sync.Mutex // protects buckets, bucket content, nursery, rand
|
||||
buckets [nBuckets]*bucket // index of known nodes by distance
|
||||
nursery []*Node // bootstrap nodes
|
||||
db *nodeDB // database of known nodes
|
||||
rand *mrand.Rand // source of randomness, periodically reseeded
|
||||
ips netutil.DistinctNetSet
|
||||
|
||||
db *nodeDB // database of known nodes
|
||||
refreshReq chan chan struct{}
|
||||
initDone chan struct{}
|
||||
closeReq chan struct{}
|
||||
closed chan struct{}
|
||||
|
||||
@ -89,9 +106,13 @@ type transport interface {
|
||||
|
||||
// bucket contains nodes, ordered by their last activity. the entry
|
||||
// that was most recently active is the first element in entries.
|
||||
type bucket struct{ entries []*Node }
|
||||
type bucket struct {
|
||||
entries []*Node // live entries, sorted by time of last contact
|
||||
replacements []*Node // recently seen nodes to be used if revalidation fails
|
||||
ips netutil.DistinctNetSet
|
||||
}
|
||||
|
||||
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string) (*Table, error) {
|
||||
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) {
|
||||
// If no node database was given, use an in-memory one
|
||||
db, err := newNodeDB(nodeDBPath, Version, ourID)
|
||||
if err != nil {
|
||||
@ -104,19 +125,42 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
|
||||
bonding: make(map[NodeID]*bondproc),
|
||||
bondslots: make(chan struct{}, maxBondingPingPongs),
|
||||
refreshReq: make(chan chan struct{}),
|
||||
initDone: make(chan struct{}),
|
||||
closeReq: make(chan struct{}),
|
||||
closed: make(chan struct{}),
|
||||
rand: mrand.New(mrand.NewSource(0)),
|
||||
ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit},
|
||||
}
|
||||
if err := tab.setFallbackNodes(bootnodes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := 0; i < cap(tab.bondslots); i++ {
|
||||
tab.bondslots <- struct{}{}
|
||||
}
|
||||
for i := range tab.buckets {
|
||||
tab.buckets[i] = new(bucket)
|
||||
tab.buckets[i] = &bucket{
|
||||
ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit},
|
||||
}
|
||||
}
|
||||
go tab.refreshLoop()
|
||||
tab.seedRand()
|
||||
tab.loadSeedNodes(false)
|
||||
// Start the background expiration goroutine after loading seeds so that the search for
|
||||
// seed nodes also considers older nodes that would otherwise be removed by the
|
||||
// expiration.
|
||||
tab.db.ensureExpirer()
|
||||
go tab.loop()
|
||||
return tab, nil
|
||||
}
|
||||
|
||||
func (tab *Table) seedRand() {
|
||||
var b [8]byte
|
||||
crand.Read(b[:])
|
||||
|
||||
tab.mutex.Lock()
|
||||
tab.rand.Seed(int64(binary.BigEndian.Uint64(b[:])))
|
||||
tab.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Self returns the local node.
|
||||
// The returned node should not be modified by the caller.
|
||||
func (tab *Table) Self() *Node {
|
||||
@ -127,9 +171,12 @@ func (tab *Table) Self() *Node {
|
||||
// table. It will not write the same node more than once. The nodes in
|
||||
// the slice are copies and can be modified by the caller.
|
||||
func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
|
||||
if !tab.isInitDone() {
|
||||
return 0
|
||||
}
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
// TODO: tree-based buckets would help here
|
||||
|
||||
// Find all non-empty buckets and get a fresh slice of their entries.
|
||||
var buckets [][]*Node
|
||||
for _, b := range tab.buckets {
|
||||
@ -141,8 +188,8 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
|
||||
return 0
|
||||
}
|
||||
// Shuffle the buckets.
|
||||
for i := uint32(len(buckets)) - 1; i > 0; i-- {
|
||||
j := randUint(i)
|
||||
for i := len(buckets) - 1; i > 0; i-- {
|
||||
j := tab.rand.Intn(len(buckets))
|
||||
buckets[i], buckets[j] = buckets[j], buckets[i]
|
||||
}
|
||||
// Move head of each bucket into buf, removing buckets that become empty.
|
||||
@ -161,15 +208,6 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
|
||||
return i + 1
|
||||
}
|
||||
|
||||
func randUint(max uint32) uint32 {
|
||||
if max == 0 {
|
||||
return 0
|
||||
}
|
||||
var b [4]byte
|
||||
rand.Read(b[:])
|
||||
return binary.BigEndian.Uint32(b[:]) % max
|
||||
}
|
||||
|
||||
// Close terminates the network listener and flushes the node database.
|
||||
func (tab *Table) Close() {
|
||||
select {
|
||||
@ -180,16 +218,15 @@ func (tab *Table) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
// SetFallbackNodes sets the initial points of contact. These nodes
|
||||
// setFallbackNodes sets the initial points of contact. These nodes
|
||||
// are used to connect to the network if the table is empty and there
|
||||
// are no known nodes in the database.
|
||||
func (tab *Table) SetFallbackNodes(nodes []*Node) error {
|
||||
func (tab *Table) setFallbackNodes(nodes []*Node) error {
|
||||
for _, n := range nodes {
|
||||
if err := n.validateComplete(); err != nil {
|
||||
return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err)
|
||||
}
|
||||
}
|
||||
tab.mutex.Lock()
|
||||
tab.nursery = make([]*Node, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
cpy := *n
|
||||
@ -198,11 +235,19 @@ func (tab *Table) SetFallbackNodes(nodes []*Node) error {
|
||||
cpy.sha = crypto.Keccak256Hash(n.ID[:])
|
||||
tab.nursery = append(tab.nursery, &cpy)
|
||||
}
|
||||
tab.mutex.Unlock()
|
||||
tab.refresh()
|
||||
return nil
|
||||
}
|
||||
|
||||
// isInitDone returns whether the table's initial seeding procedure has completed.
|
||||
func (tab *Table) isInitDone() bool {
|
||||
select {
|
||||
case <-tab.initDone:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve searches for a specific node with the given ID.
|
||||
// It returns nil if the node could not be found.
|
||||
func (tab *Table) Resolve(targetID NodeID) *Node {
|
||||
@ -314,33 +359,49 @@ func (tab *Table) refresh() <-chan struct{} {
|
||||
return done
|
||||
}
|
||||
|
||||
// refreshLoop schedules doRefresh runs and coordinates shutdown.
|
||||
func (tab *Table) refreshLoop() {
|
||||
// loop schedules refresh, revalidate runs and coordinates shutdown.
|
||||
func (tab *Table) loop() {
|
||||
var (
|
||||
timer = time.NewTicker(autoRefreshInterval)
|
||||
waiting []chan struct{} // accumulates waiting callers while doRefresh runs
|
||||
done chan struct{} // where doRefresh reports completion
|
||||
revalidate = time.NewTimer(tab.nextRevalidateTime())
|
||||
refresh = time.NewTicker(refreshInterval)
|
||||
copyNodes = time.NewTicker(copyNodesInterval)
|
||||
revalidateDone = make(chan struct{})
|
||||
refreshDone = make(chan struct{}) // where doRefresh reports completion
|
||||
waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs
|
||||
)
|
||||
defer refresh.Stop()
|
||||
defer revalidate.Stop()
|
||||
defer copyNodes.Stop()
|
||||
|
||||
// Start initial refresh.
|
||||
go tab.doRefresh(refreshDone)
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
if done == nil {
|
||||
done = make(chan struct{})
|
||||
go tab.doRefresh(done)
|
||||
case <-refresh.C:
|
||||
tab.seedRand()
|
||||
if refreshDone == nil {
|
||||
refreshDone = make(chan struct{})
|
||||
go tab.doRefresh(refreshDone)
|
||||
}
|
||||
case req := <-tab.refreshReq:
|
||||
waiting = append(waiting, req)
|
||||
if done == nil {
|
||||
done = make(chan struct{})
|
||||
go tab.doRefresh(done)
|
||||
if refreshDone == nil {
|
||||
refreshDone = make(chan struct{})
|
||||
go tab.doRefresh(refreshDone)
|
||||
}
|
||||
case <-done:
|
||||
case <-refreshDone:
|
||||
for _, ch := range waiting {
|
||||
close(ch)
|
||||
}
|
||||
waiting = nil
|
||||
done = nil
|
||||
waiting, refreshDone = nil, nil
|
||||
case <-revalidate.C:
|
||||
go tab.doRevalidate(revalidateDone)
|
||||
case <-revalidateDone:
|
||||
revalidate.Reset(tab.nextRevalidateTime())
|
||||
case <-copyNodes.C:
|
||||
go tab.copyBondedNodes()
|
||||
case <-tab.closeReq:
|
||||
break loop
|
||||
}
|
||||
@ -349,8 +410,8 @@ loop:
|
||||
if tab.net != nil {
|
||||
tab.net.close()
|
||||
}
|
||||
if done != nil {
|
||||
<-done
|
||||
if refreshDone != nil {
|
||||
<-refreshDone
|
||||
}
|
||||
for _, ch := range waiting {
|
||||
close(ch)
|
||||
@ -365,38 +426,109 @@ loop:
|
||||
func (tab *Table) doRefresh(done chan struct{}) {
|
||||
defer close(done)
|
||||
|
||||
// Load nodes from the database and insert
|
||||
// them. This should yield a few previously seen nodes that are
|
||||
// (hopefully) still alive.
|
||||
tab.loadSeedNodes(true)
|
||||
|
||||
// Run self lookup to discover new neighbor nodes.
|
||||
tab.lookup(tab.self.ID, false)
|
||||
|
||||
// The Kademlia paper specifies that the bucket refresh should
|
||||
// perform a lookup in the least recently used bucket. We cannot
|
||||
// adhere to this because the findnode target is a 512bit value
|
||||
// (not hash-sized) and it is not easily possible to generate a
|
||||
// sha3 preimage that falls into a chosen bucket.
|
||||
// We perform a lookup with a random target instead.
|
||||
var target NodeID
|
||||
rand.Read(target[:])
|
||||
result := tab.lookup(target, false)
|
||||
if len(result) > 0 {
|
||||
// We perform a few lookups with a random target instead.
|
||||
for i := 0; i < 3; i++ {
|
||||
var target NodeID
|
||||
crand.Read(target[:])
|
||||
tab.lookup(target, false)
|
||||
}
|
||||
}
|
||||
|
||||
func (tab *Table) loadSeedNodes(bond bool) {
|
||||
seeds := tab.db.querySeeds(seedCount, seedMaxAge)
|
||||
seeds = append(seeds, tab.nursery...)
|
||||
if bond {
|
||||
seeds = tab.bondall(seeds)
|
||||
}
|
||||
for i := range seeds {
|
||||
seed := seeds[i]
|
||||
age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.lastPong(seed.ID)) }}
|
||||
log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age)
|
||||
tab.add(seed)
|
||||
}
|
||||
}
|
||||
|
||||
// doRevalidate checks that the last node in a random bucket is still live
|
||||
// and replaces or deletes the node if it isn't.
|
||||
func (tab *Table) doRevalidate(done chan<- struct{}) {
|
||||
defer func() { done <- struct{}{} }()
|
||||
|
||||
last, bi := tab.nodeToRevalidate()
|
||||
if last == nil {
|
||||
// No non-empty bucket found.
|
||||
return
|
||||
}
|
||||
|
||||
// The table is empty. Load nodes from the database and insert
|
||||
// them. This should yield a few previously seen nodes that are
|
||||
// (hopefully) still alive.
|
||||
seeds := tab.db.querySeeds(seedCount, seedMaxAge)
|
||||
seeds = tab.bondall(append(seeds, tab.nursery...))
|
||||
// Ping the selected node and wait for a pong.
|
||||
err := tab.ping(last.ID, last.addr())
|
||||
|
||||
if len(seeds) == 0 {
|
||||
log.Debug("No discv4 seed nodes found")
|
||||
}
|
||||
for _, n := range seeds {
|
||||
age := log.Lazy{Fn: func() time.Duration { return time.Since(tab.db.lastPong(n.ID)) }}
|
||||
log.Trace("Found seed node in database", "id", n.ID, "addr", n.addr(), "age", age)
|
||||
}
|
||||
tab.mutex.Lock()
|
||||
tab.stuff(seeds)
|
||||
tab.mutex.Unlock()
|
||||
defer tab.mutex.Unlock()
|
||||
b := tab.buckets[bi]
|
||||
if err == nil {
|
||||
// The node responded, move it to the front.
|
||||
log.Debug("Revalidated node", "b", bi, "id", last.ID)
|
||||
b.bump(last)
|
||||
return
|
||||
}
|
||||
// No reply received, pick a replacement or delete the node if there aren't
|
||||
// any replacements.
|
||||
if r := tab.replace(b, last); r != nil {
|
||||
log.Debug("Replaced dead node", "b", bi, "id", last.ID, "ip", last.IP, "r", r.ID, "rip", r.IP)
|
||||
} else {
|
||||
log.Debug("Removed dead node", "b", bi, "id", last.ID, "ip", last.IP)
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, do a self lookup to fill up the buckets.
|
||||
tab.lookup(tab.self.ID, false)
|
||||
// nodeToRevalidate returns the last node in a random, non-empty bucket.
|
||||
func (tab *Table) nodeToRevalidate() (n *Node, bi int) {
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
|
||||
for _, bi = range tab.rand.Perm(len(tab.buckets)) {
|
||||
b := tab.buckets[bi]
|
||||
if len(b.entries) > 0 {
|
||||
last := b.entries[len(b.entries)-1]
|
||||
return last, bi
|
||||
}
|
||||
}
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
func (tab *Table) nextRevalidateTime() time.Duration {
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
|
||||
return time.Duration(tab.rand.Int63n(int64(revalidateInterval)))
|
||||
}
|
||||
|
||||
// copyBondedNodes adds nodes from the table to the database if they have been in the table
|
||||
// longer then minTableTime.
|
||||
func (tab *Table) copyBondedNodes() {
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for _, b := range tab.buckets {
|
||||
for _, n := range b.entries {
|
||||
if now.Sub(n.addedAt) >= seedMinTableTime {
|
||||
tab.db.updateNode(n)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closest returns the n nodes in the table that are closest to the
|
||||
@ -459,15 +591,14 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
|
||||
if id == tab.self.ID {
|
||||
return nil, errors.New("is self")
|
||||
}
|
||||
// Retrieve a previously known node and any recent findnode failures
|
||||
node, fails := tab.db.node(id), 0
|
||||
if node != nil {
|
||||
fails = tab.db.findFails(id)
|
||||
if pinged && !tab.isInitDone() {
|
||||
return nil, errors.New("still initializing")
|
||||
}
|
||||
// If the node is unknown (non-bonded) or failed (remotely unknown), bond from scratch
|
||||
var result error
|
||||
// Start bonding if we haven't seen this node for a while or if it failed findnode too often.
|
||||
node, fails := tab.db.node(id), tab.db.findFails(id)
|
||||
age := time.Since(tab.db.lastPong(id))
|
||||
if node == nil || fails > 0 || age > nodeDBNodeExpiration {
|
||||
var result error
|
||||
if fails > 0 || age > nodeDBNodeExpiration {
|
||||
log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age)
|
||||
|
||||
tab.bondmu.Lock()
|
||||
@ -494,10 +625,10 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
|
||||
node = w.n
|
||||
}
|
||||
}
|
||||
// Add the node to the table even if the bonding ping/pong
|
||||
// fails. It will be relaced quickly if it continues to be
|
||||
// unresponsive.
|
||||
if node != nil {
|
||||
// Add the node to the table even if the bonding ping/pong
|
||||
// fails. It will be relaced quickly if it continues to be
|
||||
// unresponsive.
|
||||
tab.add(node)
|
||||
tab.db.updateFindFails(id, 0)
|
||||
}
|
||||
@ -522,7 +653,6 @@ func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAdd
|
||||
}
|
||||
// Bonding succeeded, update the node database.
|
||||
w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort)
|
||||
tab.db.updateNode(w.n)
|
||||
close(w.done)
|
||||
}
|
||||
|
||||
@ -534,16 +664,18 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
|
||||
return err
|
||||
}
|
||||
tab.db.updateLastPong(id, time.Now())
|
||||
|
||||
// Start the background expiration goroutine after the first
|
||||
// successful communication. Subsequent calls have no effect if it
|
||||
// is already running. We do this here instead of somewhere else
|
||||
// so that the search for seed nodes also considers older nodes
|
||||
// that would otherwise be removed by the expiration.
|
||||
tab.db.ensureExpirer()
|
||||
return nil
|
||||
}
|
||||
|
||||
// bucket returns the bucket for the given node ID hash.
|
||||
func (tab *Table) bucket(sha common.Hash) *bucket {
|
||||
d := logdist(tab.self.sha, sha)
|
||||
if d <= bucketMinDistance {
|
||||
return tab.buckets[0]
|
||||
}
|
||||
return tab.buckets[d-bucketMinDistance-1]
|
||||
}
|
||||
|
||||
// add attempts to add the given node its corresponding bucket. If the
|
||||
// bucket has space available, adding the node succeeds immediately.
|
||||
// Otherwise, the node is added if the least recently active node in
|
||||
@ -551,57 +683,29 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
|
||||
//
|
||||
// The caller must not hold tab.mutex.
|
||||
func (tab *Table) add(new *Node) {
|
||||
b := tab.buckets[logdist(tab.self.sha, new.sha)]
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
if b.bump(new) {
|
||||
return
|
||||
}
|
||||
var oldest *Node
|
||||
if len(b.entries) == bucketSize {
|
||||
oldest = b.entries[bucketSize-1]
|
||||
if oldest.contested {
|
||||
// The node is already being replaced, don't attempt
|
||||
// to replace it.
|
||||
return
|
||||
}
|
||||
oldest.contested = true
|
||||
// Let go of the mutex so other goroutines can access
|
||||
// the table while we ping the least recently active node.
|
||||
tab.mutex.Unlock()
|
||||
err := tab.ping(oldest.ID, oldest.addr())
|
||||
tab.mutex.Lock()
|
||||
oldest.contested = false
|
||||
if err == nil {
|
||||
// The node responded, don't replace it.
|
||||
return
|
||||
}
|
||||
}
|
||||
added := b.replace(new, oldest)
|
||||
if added && tab.nodeAddedHook != nil {
|
||||
tab.nodeAddedHook(new)
|
||||
|
||||
b := tab.bucket(new.sha)
|
||||
if !tab.bumpOrAdd(b, new) {
|
||||
// Node is not in table. Add it to the replacement list.
|
||||
tab.addReplacement(b, new)
|
||||
}
|
||||
}
|
||||
|
||||
// stuff adds nodes the table to the end of their corresponding bucket
|
||||
// if the bucket is not full. The caller must hold tab.mutex.
|
||||
// if the bucket is not full. The caller must not hold tab.mutex.
|
||||
func (tab *Table) stuff(nodes []*Node) {
|
||||
outer:
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
|
||||
for _, n := range nodes {
|
||||
if n.ID == tab.self.ID {
|
||||
continue // don't add self
|
||||
}
|
||||
bucket := tab.buckets[logdist(tab.self.sha, n.sha)]
|
||||
for i := range bucket.entries {
|
||||
if bucket.entries[i].ID == n.ID {
|
||||
continue outer // already in bucket
|
||||
}
|
||||
}
|
||||
if len(bucket.entries) < bucketSize {
|
||||
bucket.entries = append(bucket.entries, n)
|
||||
if tab.nodeAddedHook != nil {
|
||||
tab.nodeAddedHook(n)
|
||||
}
|
||||
b := tab.bucket(n.sha)
|
||||
if len(b.entries) < bucketSize {
|
||||
tab.bumpOrAdd(b, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -611,36 +715,72 @@ outer:
|
||||
func (tab *Table) delete(node *Node) {
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
bucket := tab.buckets[logdist(tab.self.sha, node.sha)]
|
||||
for i := range bucket.entries {
|
||||
if bucket.entries[i].ID == node.ID {
|
||||
bucket.entries = append(bucket.entries[:i], bucket.entries[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tab.deleteInBucket(tab.bucket(node.sha), node)
|
||||
}
|
||||
|
||||
func (b *bucket) replace(n *Node, last *Node) bool {
|
||||
// Don't add if b already contains n.
|
||||
for i := range b.entries {
|
||||
if b.entries[i].ID == n.ID {
|
||||
return false
|
||||
}
|
||||
func (tab *Table) addIP(b *bucket, ip net.IP) bool {
|
||||
if netutil.IsLAN(ip) {
|
||||
return true
|
||||
}
|
||||
// Replace last if it is still the last entry or just add n if b
|
||||
// isn't full. If is no longer the last entry, it has either been
|
||||
// replaced with someone else or became active.
|
||||
if len(b.entries) == bucketSize && (last == nil || b.entries[bucketSize-1].ID != last.ID) {
|
||||
if !tab.ips.Add(ip) {
|
||||
log.Debug("IP exceeds table limit", "ip", ip)
|
||||
return false
|
||||
}
|
||||
if len(b.entries) < bucketSize {
|
||||
b.entries = append(b.entries, nil)
|
||||
if !b.ips.Add(ip) {
|
||||
log.Debug("IP exceeds bucket limit", "ip", ip)
|
||||
tab.ips.Remove(ip)
|
||||
return false
|
||||
}
|
||||
copy(b.entries[1:], b.entries)
|
||||
b.entries[0] = n
|
||||
return true
|
||||
}
|
||||
|
||||
func (tab *Table) removeIP(b *bucket, ip net.IP) {
|
||||
if netutil.IsLAN(ip) {
|
||||
return
|
||||
}
|
||||
tab.ips.Remove(ip)
|
||||
b.ips.Remove(ip)
|
||||
}
|
||||
|
||||
func (tab *Table) addReplacement(b *bucket, n *Node) {
|
||||
for _, e := range b.replacements {
|
||||
if e.ID == n.ID {
|
||||
return // already in list
|
||||
}
|
||||
}
|
||||
if !tab.addIP(b, n.IP) {
|
||||
return
|
||||
}
|
||||
var removed *Node
|
||||
b.replacements, removed = pushNode(b.replacements, n, maxReplacements)
|
||||
if removed != nil {
|
||||
tab.removeIP(b, removed.IP)
|
||||
}
|
||||
}
|
||||
|
||||
// replace removes n from the replacement list and replaces 'last' with it if it is the
|
||||
// last entry in the bucket. If 'last' isn't the last entry, it has either been replaced
|
||||
// with someone else or became active.
|
||||
func (tab *Table) replace(b *bucket, last *Node) *Node {
|
||||
if len(b.entries) >= 0 && b.entries[len(b.entries)-1].ID != last.ID {
|
||||
// Entry has moved, don't replace it.
|
||||
return nil
|
||||
}
|
||||
// Still the last entry.
|
||||
if len(b.replacements) == 0 {
|
||||
tab.deleteInBucket(b, last)
|
||||
return nil
|
||||
}
|
||||
r := b.replacements[tab.rand.Intn(len(b.replacements))]
|
||||
b.replacements = deleteNode(b.replacements, r)
|
||||
b.entries[len(b.entries)-1] = r
|
||||
tab.removeIP(b, last.IP)
|
||||
return r
|
||||
}
|
||||
|
||||
// bump moves the given node to the front of the bucket entry list
|
||||
// if it is contained in that list.
|
||||
func (b *bucket) bump(n *Node) bool {
|
||||
for i := range b.entries {
|
||||
if b.entries[i].ID == n.ID {
|
||||
@ -653,6 +793,50 @@ func (b *bucket) bump(n *Node) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// bumpOrAdd moves n to the front of the bucket entry list or adds it if the list isn't
|
||||
// full. The return value is true if n is in the bucket.
|
||||
func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool {
|
||||
if b.bump(n) {
|
||||
return true
|
||||
}
|
||||
if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP) {
|
||||
return false
|
||||
}
|
||||
b.entries, _ = pushNode(b.entries, n, bucketSize)
|
||||
b.replacements = deleteNode(b.replacements, n)
|
||||
n.addedAt = time.Now()
|
||||
if tab.nodeAddedHook != nil {
|
||||
tab.nodeAddedHook(n)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (tab *Table) deleteInBucket(b *bucket, n *Node) {
|
||||
b.entries = deleteNode(b.entries, n)
|
||||
tab.removeIP(b, n.IP)
|
||||
}
|
||||
|
||||
// pushNode adds n to the front of list, keeping at most max items.
|
||||
func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) {
|
||||
if len(list) < max {
|
||||
list = append(list, nil)
|
||||
}
|
||||
removed := list[len(list)-1]
|
||||
copy(list[1:], list)
|
||||
list[0] = n
|
||||
return list, removed
|
||||
}
|
||||
|
||||
// deleteNode removes n from list.
|
||||
func deleteNode(list []*Node, n *Node) []*Node {
|
||||
for i := range list {
|
||||
if list[i].ID == n.ID {
|
||||
return append(list[:i], list[i+1:]...)
|
||||
}
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// nodesByDistance is a list of nodes, ordered by
|
||||
// distance to target.
|
||||
type nodesByDistance struct {
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"crypto/ecdsa"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
|
||||
"net"
|
||||
"reflect"
|
||||
@ -32,60 +33,65 @@ import (
|
||||
)
|
||||
|
||||
func TestTable_pingReplace(t *testing.T) {
|
||||
doit := func(newNodeIsResponding, lastInBucketIsResponding bool) {
|
||||
transport := newPingRecorder()
|
||||
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "")
|
||||
defer tab.Close()
|
||||
pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
|
||||
|
||||
// fill up the sender's bucket.
|
||||
last := fillBucket(tab, 253)
|
||||
|
||||
// this call to bond should replace the last node
|
||||
// in its bucket if the node is not responding.
|
||||
transport.responding[last.ID] = lastInBucketIsResponding
|
||||
transport.responding[pingSender.ID] = newNodeIsResponding
|
||||
tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
|
||||
|
||||
// first ping goes to sender (bonding pingback)
|
||||
if !transport.pinged[pingSender.ID] {
|
||||
t.Error("table did not ping back sender")
|
||||
}
|
||||
if newNodeIsResponding {
|
||||
// second ping goes to oldest node in bucket
|
||||
// to see whether it is still alive.
|
||||
if !transport.pinged[last.ID] {
|
||||
t.Error("table did not ping last node in bucket")
|
||||
}
|
||||
}
|
||||
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
if l := len(tab.buckets[253].entries); l != bucketSize {
|
||||
t.Errorf("wrong bucket size after bond: got %d, want %d", l, bucketSize)
|
||||
}
|
||||
|
||||
if lastInBucketIsResponding || !newNodeIsResponding {
|
||||
if !contains(tab.buckets[253].entries, last.ID) {
|
||||
t.Error("last entry was removed")
|
||||
}
|
||||
if contains(tab.buckets[253].entries, pingSender.ID) {
|
||||
t.Error("new entry was added")
|
||||
}
|
||||
} else {
|
||||
if contains(tab.buckets[253].entries, last.ID) {
|
||||
t.Error("last entry was not removed")
|
||||
}
|
||||
if !contains(tab.buckets[253].entries, pingSender.ID) {
|
||||
t.Error("new entry was not added")
|
||||
}
|
||||
}
|
||||
run := func(newNodeResponding, lastInBucketResponding bool) {
|
||||
name := fmt.Sprintf("newNodeResponding=%t/lastInBucketResponding=%t", newNodeResponding, lastInBucketResponding)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testPingReplace(t, newNodeResponding, lastInBucketResponding)
|
||||
})
|
||||
}
|
||||
|
||||
doit(true, true)
|
||||
doit(false, true)
|
||||
doit(true, false)
|
||||
doit(false, false)
|
||||
run(true, true)
|
||||
run(false, true)
|
||||
run(true, false)
|
||||
run(false, false)
|
||||
}
|
||||
|
||||
func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) {
|
||||
transport := newPingRecorder()
|
||||
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
|
||||
defer tab.Close()
|
||||
|
||||
// Wait for init so bond is accepted.
|
||||
<-tab.initDone
|
||||
|
||||
// fill up the sender's bucket.
|
||||
pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
|
||||
last := fillBucket(tab, pingSender)
|
||||
|
||||
// this call to bond should replace the last node
|
||||
// in its bucket if the node is not responding.
|
||||
transport.dead[last.ID] = !lastInBucketIsResponding
|
||||
transport.dead[pingSender.ID] = !newNodeIsResponding
|
||||
tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
|
||||
tab.doRevalidate(make(chan struct{}, 1))
|
||||
|
||||
// first ping goes to sender (bonding pingback)
|
||||
if !transport.pinged[pingSender.ID] {
|
||||
t.Error("table did not ping back sender")
|
||||
}
|
||||
if !transport.pinged[last.ID] {
|
||||
// second ping goes to oldest node in bucket
|
||||
// to see whether it is still alive.
|
||||
t.Error("table did not ping last node in bucket")
|
||||
}
|
||||
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
wantSize := bucketSize
|
||||
if !lastInBucketIsResponding && !newNodeIsResponding {
|
||||
wantSize--
|
||||
}
|
||||
if l := len(tab.bucket(pingSender.sha).entries); l != wantSize {
|
||||
t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize)
|
||||
}
|
||||
if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding {
|
||||
t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding)
|
||||
}
|
||||
wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding
|
||||
if found := contains(tab.bucket(pingSender.sha).entries, pingSender.ID); found != wantNewEntry {
|
||||
t.Errorf("new entry found: %t, want: %t", found, wantNewEntry)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBucket_bumpNoDuplicates(t *testing.T) {
|
||||
@ -130,11 +136,45 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// This checks that the table-wide IP limit is applied correctly.
|
||||
func TestTable_IPLimit(t *testing.T) {
|
||||
transport := newPingRecorder()
|
||||
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
|
||||
defer tab.Close()
|
||||
|
||||
for i := 0; i < tableIPLimit+1; i++ {
|
||||
n := nodeAtDistance(tab.self.sha, i)
|
||||
n.IP = net.IP{172, 0, 1, byte(i)}
|
||||
tab.add(n)
|
||||
}
|
||||
if tab.len() > tableIPLimit {
|
||||
t.Errorf("too many nodes in table")
|
||||
}
|
||||
}
|
||||
|
||||
// This checks that the table-wide IP limit is applied correctly.
|
||||
func TestTable_BucketIPLimit(t *testing.T) {
|
||||
transport := newPingRecorder()
|
||||
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
|
||||
defer tab.Close()
|
||||
|
||||
d := 3
|
||||
for i := 0; i < bucketIPLimit+1; i++ {
|
||||
n := nodeAtDistance(tab.self.sha, d)
|
||||
n.IP = net.IP{172, 0, 1, byte(i)}
|
||||
tab.add(n)
|
||||
}
|
||||
if tab.len() > bucketIPLimit {
|
||||
t.Errorf("too many nodes in table")
|
||||
}
|
||||
}
|
||||
|
||||
// fillBucket inserts nodes into the given bucket until
|
||||
// it is full. The node's IDs dont correspond to their
|
||||
// hashes.
|
||||
func fillBucket(tab *Table, ld int) (last *Node) {
|
||||
b := tab.buckets[ld]
|
||||
func fillBucket(tab *Table, n *Node) (last *Node) {
|
||||
ld := logdist(tab.self.sha, n.sha)
|
||||
b := tab.bucket(n.sha)
|
||||
for len(b.entries) < bucketSize {
|
||||
b.entries = append(b.entries, nodeAtDistance(tab.self.sha, ld))
|
||||
}
|
||||
@ -146,30 +186,39 @@ func fillBucket(tab *Table, ld int) (last *Node) {
|
||||
func nodeAtDistance(base common.Hash, ld int) (n *Node) {
|
||||
n = new(Node)
|
||||
n.sha = hashAtDistance(base, ld)
|
||||
n.IP = net.IP{10, 0, 2, byte(ld)}
|
||||
n.IP = net.IP{byte(ld), 0, 2, byte(ld)}
|
||||
copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID
|
||||
return n
|
||||
}
|
||||
|
||||
type pingRecorder struct{ responding, pinged map[NodeID]bool }
|
||||
type pingRecorder struct {
|
||||
mu sync.Mutex
|
||||
dead, pinged map[NodeID]bool
|
||||
}
|
||||
|
||||
func newPingRecorder() *pingRecorder {
|
||||
return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)}
|
||||
return &pingRecorder{
|
||||
dead: make(map[NodeID]bool),
|
||||
pinged: make(map[NodeID]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
|
||||
panic("findnode called on pingRecorder")
|
||||
return nil, nil
|
||||
}
|
||||
func (t *pingRecorder) close() {}
|
||||
func (t *pingRecorder) waitping(from NodeID) error {
|
||||
return nil // remote always pings
|
||||
}
|
||||
func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.pinged[toid] = true
|
||||
if t.responding[toid] {
|
||||
return nil
|
||||
} else {
|
||||
if t.dead[toid] {
|
||||
return errTimeout
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -178,7 +227,8 @@ func TestTable_closest(t *testing.T) {
|
||||
|
||||
test := func(test *closeTest) bool {
|
||||
// for any node table, Target and N
|
||||
tab, _ := newTable(nil, test.Self, &net.UDPAddr{}, "")
|
||||
transport := newPingRecorder()
|
||||
tab, _ := newTable(transport, test.Self, &net.UDPAddr{}, "", nil)
|
||||
defer tab.Close()
|
||||
tab.stuff(test.All)
|
||||
|
||||
@ -237,8 +287,11 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
|
||||
},
|
||||
}
|
||||
test := func(buf []*Node) bool {
|
||||
tab, _ := newTable(nil, NodeID{}, &net.UDPAddr{}, "")
|
||||
transport := newPingRecorder()
|
||||
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
|
||||
defer tab.Close()
|
||||
<-tab.initDone
|
||||
|
||||
for i := 0; i < len(buf); i++ {
|
||||
ld := cfg.Rand.Intn(len(tab.buckets))
|
||||
tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)})
|
||||
@ -280,7 +333,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
|
||||
func TestTable_Lookup(t *testing.T) {
|
||||
self := nodeAtDistance(common.Hash{}, 0)
|
||||
tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "")
|
||||
tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "", nil)
|
||||
defer tab.Close()
|
||||
|
||||
// lookup on empty table returns no nodes
|
||||
|
@ -216,9 +216,22 @@ type ReadPacket struct {
|
||||
Addr *net.UDPAddr
|
||||
}
|
||||
|
||||
// Config holds Table-related settings.
|
||||
type Config struct {
|
||||
// These settings are required and configure the UDP listener:
|
||||
PrivateKey *ecdsa.PrivateKey
|
||||
|
||||
// These settings are optional:
|
||||
AnnounceAddr *net.UDPAddr // local address announced in the DHT
|
||||
NodeDBPath string // if set, the node database is stored at this filesystem location
|
||||
NetRestrict *netutil.Netlist // network whitelist
|
||||
Bootnodes []*Node // list of bootstrap nodes
|
||||
Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
|
||||
}
|
||||
|
||||
// ListenUDP returns a new table that listens for UDP packets on laddr.
|
||||
func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) {
|
||||
tab, _, err := newUDP(priv, conn, realaddr, unhandled, nodeDBPath, netrestrict)
|
||||
func ListenUDP(c conn, cfg Config) (*Table, error) {
|
||||
tab, _, err := newUDP(c, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -226,25 +239,29 @@ func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandl
|
||||
return tab, nil
|
||||
}
|
||||
|
||||
func newUDP(priv *ecdsa.PrivateKey, c conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) {
|
||||
func newUDP(c conn, cfg Config) (*Table, *udp, error) {
|
||||
udp := &udp{
|
||||
conn: c,
|
||||
priv: priv,
|
||||
netrestrict: netrestrict,
|
||||
priv: cfg.PrivateKey,
|
||||
netrestrict: cfg.NetRestrict,
|
||||
closing: make(chan struct{}),
|
||||
gotreply: make(chan reply),
|
||||
addpending: make(chan *pending),
|
||||
}
|
||||
realaddr := c.LocalAddr().(*net.UDPAddr)
|
||||
if cfg.AnnounceAddr != nil {
|
||||
realaddr = cfg.AnnounceAddr
|
||||
}
|
||||
// TODO: separate TCP port
|
||||
udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
|
||||
tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath)
|
||||
tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
udp.Table = tab
|
||||
|
||||
go udp.loop()
|
||||
go udp.readLoop(unhandled)
|
||||
go udp.readLoop(cfg.Unhandled)
|
||||
return udp.Table, udp, nil
|
||||
}
|
||||
|
||||
@ -256,14 +273,20 @@ func (t *udp) close() {
|
||||
|
||||
// ping sends a ping message to the given node and waits for a reply.
|
||||
func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
|
||||
// TODO: maybe check for ReplyTo field in callback to measure RTT
|
||||
errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
|
||||
t.send(toaddr, pingPacket, &ping{
|
||||
req := &ping{
|
||||
Version: Version,
|
||||
From: t.ourEndpoint,
|
||||
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
|
||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||
}
|
||||
packet, hash, err := encodePacket(t.priv, pingPacket, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
errc := t.pending(toid, pongPacket, func(p interface{}) bool {
|
||||
return bytes.Equal(p.(*pong).ReplyTok, hash)
|
||||
})
|
||||
t.write(toaddr, req.name(), packet)
|
||||
return <-errc
|
||||
}
|
||||
|
||||
@ -447,40 +470,45 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) error {
|
||||
packet, err := encodePacket(t.priv, ptype, req)
|
||||
func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) {
|
||||
packet, hash, err := encodePacket(t.priv, ptype, req)
|
||||
if err != nil {
|
||||
return err
|
||||
return hash, err
|
||||
}
|
||||
_, err = t.conn.WriteToUDP(packet, toaddr)
|
||||
log.Trace(">> "+req.name(), "addr", toaddr, "err", err)
|
||||
return hash, t.write(toaddr, req.name(), packet)
|
||||
}
|
||||
|
||||
func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error {
|
||||
_, err := t.conn.WriteToUDP(packet, toaddr)
|
||||
log.Trace(">> "+what, "addr", toaddr, "err", err)
|
||||
return err
|
||||
}
|
||||
|
||||
func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
|
||||
func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) {
|
||||
b := new(bytes.Buffer)
|
||||
b.Write(headSpace)
|
||||
b.WriteByte(ptype)
|
||||
if err := rlp.Encode(b, req); err != nil {
|
||||
log.Error("Can't encode discv4 packet", "err", err)
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
packet := b.Bytes()
|
||||
packet = b.Bytes()
|
||||
sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
|
||||
if err != nil {
|
||||
log.Error("Can't sign discv4 packet", "err", err)
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
copy(packet[macSize:], sig)
|
||||
// add the hash to the front. Note: this doesn't protect the
|
||||
// packet in any way. Our public key will be part of this hash in
|
||||
// The future.
|
||||
copy(packet, crypto.Keccak256(packet[macSize:]))
|
||||
return packet, nil
|
||||
hash = crypto.Keccak256(packet[macSize:])
|
||||
copy(packet, hash)
|
||||
return packet, hash, nil
|
||||
}
|
||||
|
||||
// readLoop runs in its own goroutine. it handles incoming UDP packets.
|
||||
func (t *udp) readLoop(unhandled chan ReadPacket) {
|
||||
func (t *udp) readLoop(unhandled chan<- ReadPacket) {
|
||||
defer t.conn.Close()
|
||||
if unhandled != nil {
|
||||
defer close(unhandled)
|
||||
@ -601,18 +629,22 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
|
||||
t.mutex.Unlock()
|
||||
|
||||
p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
|
||||
var sent bool
|
||||
// Send neighbors in chunks with at most maxNeighbors per packet
|
||||
// to stay below the 1280 byte limit.
|
||||
for i, n := range closest {
|
||||
if netutil.CheckRelayIP(from.IP, n.IP) != nil {
|
||||
continue
|
||||
for _, n := range closest {
|
||||
if netutil.CheckRelayIP(from.IP, n.IP) == nil {
|
||||
p.Nodes = append(p.Nodes, nodeToRPC(n))
|
||||
}
|
||||
p.Nodes = append(p.Nodes, nodeToRPC(n))
|
||||
if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
|
||||
if len(p.Nodes) == maxNeighbors {
|
||||
t.send(from, neighborsPacket, &p)
|
||||
p.Nodes = p.Nodes[:0]
|
||||
sent = true
|
||||
}
|
||||
}
|
||||
if len(p.Nodes) > 0 || !sent {
|
||||
t.send(from, neighborsPacket, &p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -70,14 +70,15 @@ func newUDPTest(t *testing.T) *udpTest {
|
||||
remotekey: newkey(),
|
||||
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
|
||||
}
|
||||
realaddr := test.pipe.LocalAddr().(*net.UDPAddr)
|
||||
test.table, test.udp, _ = newUDP(test.localkey, test.pipe, realaddr, nil, "", nil)
|
||||
test.table, test.udp, _ = newUDP(test.pipe, Config{PrivateKey: test.localkey})
|
||||
// Wait for initial refresh so the table doesn't send unexpected findnode.
|
||||
<-test.table.initDone
|
||||
return test
|
||||
}
|
||||
|
||||
// handles a packet as if it had been sent to the transport.
|
||||
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
|
||||
enc, err := encodePacket(test.remotekey, ptype, data)
|
||||
enc, _, err := encodePacket(test.remotekey, ptype, data)
|
||||
if err != nil {
|
||||
return test.errorf("packet (%d) encode error: %v", ptype, err)
|
||||
}
|
||||
@ -90,19 +91,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
|
||||
|
||||
// waits for a packet to be sent by the transport.
|
||||
// validate should have type func(*udpTest, X) error, where X is a packet type.
|
||||
func (test *udpTest) waitPacketOut(validate interface{}) error {
|
||||
func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) {
|
||||
dgram := test.pipe.waitPacketOut()
|
||||
p, _, _, err := decodePacket(dgram)
|
||||
p, _, hash, err := decodePacket(dgram)
|
||||
if err != nil {
|
||||
return test.errorf("sent packet decode error: %v", err)
|
||||
return hash, test.errorf("sent packet decode error: %v", err)
|
||||
}
|
||||
fn := reflect.ValueOf(validate)
|
||||
exptype := fn.Type().In(0)
|
||||
if reflect.TypeOf(p) != exptype {
|
||||
return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
|
||||
return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
|
||||
}
|
||||
fn.Call([]reflect.Value{reflect.ValueOf(p)})
|
||||
return nil
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
func (test *udpTest) errorf(format string, args ...interface{}) error {
|
||||
@ -351,7 +352,7 @@ func TestUDP_successfulPing(t *testing.T) {
|
||||
})
|
||||
|
||||
// remote is unknown, the table pings back.
|
||||
test.waitPacketOut(func(p *ping) error {
|
||||
hash, _ := test.waitPacketOut(func(p *ping) error {
|
||||
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) {
|
||||
t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint)
|
||||
}
|
||||
@ -365,7 +366,7 @@ func TestUDP_successfulPing(t *testing.T) {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
test.packetIn(nil, pongPacket, &pong{Expiration: futureExp})
|
||||
test.packetIn(nil, pongPacket, &pong{ReplyTok: hash, Expiration: futureExp})
|
||||
|
||||
// the node should be added to the table shortly after getting the
|
||||
// pong packet.
|
||||
|
@ -18,8 +18,11 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -189,3 +192,131 @@ func CheckRelayIP(sender, addr net.IP) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SameNet reports whether two IP addresses have an equal prefix of the given bit length.
|
||||
func SameNet(bits uint, ip, other net.IP) bool {
|
||||
ip4, other4 := ip.To4(), other.To4()
|
||||
switch {
|
||||
case (ip4 == nil) != (other4 == nil):
|
||||
return false
|
||||
case ip4 != nil:
|
||||
return sameNet(bits, ip4, other4)
|
||||
default:
|
||||
return sameNet(bits, ip.To16(), other.To16())
|
||||
}
|
||||
}
|
||||
|
||||
func sameNet(bits uint, ip, other net.IP) bool {
|
||||
nb := int(bits / 8)
|
||||
mask := ^byte(0xFF >> (bits % 8))
|
||||
if mask != 0 && nb < len(ip) && ip[nb]&mask != other[nb]&mask {
|
||||
return false
|
||||
}
|
||||
return nb <= len(ip) && bytes.Equal(ip[:nb], other[:nb])
|
||||
}
|
||||
|
||||
// DistinctNetSet tracks IPs, ensuring that at most N of them
|
||||
// fall into the same network range.
|
||||
type DistinctNetSet struct {
|
||||
Subnet uint // number of common prefix bits
|
||||
Limit uint // maximum number of IPs in each subnet
|
||||
|
||||
members map[string]uint
|
||||
buf net.IP
|
||||
}
|
||||
|
||||
// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the
|
||||
// number of existing IPs in the defined range exceeds the limit.
|
||||
func (s *DistinctNetSet) Add(ip net.IP) bool {
|
||||
key := s.key(ip)
|
||||
n := s.members[string(key)]
|
||||
if n < s.Limit {
|
||||
s.members[string(key)] = n + 1
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Remove removes an IP from the set.
|
||||
func (s *DistinctNetSet) Remove(ip net.IP) {
|
||||
key := s.key(ip)
|
||||
if n, ok := s.members[string(key)]; ok {
|
||||
if n == 1 {
|
||||
delete(s.members, string(key))
|
||||
} else {
|
||||
s.members[string(key)] = n - 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Contains whether the given IP is contained in the set.
|
||||
func (s DistinctNetSet) Contains(ip net.IP) bool {
|
||||
key := s.key(ip)
|
||||
_, ok := s.members[string(key)]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Len returns the number of tracked IPs.
|
||||
func (s DistinctNetSet) Len() int {
|
||||
n := uint(0)
|
||||
for _, i := range s.members {
|
||||
n += i
|
||||
}
|
||||
return int(n)
|
||||
}
|
||||
|
||||
// key encodes the map key for an address into a temporary buffer.
|
||||
//
|
||||
// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types.
|
||||
// The remainder of the key is the IP, truncated to the number of bits.
|
||||
func (s *DistinctNetSet) key(ip net.IP) net.IP {
|
||||
// Lazily initialize storage.
|
||||
if s.members == nil {
|
||||
s.members = make(map[string]uint)
|
||||
s.buf = make(net.IP, 17)
|
||||
}
|
||||
// Canonicalize ip and bits.
|
||||
typ := byte('6')
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
typ, ip = '4', ip4
|
||||
}
|
||||
bits := s.Subnet
|
||||
if bits > uint(len(ip)*8) {
|
||||
bits = uint(len(ip) * 8)
|
||||
}
|
||||
// Encode the prefix into s.buf.
|
||||
nb := int(bits / 8)
|
||||
mask := ^byte(0xFF >> (bits % 8))
|
||||
s.buf[0] = typ
|
||||
buf := append(s.buf[:1], ip[:nb]...)
|
||||
if nb < len(ip) && mask != 0 {
|
||||
buf = append(buf, ip[nb]&mask)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
// String implements fmt.Stringer
|
||||
func (s DistinctNetSet) String() string {
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("{")
|
||||
keys := make([]string, 0, len(s.members))
|
||||
for k := range s.members {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for i, k := range keys {
|
||||
var ip net.IP
|
||||
if k[0] == '4' {
|
||||
ip = make(net.IP, 4)
|
||||
} else {
|
||||
ip = make(net.IP, 16)
|
||||
}
|
||||
copy(ip, k[1:])
|
||||
fmt.Fprintf(&buf, "%v×%d", ip, s.members[k])
|
||||
if i != len(keys)-1 {
|
||||
buf.WriteString(" ")
|
||||
}
|
||||
}
|
||||
buf.WriteString("}")
|
||||
return buf.String()
|
||||
}
|
||||
|
@ -17,9 +17,11 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
@ -171,3 +173,90 @@ func BenchmarkCheckRelayIP(b *testing.B) {
|
||||
CheckRelayIP(sender, addr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSameNet(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip, other string
|
||||
bits uint
|
||||
want bool
|
||||
}{
|
||||
{"0.0.0.0", "0.0.0.0", 32, true},
|
||||
{"0.0.0.0", "0.0.0.1", 0, true},
|
||||
{"0.0.0.0", "0.0.0.1", 31, true},
|
||||
{"0.0.0.0", "0.0.0.1", 32, false},
|
||||
{"0.33.0.1", "0.34.0.2", 8, true},
|
||||
{"0.33.0.1", "0.34.0.2", 13, true},
|
||||
{"0.33.0.1", "0.34.0.2", 15, false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
if ok := SameNet(test.bits, parseIP(test.ip), parseIP(test.other)); ok != test.want {
|
||||
t.Errorf("SameNet(%d, %s, %s) == %t, want %t", test.bits, test.ip, test.other, ok, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleSameNet() {
|
||||
// This returns true because the IPs are in the same /24 network:
|
||||
fmt.Println(SameNet(24, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 3}))
|
||||
// This call returns false:
|
||||
fmt.Println(SameNet(24, net.IP{127, 3, 0, 1}, net.IP{127, 5, 0, 3}))
|
||||
// Output:
|
||||
// true
|
||||
// false
|
||||
}
|
||||
|
||||
func TestDistinctNetSet(t *testing.T) {
|
||||
ops := []struct {
|
||||
add, remove string
|
||||
fails bool
|
||||
}{
|
||||
{add: "127.0.0.1"},
|
||||
{add: "127.0.0.2"},
|
||||
{add: "127.0.0.3", fails: true},
|
||||
{add: "127.32.0.1"},
|
||||
{add: "127.32.0.2"},
|
||||
{add: "127.32.0.3", fails: true},
|
||||
{add: "127.33.0.1", fails: true},
|
||||
{add: "127.34.0.1"},
|
||||
{add: "127.34.0.2"},
|
||||
{add: "127.34.0.3", fails: true},
|
||||
// Make room for an address, then add again.
|
||||
{remove: "127.0.0.1"},
|
||||
{add: "127.0.0.3"},
|
||||
{add: "127.0.0.3", fails: true},
|
||||
}
|
||||
|
||||
set := DistinctNetSet{Subnet: 15, Limit: 2}
|
||||
for _, op := range ops {
|
||||
var desc string
|
||||
if op.add != "" {
|
||||
desc = fmt.Sprintf("Add(%s)", op.add)
|
||||
if ok := set.Add(parseIP(op.add)); ok != !op.fails {
|
||||
t.Errorf("%s == %t, want %t", desc, ok, !op.fails)
|
||||
}
|
||||
} else {
|
||||
desc = fmt.Sprintf("Remove(%s)", op.remove)
|
||||
set.Remove(parseIP(op.remove))
|
||||
}
|
||||
t.Logf("%s: %v", desc, set)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistinctNetSetAddRemove(t *testing.T) {
|
||||
cfg := &quick.Config{}
|
||||
fn := func(ips []net.IP) bool {
|
||||
s := DistinctNetSet{Limit: 3, Subnet: 2}
|
||||
for _, ip := range ips {
|
||||
s.Add(ip)
|
||||
}
|
||||
for _, ip := range ips {
|
||||
s.Remove(ip)
|
||||
}
|
||||
return s.Len() == 0
|
||||
}
|
||||
|
||||
if err := quick.Check(fn, cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
@ -419,6 +419,9 @@ type PeerInfo struct {
|
||||
Network struct {
|
||||
LocalAddress string `json:"localAddress"` // Local endpoint of the TCP data connection
|
||||
RemoteAddress string `json:"remoteAddress"` // Remote endpoint of the TCP data connection
|
||||
Inbound bool `json:"inbound"`
|
||||
Trusted bool `json:"trusted"`
|
||||
Static bool `json:"static"`
|
||||
} `json:"network"`
|
||||
Protocols map[string]interface{} `json:"protocols"` // Sub-protocol specific metadata fields
|
||||
}
|
||||
@ -439,6 +442,9 @@ func (p *Peer) Info() *PeerInfo {
|
||||
}
|
||||
info.Network.LocalAddress = p.LocalAddr().String()
|
||||
info.Network.RemoteAddress = p.RemoteAddr().String()
|
||||
info.Network.Inbound = p.rw.is(inboundConn)
|
||||
info.Network.Trusted = p.rw.is(trustedConn)
|
||||
info.Network.Static = p.rw.is(staticDialedConn)
|
||||
|
||||
// Gather all the running protocol infos
|
||||
for _, proto := range p.running {
|
||||
|
@ -40,11 +40,10 @@ const (
|
||||
refreshPeersInterval = 30 * time.Second
|
||||
staticPeerCheckInterval = 15 * time.Second
|
||||
|
||||
// Maximum number of concurrently handshaking inbound connections.
|
||||
maxAcceptConns = 50
|
||||
|
||||
// Maximum number of concurrently dialing outbound connections.
|
||||
maxActiveDialTasks = 16
|
||||
// Connectivity defaults.
|
||||
maxActiveDialTasks = 16
|
||||
defaultMaxPendingPeers = 50
|
||||
defaultDialRatio = 3
|
||||
|
||||
// Maximum time allowed for reading a complete message.
|
||||
// This is effectively the amount of time a connection can be idle.
|
||||
@ -70,6 +69,11 @@ type Config struct {
|
||||
// Zero defaults to preset values.
|
||||
MaxPendingPeers int `toml:",omitempty"`
|
||||
|
||||
// DialRatio controls the ratio of inbound to dialed connections.
|
||||
// Example: a DialRatio of 2 allows 1/2 of connections to be dialed.
|
||||
// Setting DialRatio to zero defaults it to 3.
|
||||
DialRatio int `toml:",omitempty"`
|
||||
|
||||
// NoDiscovery can be used to disable the peer discovery mechanism.
|
||||
// Disabling is useful for protocol debugging (manual topology).
|
||||
NoDiscovery bool
|
||||
@ -427,7 +431,6 @@ func (srv *Server) Start() (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
realaddr = conn.LocalAddr().(*net.UDPAddr)
|
||||
if srv.NAT != nil {
|
||||
if !realaddr.IP.IsLoopback() {
|
||||
@ -447,11 +450,16 @@ func (srv *Server) Start() (err error) {
|
||||
|
||||
// node table
|
||||
if !srv.NoDiscovery {
|
||||
ntab, err := discover.ListenUDP(srv.PrivateKey, conn, realaddr, unhandled, srv.NodeDatabase, srv.NetRestrict)
|
||||
if err != nil {
|
||||
return err
|
||||
cfg := discover.Config{
|
||||
PrivateKey: srv.PrivateKey,
|
||||
AnnounceAddr: realaddr,
|
||||
NodeDBPath: srv.NodeDatabase,
|
||||
NetRestrict: srv.NetRestrict,
|
||||
Bootnodes: srv.BootstrapNodes,
|
||||
Unhandled: unhandled,
|
||||
}
|
||||
if err := ntab.SetFallbackNodes(srv.BootstrapNodes); err != nil {
|
||||
ntab, err := discover.ListenUDP(conn, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv.ntab = ntab
|
||||
@ -476,10 +484,7 @@ func (srv *Server) Start() (err error) {
|
||||
srv.DiscV5 = ntab
|
||||
}
|
||||
|
||||
dynPeers := (srv.MaxPeers + 1) / 2
|
||||
if srv.NoDiscovery {
|
||||
dynPeers = 0
|
||||
}
|
||||
dynPeers := srv.maxDialedConns()
|
||||
dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
|
||||
|
||||
// handshake
|
||||
@ -536,6 +541,7 @@ func (srv *Server) run(dialstate dialer) {
|
||||
defer srv.loopWG.Done()
|
||||
var (
|
||||
peers = make(map[discover.NodeID]*Peer)
|
||||
inboundCount = 0
|
||||
trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes))
|
||||
taskdone = make(chan task, maxActiveDialTasks)
|
||||
runningTasks []task
|
||||
@ -621,14 +627,14 @@ running:
|
||||
}
|
||||
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
|
||||
select {
|
||||
case c.cont <- srv.encHandshakeChecks(peers, c):
|
||||
case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c):
|
||||
case <-srv.quit:
|
||||
break running
|
||||
}
|
||||
case c := <-srv.addpeer:
|
||||
// At this point the connection is past the protocol handshake.
|
||||
// Its capabilities are known and the remote identity is verified.
|
||||
err := srv.protoHandshakeChecks(peers, c)
|
||||
err := srv.protoHandshakeChecks(peers, inboundCount, c)
|
||||
if err == nil {
|
||||
// The handshakes are done and it passed all checks.
|
||||
p := newPeer(c, srv.Protocols)
|
||||
@ -639,8 +645,11 @@ running:
|
||||
}
|
||||
name := truncateName(c.name)
|
||||
srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1)
|
||||
peers[c.id] = p
|
||||
go srv.runPeer(p)
|
||||
peers[c.id] = p
|
||||
if p.Inbound() {
|
||||
inboundCount++
|
||||
}
|
||||
}
|
||||
// The dialer logic relies on the assumption that
|
||||
// dial tasks complete after the peer has been added or
|
||||
@ -655,6 +664,9 @@ running:
|
||||
d := common.PrettyDuration(mclock.Now() - pd.created)
|
||||
pd.log.Debug("Removing p2p peer", "duration", d, "peers", len(peers)-1, "req", pd.requested, "err", pd.err)
|
||||
delete(peers, pd.ID())
|
||||
if pd.Inbound() {
|
||||
inboundCount--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -681,20 +693,22 @@ running:
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
|
||||
func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
|
||||
// Drop connections with no matching protocols.
|
||||
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
|
||||
return DiscUselessPeer
|
||||
}
|
||||
// Repeat the encryption handshake checks because the
|
||||
// peer set might have changed between the handshakes.
|
||||
return srv.encHandshakeChecks(peers, c)
|
||||
return srv.encHandshakeChecks(peers, inboundCount, c)
|
||||
}
|
||||
|
||||
func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
|
||||
func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
|
||||
switch {
|
||||
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
|
||||
return DiscTooManyPeers
|
||||
case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns():
|
||||
return DiscTooManyPeers
|
||||
case peers[c.id] != nil:
|
||||
return DiscAlreadyConnected
|
||||
case c.id == srv.Self().ID:
|
||||
@ -704,6 +718,21 @@ func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) maxInboundConns() int {
|
||||
return srv.MaxPeers - srv.maxDialedConns()
|
||||
}
|
||||
|
||||
func (srv *Server) maxDialedConns() int {
|
||||
if srv.NoDiscovery || srv.NoDial {
|
||||
return 0
|
||||
}
|
||||
r := srv.DialRatio
|
||||
if r == 0 {
|
||||
r = defaultDialRatio
|
||||
}
|
||||
return srv.MaxPeers / r
|
||||
}
|
||||
|
||||
type tempError interface {
|
||||
Temporary() bool
|
||||
}
|
||||
@ -714,10 +743,7 @@ func (srv *Server) listenLoop() {
|
||||
defer srv.loopWG.Done()
|
||||
srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab))
|
||||
|
||||
// This channel acts as a semaphore limiting
|
||||
// active inbound connections that are lingering pre-handshake.
|
||||
// If all slots are taken, no further connections are accepted.
|
||||
tokens := maxAcceptConns
|
||||
tokens := defaultMaxPendingPeers
|
||||
if srv.MaxPendingPeers > 0 {
|
||||
tokens = srv.MaxPendingPeers
|
||||
}
|
||||
@ -758,9 +784,6 @@ func (srv *Server) listenLoop() {
|
||||
|
||||
fd = newMeteredConn(fd, true)
|
||||
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
|
||||
|
||||
// Spawn the handler. It will give the slot back when the connection
|
||||
// has been established.
|
||||
go func() {
|
||||
srv.SetupConn(fd, inboundConn, nil)
|
||||
slots <- struct{}{}
|
||||
|
Loading…
Reference in New Issue
Block a user