p2p, p2p/discover: misc connectivity improvements (#16069)

* p2p: add DialRatio for configuration of inbound vs. dialed connections

* p2p: add connection flags to PeerInfo

* p2p/netutil: add SameNet, DistinctNetSet

* p2p/discover: improve revalidation and seeding

This changes node revalidation to be periodic instead of on-demand. This
should prevent issues where dead nodes get stuck in closer buckets
because no other node will ever come along to replace them.

Every 5 seconds (on average), the last node in a random bucket is
checked and moved to the front of the bucket if it is still responding.
If revalidation fails, the last node is replaced by an entry of the
'replacement list' containing recently-seen nodes.

Most close buckets are removed because it's very unlikely we'll ever
encounter a node that would fall into any of those buckets.

Table seeding is also improved: we now require a few minutes of table
membership before considering a node as a potential seed node. This
should make it less likely to store short-lived nodes as potential
seeds.

* p2p/discover: fix nits in UDP transport

We would skip sending neighbors replies if there were fewer than
maxNeighbors results and CheckRelayIP returned an error for the last
one. While here, also resolve a TODO about pong reply tokens.
This commit is contained in:
Felix Lange 2018-02-12 13:36:09 +01:00 committed by Péter Szilágyi
parent 1d39912a9b
commit 9123eceb0f
10 changed files with 806 additions and 282 deletions

View File

@ -122,7 +122,12 @@ func main() {
utils.Fatalf("%v", err) utils.Fatalf("%v", err)
} }
} else { } else {
if _, err := discover.ListenUDP(nodeKey, conn, realaddr, nil, "", restrictList); err != nil { cfg := discover.Config{
PrivateKey: nodeKey,
AnnounceAddr: realaddr,
NetRestrict: restrictList,
}
if _, err := discover.ListenUDP(conn, cfg); err != nil {
utils.Fatalf("%v", err) utils.Fatalf("%v", err)
} }
} }

View File

@ -29,6 +29,7 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
@ -51,9 +52,8 @@ type Node struct {
// with ID. // with ID.
sha common.Hash sha common.Hash
// whether this node is currently being pinged in order to replace // Time when the node was added to the table.
// it in a bucket addedAt time.Time
contested bool
} }
// NewNode creates a new node. It is mostly meant to be used for // NewNode creates a new node. It is mostly meant to be used for

View File

@ -23,10 +23,11 @@
package discover package discover
import ( import (
"crypto/rand" crand "crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
mrand "math/rand"
"net" "net"
"sort" "sort"
"sync" "sync"
@ -35,29 +36,45 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/netutil"
) )
const ( const (
alpha = 3 // Kademlia concurrency factor alpha = 3 // Kademlia concurrency factor
bucketSize = 16 // Kademlia bucket size bucketSize = 16 // Kademlia bucket size
hashBits = len(common.Hash{}) * 8 maxReplacements = 10 // Size of per-bucket replacement list
nBuckets = hashBits + 1 // Number of buckets
maxBondingPingPongs = 16 // We keep buckets for the upper 1/15 of distances because
maxFindnodeFailures = 5 // it's very unlikely we'll ever encounter a node that's closer.
hashBits = len(common.Hash{}) * 8
nBuckets = hashBits / 15 // Number of buckets
bucketMinDistance = hashBits - nBuckets // Log distance of closest bucket
autoRefreshInterval = 1 * time.Hour // IP address limits.
seedCount = 30 bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24
seedMaxAge = 5 * 24 * time.Hour tableIPLimit, tableSubnet = 10, 24
maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions
maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped
refreshInterval = 30 * time.Minute
revalidateInterval = 10 * time.Second
copyNodesInterval = 30 * time.Second
seedMinTableTime = 5 * time.Minute
seedCount = 30
seedMaxAge = 5 * 24 * time.Hour
) )
type Table struct { type Table struct {
mutex sync.Mutex // protects buckets, their content, and nursery mutex sync.Mutex // protects buckets, bucket content, nursery, rand
buckets [nBuckets]*bucket // index of known nodes by distance buckets [nBuckets]*bucket // index of known nodes by distance
nursery []*Node // bootstrap nodes nursery []*Node // bootstrap nodes
db *nodeDB // database of known nodes rand *mrand.Rand // source of randomness, periodically reseeded
ips netutil.DistinctNetSet
db *nodeDB // database of known nodes
refreshReq chan chan struct{} refreshReq chan chan struct{}
initDone chan struct{}
closeReq chan struct{} closeReq chan struct{}
closed chan struct{} closed chan struct{}
@ -89,9 +106,13 @@ type transport interface {
// bucket contains nodes, ordered by their last activity. the entry // bucket contains nodes, ordered by their last activity. the entry
// that was most recently active is the first element in entries. // that was most recently active is the first element in entries.
type bucket struct{ entries []*Node } type bucket struct {
entries []*Node // live entries, sorted by time of last contact
replacements []*Node // recently seen nodes to be used if revalidation fails
ips netutil.DistinctNetSet
}
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string) (*Table, error) { func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) {
// If no node database was given, use an in-memory one // If no node database was given, use an in-memory one
db, err := newNodeDB(nodeDBPath, Version, ourID) db, err := newNodeDB(nodeDBPath, Version, ourID)
if err != nil { if err != nil {
@ -104,19 +125,42 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
bonding: make(map[NodeID]*bondproc), bonding: make(map[NodeID]*bondproc),
bondslots: make(chan struct{}, maxBondingPingPongs), bondslots: make(chan struct{}, maxBondingPingPongs),
refreshReq: make(chan chan struct{}), refreshReq: make(chan chan struct{}),
initDone: make(chan struct{}),
closeReq: make(chan struct{}), closeReq: make(chan struct{}),
closed: make(chan struct{}), closed: make(chan struct{}),
rand: mrand.New(mrand.NewSource(0)),
ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit},
}
if err := tab.setFallbackNodes(bootnodes); err != nil {
return nil, err
} }
for i := 0; i < cap(tab.bondslots); i++ { for i := 0; i < cap(tab.bondslots); i++ {
tab.bondslots <- struct{}{} tab.bondslots <- struct{}{}
} }
for i := range tab.buckets { for i := range tab.buckets {
tab.buckets[i] = new(bucket) tab.buckets[i] = &bucket{
ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit},
}
} }
go tab.refreshLoop() tab.seedRand()
tab.loadSeedNodes(false)
// Start the background expiration goroutine after loading seeds so that the search for
// seed nodes also considers older nodes that would otherwise be removed by the
// expiration.
tab.db.ensureExpirer()
go tab.loop()
return tab, nil return tab, nil
} }
func (tab *Table) seedRand() {
var b [8]byte
crand.Read(b[:])
tab.mutex.Lock()
tab.rand.Seed(int64(binary.BigEndian.Uint64(b[:])))
tab.mutex.Unlock()
}
// Self returns the local node. // Self returns the local node.
// The returned node should not be modified by the caller. // The returned node should not be modified by the caller.
func (tab *Table) Self() *Node { func (tab *Table) Self() *Node {
@ -127,9 +171,12 @@ func (tab *Table) Self() *Node {
// table. It will not write the same node more than once. The nodes in // table. It will not write the same node more than once. The nodes in
// the slice are copies and can be modified by the caller. // the slice are copies and can be modified by the caller.
func (tab *Table) ReadRandomNodes(buf []*Node) (n int) { func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
if !tab.isInitDone() {
return 0
}
tab.mutex.Lock() tab.mutex.Lock()
defer tab.mutex.Unlock() defer tab.mutex.Unlock()
// TODO: tree-based buckets would help here
// Find all non-empty buckets and get a fresh slice of their entries. // Find all non-empty buckets and get a fresh slice of their entries.
var buckets [][]*Node var buckets [][]*Node
for _, b := range tab.buckets { for _, b := range tab.buckets {
@ -141,8 +188,8 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
return 0 return 0
} }
// Shuffle the buckets. // Shuffle the buckets.
for i := uint32(len(buckets)) - 1; i > 0; i-- { for i := len(buckets) - 1; i > 0; i-- {
j := randUint(i) j := tab.rand.Intn(len(buckets))
buckets[i], buckets[j] = buckets[j], buckets[i] buckets[i], buckets[j] = buckets[j], buckets[i]
} }
// Move head of each bucket into buf, removing buckets that become empty. // Move head of each bucket into buf, removing buckets that become empty.
@ -161,15 +208,6 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
return i + 1 return i + 1
} }
func randUint(max uint32) uint32 {
if max == 0 {
return 0
}
var b [4]byte
rand.Read(b[:])
return binary.BigEndian.Uint32(b[:]) % max
}
// Close terminates the network listener and flushes the node database. // Close terminates the network listener and flushes the node database.
func (tab *Table) Close() { func (tab *Table) Close() {
select { select {
@ -180,16 +218,15 @@ func (tab *Table) Close() {
} }
} }
// SetFallbackNodes sets the initial points of contact. These nodes // setFallbackNodes sets the initial points of contact. These nodes
// are used to connect to the network if the table is empty and there // are used to connect to the network if the table is empty and there
// are no known nodes in the database. // are no known nodes in the database.
func (tab *Table) SetFallbackNodes(nodes []*Node) error { func (tab *Table) setFallbackNodes(nodes []*Node) error {
for _, n := range nodes { for _, n := range nodes {
if err := n.validateComplete(); err != nil { if err := n.validateComplete(); err != nil {
return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err) return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err)
} }
} }
tab.mutex.Lock()
tab.nursery = make([]*Node, 0, len(nodes)) tab.nursery = make([]*Node, 0, len(nodes))
for _, n := range nodes { for _, n := range nodes {
cpy := *n cpy := *n
@ -198,11 +235,19 @@ func (tab *Table) SetFallbackNodes(nodes []*Node) error {
cpy.sha = crypto.Keccak256Hash(n.ID[:]) cpy.sha = crypto.Keccak256Hash(n.ID[:])
tab.nursery = append(tab.nursery, &cpy) tab.nursery = append(tab.nursery, &cpy)
} }
tab.mutex.Unlock()
tab.refresh()
return nil return nil
} }
// isInitDone returns whether the table's initial seeding procedure has completed.
func (tab *Table) isInitDone() bool {
select {
case <-tab.initDone:
return true
default:
return false
}
}
// Resolve searches for a specific node with the given ID. // Resolve searches for a specific node with the given ID.
// It returns nil if the node could not be found. // It returns nil if the node could not be found.
func (tab *Table) Resolve(targetID NodeID) *Node { func (tab *Table) Resolve(targetID NodeID) *Node {
@ -314,33 +359,49 @@ func (tab *Table) refresh() <-chan struct{} {
return done return done
} }
// refreshLoop schedules doRefresh runs and coordinates shutdown. // loop schedules refresh, revalidate runs and coordinates shutdown.
func (tab *Table) refreshLoop() { func (tab *Table) loop() {
var ( var (
timer = time.NewTicker(autoRefreshInterval) revalidate = time.NewTimer(tab.nextRevalidateTime())
waiting []chan struct{} // accumulates waiting callers while doRefresh runs refresh = time.NewTicker(refreshInterval)
done chan struct{} // where doRefresh reports completion copyNodes = time.NewTicker(copyNodesInterval)
revalidateDone = make(chan struct{})
refreshDone = make(chan struct{}) // where doRefresh reports completion
waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs
) )
defer refresh.Stop()
defer revalidate.Stop()
defer copyNodes.Stop()
// Start initial refresh.
go tab.doRefresh(refreshDone)
loop: loop:
for { for {
select { select {
case <-timer.C: case <-refresh.C:
if done == nil { tab.seedRand()
done = make(chan struct{}) if refreshDone == nil {
go tab.doRefresh(done) refreshDone = make(chan struct{})
go tab.doRefresh(refreshDone)
} }
case req := <-tab.refreshReq: case req := <-tab.refreshReq:
waiting = append(waiting, req) waiting = append(waiting, req)
if done == nil { if refreshDone == nil {
done = make(chan struct{}) refreshDone = make(chan struct{})
go tab.doRefresh(done) go tab.doRefresh(refreshDone)
} }
case <-done: case <-refreshDone:
for _, ch := range waiting { for _, ch := range waiting {
close(ch) close(ch)
} }
waiting = nil waiting, refreshDone = nil, nil
done = nil case <-revalidate.C:
go tab.doRevalidate(revalidateDone)
case <-revalidateDone:
revalidate.Reset(tab.nextRevalidateTime())
case <-copyNodes.C:
go tab.copyBondedNodes()
case <-tab.closeReq: case <-tab.closeReq:
break loop break loop
} }
@ -349,8 +410,8 @@ loop:
if tab.net != nil { if tab.net != nil {
tab.net.close() tab.net.close()
} }
if done != nil { if refreshDone != nil {
<-done <-refreshDone
} }
for _, ch := range waiting { for _, ch := range waiting {
close(ch) close(ch)
@ -365,38 +426,109 @@ loop:
func (tab *Table) doRefresh(done chan struct{}) { func (tab *Table) doRefresh(done chan struct{}) {
defer close(done) defer close(done)
// Load nodes from the database and insert
// them. This should yield a few previously seen nodes that are
// (hopefully) still alive.
tab.loadSeedNodes(true)
// Run self lookup to discover new neighbor nodes.
tab.lookup(tab.self.ID, false)
// The Kademlia paper specifies that the bucket refresh should // The Kademlia paper specifies that the bucket refresh should
// perform a lookup in the least recently used bucket. We cannot // perform a lookup in the least recently used bucket. We cannot
// adhere to this because the findnode target is a 512bit value // adhere to this because the findnode target is a 512bit value
// (not hash-sized) and it is not easily possible to generate a // (not hash-sized) and it is not easily possible to generate a
// sha3 preimage that falls into a chosen bucket. // sha3 preimage that falls into a chosen bucket.
// We perform a lookup with a random target instead. // We perform a few lookups with a random target instead.
var target NodeID for i := 0; i < 3; i++ {
rand.Read(target[:]) var target NodeID
result := tab.lookup(target, false) crand.Read(target[:])
if len(result) > 0 { tab.lookup(target, false)
}
}
func (tab *Table) loadSeedNodes(bond bool) {
seeds := tab.db.querySeeds(seedCount, seedMaxAge)
seeds = append(seeds, tab.nursery...)
if bond {
seeds = tab.bondall(seeds)
}
for i := range seeds {
seed := seeds[i]
age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.lastPong(seed.ID)) }}
log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age)
tab.add(seed)
}
}
// doRevalidate checks that the last node in a random bucket is still live
// and replaces or deletes the node if it isn't.
func (tab *Table) doRevalidate(done chan<- struct{}) {
defer func() { done <- struct{}{} }()
last, bi := tab.nodeToRevalidate()
if last == nil {
// No non-empty bucket found.
return return
} }
// The table is empty. Load nodes from the database and insert // Ping the selected node and wait for a pong.
// them. This should yield a few previously seen nodes that are err := tab.ping(last.ID, last.addr())
// (hopefully) still alive.
seeds := tab.db.querySeeds(seedCount, seedMaxAge)
seeds = tab.bondall(append(seeds, tab.nursery...))
if len(seeds) == 0 {
log.Debug("No discv4 seed nodes found")
}
for _, n := range seeds {
age := log.Lazy{Fn: func() time.Duration { return time.Since(tab.db.lastPong(n.ID)) }}
log.Trace("Found seed node in database", "id", n.ID, "addr", n.addr(), "age", age)
}
tab.mutex.Lock() tab.mutex.Lock()
tab.stuff(seeds) defer tab.mutex.Unlock()
tab.mutex.Unlock() b := tab.buckets[bi]
if err == nil {
// The node responded, move it to the front.
log.Debug("Revalidated node", "b", bi, "id", last.ID)
b.bump(last)
return
}
// No reply received, pick a replacement or delete the node if there aren't
// any replacements.
if r := tab.replace(b, last); r != nil {
log.Debug("Replaced dead node", "b", bi, "id", last.ID, "ip", last.IP, "r", r.ID, "rip", r.IP)
} else {
log.Debug("Removed dead node", "b", bi, "id", last.ID, "ip", last.IP)
}
}
// Finally, do a self lookup to fill up the buckets. // nodeToRevalidate returns the last node in a random, non-empty bucket.
tab.lookup(tab.self.ID, false) func (tab *Table) nodeToRevalidate() (n *Node, bi int) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
for _, bi = range tab.rand.Perm(len(tab.buckets)) {
b := tab.buckets[bi]
if len(b.entries) > 0 {
last := b.entries[len(b.entries)-1]
return last, bi
}
}
return nil, 0
}
func (tab *Table) nextRevalidateTime() time.Duration {
tab.mutex.Lock()
defer tab.mutex.Unlock()
return time.Duration(tab.rand.Int63n(int64(revalidateInterval)))
}
// copyBondedNodes adds nodes from the table to the database if they have been in the table
// longer then minTableTime.
func (tab *Table) copyBondedNodes() {
tab.mutex.Lock()
defer tab.mutex.Unlock()
now := time.Now()
for _, b := range tab.buckets {
for _, n := range b.entries {
if now.Sub(n.addedAt) >= seedMinTableTime {
tab.db.updateNode(n)
}
}
}
} }
// closest returns the n nodes in the table that are closest to the // closest returns the n nodes in the table that are closest to the
@ -459,15 +591,14 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
if id == tab.self.ID { if id == tab.self.ID {
return nil, errors.New("is self") return nil, errors.New("is self")
} }
// Retrieve a previously known node and any recent findnode failures if pinged && !tab.isInitDone() {
node, fails := tab.db.node(id), 0 return nil, errors.New("still initializing")
if node != nil {
fails = tab.db.findFails(id)
} }
// If the node is unknown (non-bonded) or failed (remotely unknown), bond from scratch // Start bonding if we haven't seen this node for a while or if it failed findnode too often.
var result error node, fails := tab.db.node(id), tab.db.findFails(id)
age := time.Since(tab.db.lastPong(id)) age := time.Since(tab.db.lastPong(id))
if node == nil || fails > 0 || age > nodeDBNodeExpiration { var result error
if fails > 0 || age > nodeDBNodeExpiration {
log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age) log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age)
tab.bondmu.Lock() tab.bondmu.Lock()
@ -494,10 +625,10 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
node = w.n node = w.n
} }
} }
// Add the node to the table even if the bonding ping/pong
// fails. It will be relaced quickly if it continues to be
// unresponsive.
if node != nil { if node != nil {
// Add the node to the table even if the bonding ping/pong
// fails. It will be relaced quickly if it continues to be
// unresponsive.
tab.add(node) tab.add(node)
tab.db.updateFindFails(id, 0) tab.db.updateFindFails(id, 0)
} }
@ -522,7 +653,6 @@ func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAdd
} }
// Bonding succeeded, update the node database. // Bonding succeeded, update the node database.
w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort) w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort)
tab.db.updateNode(w.n)
close(w.done) close(w.done)
} }
@ -534,16 +664,18 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
return err return err
} }
tab.db.updateLastPong(id, time.Now()) tab.db.updateLastPong(id, time.Now())
// Start the background expiration goroutine after the first
// successful communication. Subsequent calls have no effect if it
// is already running. We do this here instead of somewhere else
// so that the search for seed nodes also considers older nodes
// that would otherwise be removed by the expiration.
tab.db.ensureExpirer()
return nil return nil
} }
// bucket returns the bucket for the given node ID hash.
func (tab *Table) bucket(sha common.Hash) *bucket {
d := logdist(tab.self.sha, sha)
if d <= bucketMinDistance {
return tab.buckets[0]
}
return tab.buckets[d-bucketMinDistance-1]
}
// add attempts to add the given node its corresponding bucket. If the // add attempts to add the given node its corresponding bucket. If the
// bucket has space available, adding the node succeeds immediately. // bucket has space available, adding the node succeeds immediately.
// Otherwise, the node is added if the least recently active node in // Otherwise, the node is added if the least recently active node in
@ -551,57 +683,29 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
// //
// The caller must not hold tab.mutex. // The caller must not hold tab.mutex.
func (tab *Table) add(new *Node) { func (tab *Table) add(new *Node) {
b := tab.buckets[logdist(tab.self.sha, new.sha)]
tab.mutex.Lock() tab.mutex.Lock()
defer tab.mutex.Unlock() defer tab.mutex.Unlock()
if b.bump(new) {
return b := tab.bucket(new.sha)
} if !tab.bumpOrAdd(b, new) {
var oldest *Node // Node is not in table. Add it to the replacement list.
if len(b.entries) == bucketSize { tab.addReplacement(b, new)
oldest = b.entries[bucketSize-1]
if oldest.contested {
// The node is already being replaced, don't attempt
// to replace it.
return
}
oldest.contested = true
// Let go of the mutex so other goroutines can access
// the table while we ping the least recently active node.
tab.mutex.Unlock()
err := tab.ping(oldest.ID, oldest.addr())
tab.mutex.Lock()
oldest.contested = false
if err == nil {
// The node responded, don't replace it.
return
}
}
added := b.replace(new, oldest)
if added && tab.nodeAddedHook != nil {
tab.nodeAddedHook(new)
} }
} }
// stuff adds nodes the table to the end of their corresponding bucket // stuff adds nodes the table to the end of their corresponding bucket
// if the bucket is not full. The caller must hold tab.mutex. // if the bucket is not full. The caller must not hold tab.mutex.
func (tab *Table) stuff(nodes []*Node) { func (tab *Table) stuff(nodes []*Node) {
outer: tab.mutex.Lock()
defer tab.mutex.Unlock()
for _, n := range nodes { for _, n := range nodes {
if n.ID == tab.self.ID { if n.ID == tab.self.ID {
continue // don't add self continue // don't add self
} }
bucket := tab.buckets[logdist(tab.self.sha, n.sha)] b := tab.bucket(n.sha)
for i := range bucket.entries { if len(b.entries) < bucketSize {
if bucket.entries[i].ID == n.ID { tab.bumpOrAdd(b, n)
continue outer // already in bucket
}
}
if len(bucket.entries) < bucketSize {
bucket.entries = append(bucket.entries, n)
if tab.nodeAddedHook != nil {
tab.nodeAddedHook(n)
}
} }
} }
} }
@ -611,36 +715,72 @@ outer:
func (tab *Table) delete(node *Node) { func (tab *Table) delete(node *Node) {
tab.mutex.Lock() tab.mutex.Lock()
defer tab.mutex.Unlock() defer tab.mutex.Unlock()
bucket := tab.buckets[logdist(tab.self.sha, node.sha)]
for i := range bucket.entries { tab.deleteInBucket(tab.bucket(node.sha), node)
if bucket.entries[i].ID == node.ID {
bucket.entries = append(bucket.entries[:i], bucket.entries[i+1:]...)
return
}
}
} }
func (b *bucket) replace(n *Node, last *Node) bool { func (tab *Table) addIP(b *bucket, ip net.IP) bool {
// Don't add if b already contains n. if netutil.IsLAN(ip) {
for i := range b.entries { return true
if b.entries[i].ID == n.ID {
return false
}
} }
// Replace last if it is still the last entry or just add n if b if !tab.ips.Add(ip) {
// isn't full. If is no longer the last entry, it has either been log.Debug("IP exceeds table limit", "ip", ip)
// replaced with someone else or became active.
if len(b.entries) == bucketSize && (last == nil || b.entries[bucketSize-1].ID != last.ID) {
return false return false
} }
if len(b.entries) < bucketSize { if !b.ips.Add(ip) {
b.entries = append(b.entries, nil) log.Debug("IP exceeds bucket limit", "ip", ip)
tab.ips.Remove(ip)
return false
} }
copy(b.entries[1:], b.entries)
b.entries[0] = n
return true return true
} }
func (tab *Table) removeIP(b *bucket, ip net.IP) {
if netutil.IsLAN(ip) {
return
}
tab.ips.Remove(ip)
b.ips.Remove(ip)
}
func (tab *Table) addReplacement(b *bucket, n *Node) {
for _, e := range b.replacements {
if e.ID == n.ID {
return // already in list
}
}
if !tab.addIP(b, n.IP) {
return
}
var removed *Node
b.replacements, removed = pushNode(b.replacements, n, maxReplacements)
if removed != nil {
tab.removeIP(b, removed.IP)
}
}
// replace removes n from the replacement list and replaces 'last' with it if it is the
// last entry in the bucket. If 'last' isn't the last entry, it has either been replaced
// with someone else or became active.
func (tab *Table) replace(b *bucket, last *Node) *Node {
if len(b.entries) >= 0 && b.entries[len(b.entries)-1].ID != last.ID {
// Entry has moved, don't replace it.
return nil
}
// Still the last entry.
if len(b.replacements) == 0 {
tab.deleteInBucket(b, last)
return nil
}
r := b.replacements[tab.rand.Intn(len(b.replacements))]
b.replacements = deleteNode(b.replacements, r)
b.entries[len(b.entries)-1] = r
tab.removeIP(b, last.IP)
return r
}
// bump moves the given node to the front of the bucket entry list
// if it is contained in that list.
func (b *bucket) bump(n *Node) bool { func (b *bucket) bump(n *Node) bool {
for i := range b.entries { for i := range b.entries {
if b.entries[i].ID == n.ID { if b.entries[i].ID == n.ID {
@ -653,6 +793,50 @@ func (b *bucket) bump(n *Node) bool {
return false return false
} }
// bumpOrAdd moves n to the front of the bucket entry list or adds it if the list isn't
// full. The return value is true if n is in the bucket.
func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool {
if b.bump(n) {
return true
}
if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP) {
return false
}
b.entries, _ = pushNode(b.entries, n, bucketSize)
b.replacements = deleteNode(b.replacements, n)
n.addedAt = time.Now()
if tab.nodeAddedHook != nil {
tab.nodeAddedHook(n)
}
return true
}
func (tab *Table) deleteInBucket(b *bucket, n *Node) {
b.entries = deleteNode(b.entries, n)
tab.removeIP(b, n.IP)
}
// pushNode adds n to the front of list, keeping at most max items.
func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) {
if len(list) < max {
list = append(list, nil)
}
removed := list[len(list)-1]
copy(list[1:], list)
list[0] = n
return list, removed
}
// deleteNode removes n from list.
func deleteNode(list []*Node, n *Node) []*Node {
for i := range list {
if list[i].ID == n.ID {
return append(list[:i], list[i+1:]...)
}
}
return list
}
// nodesByDistance is a list of nodes, ordered by // nodesByDistance is a list of nodes, ordered by
// distance to target. // distance to target.
type nodesByDistance struct { type nodesByDistance struct {

View File

@ -20,6 +20,7 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"fmt" "fmt"
"math/rand" "math/rand"
"sync"
"net" "net"
"reflect" "reflect"
@ -32,60 +33,65 @@ import (
) )
func TestTable_pingReplace(t *testing.T) { func TestTable_pingReplace(t *testing.T) {
doit := func(newNodeIsResponding, lastInBucketIsResponding bool) { run := func(newNodeResponding, lastInBucketResponding bool) {
transport := newPingRecorder() name := fmt.Sprintf("newNodeResponding=%t/lastInBucketResponding=%t", newNodeResponding, lastInBucketResponding)
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "") t.Run(name, func(t *testing.T) {
defer tab.Close() t.Parallel()
pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99) testPingReplace(t, newNodeResponding, lastInBucketResponding)
})
// fill up the sender's bucket.
last := fillBucket(tab, 253)
// this call to bond should replace the last node
// in its bucket if the node is not responding.
transport.responding[last.ID] = lastInBucketIsResponding
transport.responding[pingSender.ID] = newNodeIsResponding
tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
// first ping goes to sender (bonding pingback)
if !transport.pinged[pingSender.ID] {
t.Error("table did not ping back sender")
}
if newNodeIsResponding {
// second ping goes to oldest node in bucket
// to see whether it is still alive.
if !transport.pinged[last.ID] {
t.Error("table did not ping last node in bucket")
}
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
if l := len(tab.buckets[253].entries); l != bucketSize {
t.Errorf("wrong bucket size after bond: got %d, want %d", l, bucketSize)
}
if lastInBucketIsResponding || !newNodeIsResponding {
if !contains(tab.buckets[253].entries, last.ID) {
t.Error("last entry was removed")
}
if contains(tab.buckets[253].entries, pingSender.ID) {
t.Error("new entry was added")
}
} else {
if contains(tab.buckets[253].entries, last.ID) {
t.Error("last entry was not removed")
}
if !contains(tab.buckets[253].entries, pingSender.ID) {
t.Error("new entry was not added")
}
}
} }
doit(true, true) run(true, true)
doit(false, true) run(false, true)
doit(true, false) run(true, false)
doit(false, false) run(false, false)
}
func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) {
transport := newPingRecorder()
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close()
// Wait for init so bond is accepted.
<-tab.initDone
// fill up the sender's bucket.
pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
last := fillBucket(tab, pingSender)
// this call to bond should replace the last node
// in its bucket if the node is not responding.
transport.dead[last.ID] = !lastInBucketIsResponding
transport.dead[pingSender.ID] = !newNodeIsResponding
tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
tab.doRevalidate(make(chan struct{}, 1))
// first ping goes to sender (bonding pingback)
if !transport.pinged[pingSender.ID] {
t.Error("table did not ping back sender")
}
if !transport.pinged[last.ID] {
// second ping goes to oldest node in bucket
// to see whether it is still alive.
t.Error("table did not ping last node in bucket")
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
wantSize := bucketSize
if !lastInBucketIsResponding && !newNodeIsResponding {
wantSize--
}
if l := len(tab.bucket(pingSender.sha).entries); l != wantSize {
t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize)
}
if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding {
t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding)
}
wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding
if found := contains(tab.bucket(pingSender.sha).entries, pingSender.ID); found != wantNewEntry {
t.Errorf("new entry found: %t, want: %t", found, wantNewEntry)
}
} }
func TestBucket_bumpNoDuplicates(t *testing.T) { func TestBucket_bumpNoDuplicates(t *testing.T) {
@ -130,11 +136,45 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
} }
} }
// This checks that the table-wide IP limit is applied correctly.
func TestTable_IPLimit(t *testing.T) {
transport := newPingRecorder()
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close()
for i := 0; i < tableIPLimit+1; i++ {
n := nodeAtDistance(tab.self.sha, i)
n.IP = net.IP{172, 0, 1, byte(i)}
tab.add(n)
}
if tab.len() > tableIPLimit {
t.Errorf("too many nodes in table")
}
}
// This checks that the table-wide IP limit is applied correctly.
func TestTable_BucketIPLimit(t *testing.T) {
transport := newPingRecorder()
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close()
d := 3
for i := 0; i < bucketIPLimit+1; i++ {
n := nodeAtDistance(tab.self.sha, d)
n.IP = net.IP{172, 0, 1, byte(i)}
tab.add(n)
}
if tab.len() > bucketIPLimit {
t.Errorf("too many nodes in table")
}
}
// fillBucket inserts nodes into the given bucket until // fillBucket inserts nodes into the given bucket until
// it is full. The node's IDs dont correspond to their // it is full. The node's IDs dont correspond to their
// hashes. // hashes.
func fillBucket(tab *Table, ld int) (last *Node) { func fillBucket(tab *Table, n *Node) (last *Node) {
b := tab.buckets[ld] ld := logdist(tab.self.sha, n.sha)
b := tab.bucket(n.sha)
for len(b.entries) < bucketSize { for len(b.entries) < bucketSize {
b.entries = append(b.entries, nodeAtDistance(tab.self.sha, ld)) b.entries = append(b.entries, nodeAtDistance(tab.self.sha, ld))
} }
@ -146,30 +186,39 @@ func fillBucket(tab *Table, ld int) (last *Node) {
func nodeAtDistance(base common.Hash, ld int) (n *Node) { func nodeAtDistance(base common.Hash, ld int) (n *Node) {
n = new(Node) n = new(Node)
n.sha = hashAtDistance(base, ld) n.sha = hashAtDistance(base, ld)
n.IP = net.IP{10, 0, 2, byte(ld)} n.IP = net.IP{byte(ld), 0, 2, byte(ld)}
copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID
return n return n
} }
type pingRecorder struct{ responding, pinged map[NodeID]bool } type pingRecorder struct {
mu sync.Mutex
dead, pinged map[NodeID]bool
}
func newPingRecorder() *pingRecorder { func newPingRecorder() *pingRecorder {
return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)} return &pingRecorder{
dead: make(map[NodeID]bool),
pinged: make(map[NodeID]bool),
}
} }
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
panic("findnode called on pingRecorder") return nil, nil
} }
func (t *pingRecorder) close() {} func (t *pingRecorder) close() {}
func (t *pingRecorder) waitping(from NodeID) error { func (t *pingRecorder) waitping(from NodeID) error {
return nil // remote always pings return nil // remote always pings
} }
func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error { func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
t.mu.Lock()
defer t.mu.Unlock()
t.pinged[toid] = true t.pinged[toid] = true
if t.responding[toid] { if t.dead[toid] {
return nil
} else {
return errTimeout return errTimeout
} else {
return nil
} }
} }
@ -178,7 +227,8 @@ func TestTable_closest(t *testing.T) {
test := func(test *closeTest) bool { test := func(test *closeTest) bool {
// for any node table, Target and N // for any node table, Target and N
tab, _ := newTable(nil, test.Self, &net.UDPAddr{}, "") transport := newPingRecorder()
tab, _ := newTable(transport, test.Self, &net.UDPAddr{}, "", nil)
defer tab.Close() defer tab.Close()
tab.stuff(test.All) tab.stuff(test.All)
@ -237,8 +287,11 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
}, },
} }
test := func(buf []*Node) bool { test := func(buf []*Node) bool {
tab, _ := newTable(nil, NodeID{}, &net.UDPAddr{}, "") transport := newPingRecorder()
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close() defer tab.Close()
<-tab.initDone
for i := 0; i < len(buf); i++ { for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets)) ld := cfg.Rand.Intn(len(tab.buckets))
tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)}) tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)})
@ -280,7 +333,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
func TestTable_Lookup(t *testing.T) { func TestTable_Lookup(t *testing.T) {
self := nodeAtDistance(common.Hash{}, 0) self := nodeAtDistance(common.Hash{}, 0)
tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "") tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "", nil)
defer tab.Close() defer tab.Close()
// lookup on empty table returns no nodes // lookup on empty table returns no nodes

View File

@ -216,9 +216,22 @@ type ReadPacket struct {
Addr *net.UDPAddr Addr *net.UDPAddr
} }
// Config holds Table-related settings.
type Config struct {
// These settings are required and configure the UDP listener:
PrivateKey *ecdsa.PrivateKey
// These settings are optional:
AnnounceAddr *net.UDPAddr // local address announced in the DHT
NodeDBPath string // if set, the node database is stored at this filesystem location
NetRestrict *netutil.Netlist // network whitelist
Bootnodes []*Node // list of bootstrap nodes
Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
}
// ListenUDP returns a new table that listens for UDP packets on laddr. // ListenUDP returns a new table that listens for UDP packets on laddr.
func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) { func ListenUDP(c conn, cfg Config) (*Table, error) {
tab, _, err := newUDP(priv, conn, realaddr, unhandled, nodeDBPath, netrestrict) tab, _, err := newUDP(c, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -226,25 +239,29 @@ func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandl
return tab, nil return tab, nil
} }
func newUDP(priv *ecdsa.PrivateKey, c conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) { func newUDP(c conn, cfg Config) (*Table, *udp, error) {
udp := &udp{ udp := &udp{
conn: c, conn: c,
priv: priv, priv: cfg.PrivateKey,
netrestrict: netrestrict, netrestrict: cfg.NetRestrict,
closing: make(chan struct{}), closing: make(chan struct{}),
gotreply: make(chan reply), gotreply: make(chan reply),
addpending: make(chan *pending), addpending: make(chan *pending),
} }
realaddr := c.LocalAddr().(*net.UDPAddr)
if cfg.AnnounceAddr != nil {
realaddr = cfg.AnnounceAddr
}
// TODO: separate TCP port // TODO: separate TCP port
udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath) tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
udp.Table = tab udp.Table = tab
go udp.loop() go udp.loop()
go udp.readLoop(unhandled) go udp.readLoop(cfg.Unhandled)
return udp.Table, udp, nil return udp.Table, udp, nil
} }
@ -256,14 +273,20 @@ func (t *udp) close() {
// ping sends a ping message to the given node and waits for a reply. // ping sends a ping message to the given node and waits for a reply.
func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
// TODO: maybe check for ReplyTo field in callback to measure RTT req := &ping{
errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
t.send(toaddr, pingPacket, &ping{
Version: Version, Version: Version,
From: t.ourEndpoint, From: t.ourEndpoint,
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
}
packet, hash, err := encodePacket(t.priv, pingPacket, req)
if err != nil {
return err
}
errc := t.pending(toid, pongPacket, func(p interface{}) bool {
return bytes.Equal(p.(*pong).ReplyTok, hash)
}) })
t.write(toaddr, req.name(), packet)
return <-errc return <-errc
} }
@ -447,40 +470,45 @@ func init() {
} }
} }
func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) error { func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) {
packet, err := encodePacket(t.priv, ptype, req) packet, hash, err := encodePacket(t.priv, ptype, req)
if err != nil { if err != nil {
return err return hash, err
} }
_, err = t.conn.WriteToUDP(packet, toaddr) return hash, t.write(toaddr, req.name(), packet)
log.Trace(">> "+req.name(), "addr", toaddr, "err", err) }
func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error {
_, err := t.conn.WriteToUDP(packet, toaddr)
log.Trace(">> "+what, "addr", toaddr, "err", err)
return err return err
} }
func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) { func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) {
b := new(bytes.Buffer) b := new(bytes.Buffer)
b.Write(headSpace) b.Write(headSpace)
b.WriteByte(ptype) b.WriteByte(ptype)
if err := rlp.Encode(b, req); err != nil { if err := rlp.Encode(b, req); err != nil {
log.Error("Can't encode discv4 packet", "err", err) log.Error("Can't encode discv4 packet", "err", err)
return nil, err return nil, nil, err
} }
packet := b.Bytes() packet = b.Bytes()
sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
if err != nil { if err != nil {
log.Error("Can't sign discv4 packet", "err", err) log.Error("Can't sign discv4 packet", "err", err)
return nil, err return nil, nil, err
} }
copy(packet[macSize:], sig) copy(packet[macSize:], sig)
// add the hash to the front. Note: this doesn't protect the // add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in // packet in any way. Our public key will be part of this hash in
// The future. // The future.
copy(packet, crypto.Keccak256(packet[macSize:])) hash = crypto.Keccak256(packet[macSize:])
return packet, nil copy(packet, hash)
return packet, hash, nil
} }
// readLoop runs in its own goroutine. it handles incoming UDP packets. // readLoop runs in its own goroutine. it handles incoming UDP packets.
func (t *udp) readLoop(unhandled chan ReadPacket) { func (t *udp) readLoop(unhandled chan<- ReadPacket) {
defer t.conn.Close() defer t.conn.Close()
if unhandled != nil { if unhandled != nil {
defer close(unhandled) defer close(unhandled)
@ -601,18 +629,22 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
t.mutex.Unlock() t.mutex.Unlock()
p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
var sent bool
// Send neighbors in chunks with at most maxNeighbors per packet // Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit. // to stay below the 1280 byte limit.
for i, n := range closest { for _, n := range closest {
if netutil.CheckRelayIP(from.IP, n.IP) != nil { if netutil.CheckRelayIP(from.IP, n.IP) == nil {
continue p.Nodes = append(p.Nodes, nodeToRPC(n))
} }
p.Nodes = append(p.Nodes, nodeToRPC(n)) if len(p.Nodes) == maxNeighbors {
if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
t.send(from, neighborsPacket, &p) t.send(from, neighborsPacket, &p)
p.Nodes = p.Nodes[:0] p.Nodes = p.Nodes[:0]
sent = true
} }
} }
if len(p.Nodes) > 0 || !sent {
t.send(from, neighborsPacket, &p)
}
return nil return nil
} }

View File

@ -70,14 +70,15 @@ func newUDPTest(t *testing.T) *udpTest {
remotekey: newkey(), remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
} }
realaddr := test.pipe.LocalAddr().(*net.UDPAddr) test.table, test.udp, _ = newUDP(test.pipe, Config{PrivateKey: test.localkey})
test.table, test.udp, _ = newUDP(test.localkey, test.pipe, realaddr, nil, "", nil) // Wait for initial refresh so the table doesn't send unexpected findnode.
<-test.table.initDone
return test return test
} }
// handles a packet as if it had been sent to the transport. // handles a packet as if it had been sent to the transport.
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error { func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
enc, err := encodePacket(test.remotekey, ptype, data) enc, _, err := encodePacket(test.remotekey, ptype, data)
if err != nil { if err != nil {
return test.errorf("packet (%d) encode error: %v", ptype, err) return test.errorf("packet (%d) encode error: %v", ptype, err)
} }
@ -90,19 +91,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
// waits for a packet to be sent by the transport. // waits for a packet to be sent by the transport.
// validate should have type func(*udpTest, X) error, where X is a packet type. // validate should have type func(*udpTest, X) error, where X is a packet type.
func (test *udpTest) waitPacketOut(validate interface{}) error { func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) {
dgram := test.pipe.waitPacketOut() dgram := test.pipe.waitPacketOut()
p, _, _, err := decodePacket(dgram) p, _, hash, err := decodePacket(dgram)
if err != nil { if err != nil {
return test.errorf("sent packet decode error: %v", err) return hash, test.errorf("sent packet decode error: %v", err)
} }
fn := reflect.ValueOf(validate) fn := reflect.ValueOf(validate)
exptype := fn.Type().In(0) exptype := fn.Type().In(0)
if reflect.TypeOf(p) != exptype { if reflect.TypeOf(p) != exptype {
return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
} }
fn.Call([]reflect.Value{reflect.ValueOf(p)}) fn.Call([]reflect.Value{reflect.ValueOf(p)})
return nil return hash, nil
} }
func (test *udpTest) errorf(format string, args ...interface{}) error { func (test *udpTest) errorf(format string, args ...interface{}) error {
@ -351,7 +352,7 @@ func TestUDP_successfulPing(t *testing.T) {
}) })
// remote is unknown, the table pings back. // remote is unknown, the table pings back.
test.waitPacketOut(func(p *ping) error { hash, _ := test.waitPacketOut(func(p *ping) error {
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) { if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) {
t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint) t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint)
} }
@ -365,7 +366,7 @@ func TestUDP_successfulPing(t *testing.T) {
} }
return nil return nil
}) })
test.packetIn(nil, pongPacket, &pong{Expiration: futureExp}) test.packetIn(nil, pongPacket, &pong{ReplyTok: hash, Expiration: futureExp})
// the node should be added to the table shortly after getting the // the node should be added to the table shortly after getting the
// pong packet. // pong packet.

View File

@ -18,8 +18,11 @@
package netutil package netutil
import ( import (
"bytes"
"errors" "errors"
"fmt"
"net" "net"
"sort"
"strings" "strings"
) )
@ -189,3 +192,131 @@ func CheckRelayIP(sender, addr net.IP) error {
} }
return nil return nil
} }
// SameNet reports whether two IP addresses have an equal prefix of the given bit length.
func SameNet(bits uint, ip, other net.IP) bool {
ip4, other4 := ip.To4(), other.To4()
switch {
case (ip4 == nil) != (other4 == nil):
return false
case ip4 != nil:
return sameNet(bits, ip4, other4)
default:
return sameNet(bits, ip.To16(), other.To16())
}
}
func sameNet(bits uint, ip, other net.IP) bool {
nb := int(bits / 8)
mask := ^byte(0xFF >> (bits % 8))
if mask != 0 && nb < len(ip) && ip[nb]&mask != other[nb]&mask {
return false
}
return nb <= len(ip) && bytes.Equal(ip[:nb], other[:nb])
}
// DistinctNetSet tracks IPs, ensuring that at most N of them
// fall into the same network range.
type DistinctNetSet struct {
Subnet uint // number of common prefix bits
Limit uint // maximum number of IPs in each subnet
members map[string]uint
buf net.IP
}
// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the
// number of existing IPs in the defined range exceeds the limit.
func (s *DistinctNetSet) Add(ip net.IP) bool {
key := s.key(ip)
n := s.members[string(key)]
if n < s.Limit {
s.members[string(key)] = n + 1
return true
}
return false
}
// Remove removes an IP from the set.
func (s *DistinctNetSet) Remove(ip net.IP) {
key := s.key(ip)
if n, ok := s.members[string(key)]; ok {
if n == 1 {
delete(s.members, string(key))
} else {
s.members[string(key)] = n - 1
}
}
}
// Contains whether the given IP is contained in the set.
func (s DistinctNetSet) Contains(ip net.IP) bool {
key := s.key(ip)
_, ok := s.members[string(key)]
return ok
}
// Len returns the number of tracked IPs.
func (s DistinctNetSet) Len() int {
n := uint(0)
for _, i := range s.members {
n += i
}
return int(n)
}
// key encodes the map key for an address into a temporary buffer.
//
// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types.
// The remainder of the key is the IP, truncated to the number of bits.
func (s *DistinctNetSet) key(ip net.IP) net.IP {
// Lazily initialize storage.
if s.members == nil {
s.members = make(map[string]uint)
s.buf = make(net.IP, 17)
}
// Canonicalize ip and bits.
typ := byte('6')
if ip4 := ip.To4(); ip4 != nil {
typ, ip = '4', ip4
}
bits := s.Subnet
if bits > uint(len(ip)*8) {
bits = uint(len(ip) * 8)
}
// Encode the prefix into s.buf.
nb := int(bits / 8)
mask := ^byte(0xFF >> (bits % 8))
s.buf[0] = typ
buf := append(s.buf[:1], ip[:nb]...)
if nb < len(ip) && mask != 0 {
buf = append(buf, ip[nb]&mask)
}
return buf
}
// String implements fmt.Stringer
func (s DistinctNetSet) String() string {
var buf bytes.Buffer
buf.WriteString("{")
keys := make([]string, 0, len(s.members))
for k := range s.members {
keys = append(keys, k)
}
sort.Strings(keys)
for i, k := range keys {
var ip net.IP
if k[0] == '4' {
ip = make(net.IP, 4)
} else {
ip = make(net.IP, 16)
}
copy(ip, k[1:])
fmt.Fprintf(&buf, "%v×%d", ip, s.members[k])
if i != len(keys)-1 {
buf.WriteString(" ")
}
}
buf.WriteString("}")
return buf.String()
}

View File

@ -17,9 +17,11 @@
package netutil package netutil
import ( import (
"fmt"
"net" "net"
"reflect" "reflect"
"testing" "testing"
"testing/quick"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
) )
@ -171,3 +173,90 @@ func BenchmarkCheckRelayIP(b *testing.B) {
CheckRelayIP(sender, addr) CheckRelayIP(sender, addr)
} }
} }
func TestSameNet(t *testing.T) {
tests := []struct {
ip, other string
bits uint
want bool
}{
{"0.0.0.0", "0.0.0.0", 32, true},
{"0.0.0.0", "0.0.0.1", 0, true},
{"0.0.0.0", "0.0.0.1", 31, true},
{"0.0.0.0", "0.0.0.1", 32, false},
{"0.33.0.1", "0.34.0.2", 8, true},
{"0.33.0.1", "0.34.0.2", 13, true},
{"0.33.0.1", "0.34.0.2", 15, false},
}
for _, test := range tests {
if ok := SameNet(test.bits, parseIP(test.ip), parseIP(test.other)); ok != test.want {
t.Errorf("SameNet(%d, %s, %s) == %t, want %t", test.bits, test.ip, test.other, ok, test.want)
}
}
}
func ExampleSameNet() {
// This returns true because the IPs are in the same /24 network:
fmt.Println(SameNet(24, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 3}))
// This call returns false:
fmt.Println(SameNet(24, net.IP{127, 3, 0, 1}, net.IP{127, 5, 0, 3}))
// Output:
// true
// false
}
func TestDistinctNetSet(t *testing.T) {
ops := []struct {
add, remove string
fails bool
}{
{add: "127.0.0.1"},
{add: "127.0.0.2"},
{add: "127.0.0.3", fails: true},
{add: "127.32.0.1"},
{add: "127.32.0.2"},
{add: "127.32.0.3", fails: true},
{add: "127.33.0.1", fails: true},
{add: "127.34.0.1"},
{add: "127.34.0.2"},
{add: "127.34.0.3", fails: true},
// Make room for an address, then add again.
{remove: "127.0.0.1"},
{add: "127.0.0.3"},
{add: "127.0.0.3", fails: true},
}
set := DistinctNetSet{Subnet: 15, Limit: 2}
for _, op := range ops {
var desc string
if op.add != "" {
desc = fmt.Sprintf("Add(%s)", op.add)
if ok := set.Add(parseIP(op.add)); ok != !op.fails {
t.Errorf("%s == %t, want %t", desc, ok, !op.fails)
}
} else {
desc = fmt.Sprintf("Remove(%s)", op.remove)
set.Remove(parseIP(op.remove))
}
t.Logf("%s: %v", desc, set)
}
}
func TestDistinctNetSetAddRemove(t *testing.T) {
cfg := &quick.Config{}
fn := func(ips []net.IP) bool {
s := DistinctNetSet{Limit: 3, Subnet: 2}
for _, ip := range ips {
s.Add(ip)
}
for _, ip := range ips {
s.Remove(ip)
}
return s.Len() == 0
}
if err := quick.Check(fn, cfg); err != nil {
t.Fatal(err)
}
}

View File

@ -419,6 +419,9 @@ type PeerInfo struct {
Network struct { Network struct {
LocalAddress string `json:"localAddress"` // Local endpoint of the TCP data connection LocalAddress string `json:"localAddress"` // Local endpoint of the TCP data connection
RemoteAddress string `json:"remoteAddress"` // Remote endpoint of the TCP data connection RemoteAddress string `json:"remoteAddress"` // Remote endpoint of the TCP data connection
Inbound bool `json:"inbound"`
Trusted bool `json:"trusted"`
Static bool `json:"static"`
} `json:"network"` } `json:"network"`
Protocols map[string]interface{} `json:"protocols"` // Sub-protocol specific metadata fields Protocols map[string]interface{} `json:"protocols"` // Sub-protocol specific metadata fields
} }
@ -439,6 +442,9 @@ func (p *Peer) Info() *PeerInfo {
} }
info.Network.LocalAddress = p.LocalAddr().String() info.Network.LocalAddress = p.LocalAddr().String()
info.Network.RemoteAddress = p.RemoteAddr().String() info.Network.RemoteAddress = p.RemoteAddr().String()
info.Network.Inbound = p.rw.is(inboundConn)
info.Network.Trusted = p.rw.is(trustedConn)
info.Network.Static = p.rw.is(staticDialedConn)
// Gather all the running protocol infos // Gather all the running protocol infos
for _, proto := range p.running { for _, proto := range p.running {

View File

@ -40,11 +40,10 @@ const (
refreshPeersInterval = 30 * time.Second refreshPeersInterval = 30 * time.Second
staticPeerCheckInterval = 15 * time.Second staticPeerCheckInterval = 15 * time.Second
// Maximum number of concurrently handshaking inbound connections. // Connectivity defaults.
maxAcceptConns = 50 maxActiveDialTasks = 16
defaultMaxPendingPeers = 50
// Maximum number of concurrently dialing outbound connections. defaultDialRatio = 3
maxActiveDialTasks = 16
// Maximum time allowed for reading a complete message. // Maximum time allowed for reading a complete message.
// This is effectively the amount of time a connection can be idle. // This is effectively the amount of time a connection can be idle.
@ -70,6 +69,11 @@ type Config struct {
// Zero defaults to preset values. // Zero defaults to preset values.
MaxPendingPeers int `toml:",omitempty"` MaxPendingPeers int `toml:",omitempty"`
// DialRatio controls the ratio of inbound to dialed connections.
// Example: a DialRatio of 2 allows 1/2 of connections to be dialed.
// Setting DialRatio to zero defaults it to 3.
DialRatio int `toml:",omitempty"`
// NoDiscovery can be used to disable the peer discovery mechanism. // NoDiscovery can be used to disable the peer discovery mechanism.
// Disabling is useful for protocol debugging (manual topology). // Disabling is useful for protocol debugging (manual topology).
NoDiscovery bool NoDiscovery bool
@ -427,7 +431,6 @@ func (srv *Server) Start() (err error) {
if err != nil { if err != nil {
return err return err
} }
realaddr = conn.LocalAddr().(*net.UDPAddr) realaddr = conn.LocalAddr().(*net.UDPAddr)
if srv.NAT != nil { if srv.NAT != nil {
if !realaddr.IP.IsLoopback() { if !realaddr.IP.IsLoopback() {
@ -447,11 +450,16 @@ func (srv *Server) Start() (err error) {
// node table // node table
if !srv.NoDiscovery { if !srv.NoDiscovery {
ntab, err := discover.ListenUDP(srv.PrivateKey, conn, realaddr, unhandled, srv.NodeDatabase, srv.NetRestrict) cfg := discover.Config{
if err != nil { PrivateKey: srv.PrivateKey,
return err AnnounceAddr: realaddr,
NodeDBPath: srv.NodeDatabase,
NetRestrict: srv.NetRestrict,
Bootnodes: srv.BootstrapNodes,
Unhandled: unhandled,
} }
if err := ntab.SetFallbackNodes(srv.BootstrapNodes); err != nil { ntab, err := discover.ListenUDP(conn, cfg)
if err != nil {
return err return err
} }
srv.ntab = ntab srv.ntab = ntab
@ -476,10 +484,7 @@ func (srv *Server) Start() (err error) {
srv.DiscV5 = ntab srv.DiscV5 = ntab
} }
dynPeers := (srv.MaxPeers + 1) / 2 dynPeers := srv.maxDialedConns()
if srv.NoDiscovery {
dynPeers = 0
}
dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict) dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
// handshake // handshake
@ -536,6 +541,7 @@ func (srv *Server) run(dialstate dialer) {
defer srv.loopWG.Done() defer srv.loopWG.Done()
var ( var (
peers = make(map[discover.NodeID]*Peer) peers = make(map[discover.NodeID]*Peer)
inboundCount = 0
trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes)) trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes))
taskdone = make(chan task, maxActiveDialTasks) taskdone = make(chan task, maxActiveDialTasks)
runningTasks []task runningTasks []task
@ -621,14 +627,14 @@ running:
} }
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
select { select {
case c.cont <- srv.encHandshakeChecks(peers, c): case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c):
case <-srv.quit: case <-srv.quit:
break running break running
} }
case c := <-srv.addpeer: case c := <-srv.addpeer:
// At this point the connection is past the protocol handshake. // At this point the connection is past the protocol handshake.
// Its capabilities are known and the remote identity is verified. // Its capabilities are known and the remote identity is verified.
err := srv.protoHandshakeChecks(peers, c) err := srv.protoHandshakeChecks(peers, inboundCount, c)
if err == nil { if err == nil {
// The handshakes are done and it passed all checks. // The handshakes are done and it passed all checks.
p := newPeer(c, srv.Protocols) p := newPeer(c, srv.Protocols)
@ -639,8 +645,11 @@ running:
} }
name := truncateName(c.name) name := truncateName(c.name)
srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1) srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1)
peers[c.id] = p
go srv.runPeer(p) go srv.runPeer(p)
peers[c.id] = p
if p.Inbound() {
inboundCount++
}
} }
// The dialer logic relies on the assumption that // The dialer logic relies on the assumption that
// dial tasks complete after the peer has been added or // dial tasks complete after the peer has been added or
@ -655,6 +664,9 @@ running:
d := common.PrettyDuration(mclock.Now() - pd.created) d := common.PrettyDuration(mclock.Now() - pd.created)
pd.log.Debug("Removing p2p peer", "duration", d, "peers", len(peers)-1, "req", pd.requested, "err", pd.err) pd.log.Debug("Removing p2p peer", "duration", d, "peers", len(peers)-1, "req", pd.requested, "err", pd.err)
delete(peers, pd.ID()) delete(peers, pd.ID())
if pd.Inbound() {
inboundCount--
}
} }
} }
@ -681,20 +693,22 @@ running:
} }
} }
func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
// Drop connections with no matching protocols. // Drop connections with no matching protocols.
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 { if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
return DiscUselessPeer return DiscUselessPeer
} }
// Repeat the encryption handshake checks because the // Repeat the encryption handshake checks because the
// peer set might have changed between the handshakes. // peer set might have changed between the handshakes.
return srv.encHandshakeChecks(peers, c) return srv.encHandshakeChecks(peers, inboundCount, c)
} }
func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
switch { switch {
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers: case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
return DiscTooManyPeers return DiscTooManyPeers
case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns():
return DiscTooManyPeers
case peers[c.id] != nil: case peers[c.id] != nil:
return DiscAlreadyConnected return DiscAlreadyConnected
case c.id == srv.Self().ID: case c.id == srv.Self().ID:
@ -704,6 +718,21 @@ func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn)
} }
} }
func (srv *Server) maxInboundConns() int {
return srv.MaxPeers - srv.maxDialedConns()
}
func (srv *Server) maxDialedConns() int {
if srv.NoDiscovery || srv.NoDial {
return 0
}
r := srv.DialRatio
if r == 0 {
r = defaultDialRatio
}
return srv.MaxPeers / r
}
type tempError interface { type tempError interface {
Temporary() bool Temporary() bool
} }
@ -714,10 +743,7 @@ func (srv *Server) listenLoop() {
defer srv.loopWG.Done() defer srv.loopWG.Done()
srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab)) srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab))
// This channel acts as a semaphore limiting tokens := defaultMaxPendingPeers
// active inbound connections that are lingering pre-handshake.
// If all slots are taken, no further connections are accepted.
tokens := maxAcceptConns
if srv.MaxPendingPeers > 0 { if srv.MaxPendingPeers > 0 {
tokens = srv.MaxPendingPeers tokens = srv.MaxPendingPeers
} }
@ -758,9 +784,6 @@ func (srv *Server) listenLoop() {
fd = newMeteredConn(fd, true) fd = newMeteredConn(fd, true)
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr()) srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
// Spawn the handler. It will give the slot back when the connection
// has been established.
go func() { go func() {
srv.SetupConn(fd, inboundConn, nil) srv.SetupConn(fd, inboundConn, nil)
slots <- struct{}{} slots <- struct{}{}