cmd/devp2p, p2p: dial using node iterator, discovery crawler (#20132)

* p2p/enode: add Iterator and associated utilities

* p2p/discover: add RandomNodes iterator

* p2p: dial using iterator

* cmd/devp2p: add discv4 crawler

* cmd/devp2p: WIP nodeset filter

* cmd/devp2p: fixup lesFilter

* core/forkid: add NewStaticFilter

* cmd/devp2p: make -eth-network filter actually work

* cmd/devp2p: improve crawl timestamp handling

* cmd/devp2p: fix typo

* p2p/enode: fix comment typos

* p2p/discover: fix comment typos

* p2p/discover: rename lookup.next to 'advance'

* p2p: lower discovery mixer timeout

* p2p/enode: implement dynamic FairMix timeouts

* cmd/devp2p: add ropsten support in -eth-network filter

* cmd/devp2p: tweak crawler log message
This commit is contained in:
Felix Lange 2019-10-29 16:08:57 +01:00 committed by Péter Szilágyi
parent b0b277525c
commit 2c37142d2f
19 changed files with 1559 additions and 414 deletions

152
cmd/devp2p/crawl.go Normal file
View File

@ -0,0 +1,152 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of go-ethereum.
//
// go-ethereum is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// go-ethereum is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
package main
import (
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/enode"
)
type crawler struct {
input nodeSet
output nodeSet
disc *discover.UDPv4
iters []enode.Iterator
inputIter enode.Iterator
ch chan *enode.Node
closed chan struct{}
// settings
revalidateInterval time.Duration
}
func newCrawler(input nodeSet, disc *discover.UDPv4, iters ...enode.Iterator) *crawler {
c := &crawler{
input: input,
output: make(nodeSet, len(input)),
disc: disc,
iters: iters,
inputIter: enode.IterNodes(input.nodes()),
ch: make(chan *enode.Node),
closed: make(chan struct{}),
}
c.iters = append(c.iters, c.inputIter)
// Copy input to output initially. Any nodes that fail validation
// will be dropped from output during the run.
for id, n := range input {
c.output[id] = n
}
return c
}
func (c *crawler) run(timeout time.Duration) nodeSet {
var (
timeoutTimer = time.NewTimer(timeout)
timeoutCh <-chan time.Time
doneCh = make(chan enode.Iterator, len(c.iters))
liveIters = len(c.iters)
)
for _, it := range c.iters {
go c.runIterator(doneCh, it)
}
loop:
for {
select {
case n := <-c.ch:
c.updateNode(n)
case it := <-doneCh:
if it == c.inputIter {
// Enable timeout when we're done revalidating the input nodes.
log.Info("Revalidation of input set is done", "len", len(c.input))
if timeout > 0 {
timeoutCh = timeoutTimer.C
}
}
if liveIters--; liveIters == 0 {
break loop
}
case <-timeoutCh:
break loop
}
}
close(c.closed)
for _, it := range c.iters {
it.Close()
}
for ; liveIters > 0; liveIters-- {
<-doneCh
}
return c.output
}
func (c *crawler) runIterator(done chan<- enode.Iterator, it enode.Iterator) {
defer func() { done <- it }()
for it.Next() {
select {
case c.ch <- it.Node():
case <-c.closed:
return
}
}
}
func (c *crawler) updateNode(n *enode.Node) {
node, ok := c.output[n.ID()]
// Skip validation of recently-seen nodes.
if ok && time.Since(node.LastCheck) < c.revalidateInterval {
return
}
// Request the node record.
nn, err := c.disc.RequestENR(n)
node.LastCheck = truncNow()
if err != nil {
if node.Score == 0 {
// Node doesn't implement EIP-868.
log.Debug("Skipping node", "id", n.ID())
return
}
node.Score /= 2
} else {
node.N = nn
node.Seq = nn.Seq()
node.Score++
if node.FirstResponse.IsZero() {
node.FirstResponse = node.LastCheck
}
node.LastResponse = node.LastCheck
}
// Store/update node in output set.
if node.Score <= 0 {
log.Info("Removing node", "id", n.ID())
delete(c.output, n.ID())
} else {
log.Info("Updating node", "id", n.ID(), "seq", n.Seq(), "score", node.Score)
c.output[n.ID()] = node
}
}
func truncNow() time.Time {
return time.Now().UTC().Truncate(1 * time.Second)
}

View File

@ -39,6 +39,7 @@ var (
discv4RequestRecordCommand, discv4RequestRecordCommand,
discv4ResolveCommand, discv4ResolveCommand,
discv4ResolveJSONCommand, discv4ResolveJSONCommand,
discv4CrawlCommand,
}, },
} }
discv4PingCommand = cli.Command{ discv4PingCommand = cli.Command{
@ -67,12 +68,25 @@ var (
Flags: []cli.Flag{bootnodesFlag}, Flags: []cli.Flag{bootnodesFlag},
ArgsUsage: "<nodes.json file>", ArgsUsage: "<nodes.json file>",
} }
discv4CrawlCommand = cli.Command{
Name: "crawl",
Usage: "Updates a nodes.json file with random nodes found in the DHT",
Action: discv4Crawl,
Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag},
}
) )
var bootnodesFlag = cli.StringFlag{ var (
bootnodesFlag = cli.StringFlag{
Name: "bootnodes", Name: "bootnodes",
Usage: "Comma separated nodes used for bootstrapping", Usage: "Comma separated nodes used for bootstrapping",
} }
crawlTimeoutFlag = cli.DurationFlag{
Name: "timeout",
Usage: "Time limit for the crawl.",
Value: 30 * time.Minute,
}
)
func discv4Ping(ctx *cli.Context) error { func discv4Ping(ctx *cli.Context) error {
n := getNodeArg(ctx) n := getNodeArg(ctx)
@ -113,30 +127,48 @@ func discv4ResolveJSON(ctx *cli.Context) error {
if ctx.NArg() < 1 { if ctx.NArg() < 1 {
return fmt.Errorf("need nodes file as argument") return fmt.Errorf("need nodes file as argument")
} }
disc := startV4(ctx) nodesFile := ctx.Args().Get(0)
defer disc.Close() inputSet := make(nodeSet)
file := ctx.Args().Get(0) if common.FileExist(nodesFile) {
inputSet = loadNodesJSON(nodesFile)
// Load existing nodes in file.
var nodes []*enode.Node
if common.FileExist(file) {
nodes = loadNodesJSON(file).nodes()
} }
// Add nodes from command line arguments.
// Add extra nodes from command line arguments.
var nodeargs []*enode.Node
for i := 1; i < ctx.NArg(); i++ { for i := 1; i < ctx.NArg(); i++ {
n, err := parseNode(ctx.Args().Get(i)) n, err := parseNode(ctx.Args().Get(i))
if err != nil { if err != nil {
exit(err) exit(err)
} }
nodes = append(nodes, n) nodeargs = append(nodeargs, n)
} }
result := make(nodeSet, len(nodes)) // Run the crawler.
for _, n := range nodes { disc := startV4(ctx)
n = disc.Resolve(n) defer disc.Close()
result[n.ID()] = nodeJSON{Seq: n.Seq(), N: n} c := newCrawler(inputSet, disc, enode.IterNodes(nodeargs))
c.revalidateInterval = 0
output := c.run(0)
writeNodesJSON(nodesFile, output)
return nil
}
func discv4Crawl(ctx *cli.Context) error {
if ctx.NArg() < 1 {
return fmt.Errorf("need nodes file as argument")
} }
writeNodesJSON(file, result) nodesFile := ctx.Args().First()
var inputSet nodeSet
if common.FileExist(nodesFile) {
inputSet = loadNodesJSON(nodesFile)
}
disc := startV4(ctx)
defer disc.Close()
c := newCrawler(inputSet, disc, disc.RandomNodes())
c.revalidateInterval = 10 * time.Minute
output := c.run(ctx.Duration(crawlTimeoutFlag.Name))
writeNodesJSON(nodesFile, output)
return nil return nil
} }

View File

@ -109,7 +109,8 @@ func dnsSync(ctx *cli.Context) error {
} }
def := treeToDefinition(url, t) def := treeToDefinition(url, t)
def.Meta.LastModified = time.Now() def.Meta.LastModified = time.Now()
writeTreeDefinition(outdir, def) writeTreeMetadata(outdir, def)
writeTreeNodes(outdir, def)
return nil return nil
} }
@ -151,7 +152,7 @@ func dnsSign(ctx *cli.Context) error {
def = treeToDefinition(url, t) def = treeToDefinition(url, t)
def.Meta.LastModified = time.Now() def.Meta.LastModified = time.Now()
writeTreeDefinition(defdir, def) writeTreeMetadata(defdir, def)
return nil return nil
} }
@ -315,26 +316,28 @@ func ensureValidTreeSignature(t *dnsdisc.Tree, pubkey *ecdsa.PublicKey, sig stri
return nil return nil
} }
// writeTreeDefinition writes a DNS node tree definition to the given directory. // writeTreeMetadata writes a DNS node tree metadata file to the given directory.
func writeTreeDefinition(directory string, def *dnsDefinition) { func writeTreeMetadata(directory string, def *dnsDefinition) {
metaJSON, err := json.MarshalIndent(&def.Meta, "", jsonIndent) metaJSON, err := json.MarshalIndent(&def.Meta, "", jsonIndent)
if err != nil { if err != nil {
exit(err) exit(err)
} }
// Convert nodes.
nodes := make(nodeSet, len(def.Nodes))
nodes.add(def.Nodes...)
// Write.
if err := os.Mkdir(directory, 0744); err != nil && !os.IsExist(err) { if err := os.Mkdir(directory, 0744); err != nil && !os.IsExist(err) {
exit(err) exit(err)
} }
metaFile, nodesFile := treeDefinitionFiles(directory) metaFile, _ := treeDefinitionFiles(directory)
writeNodesJSON(nodesFile, nodes)
if err := ioutil.WriteFile(metaFile, metaJSON, 0644); err != nil { if err := ioutil.WriteFile(metaFile, metaJSON, 0644); err != nil {
exit(err) exit(err)
} }
} }
func writeTreeNodes(directory string, def *dnsDefinition) {
ns := make(nodeSet, len(def.Nodes))
ns.add(def.Nodes...)
_, nodesFile := treeDefinitionFiles(directory)
writeNodesJSON(nodesFile, ns)
}
func treeDefinitionFiles(directory string) (string, string) { func treeDefinitionFiles(directory string) (string, string) {
meta := filepath.Join(directory, "enrtree-info.json") meta := filepath.Join(directory, "enrtree-info.json")
nodes := filepath.Join(directory, "nodes.json") nodes := filepath.Join(directory, "nodes.json")

View File

@ -60,6 +60,7 @@ func init() {
enrdumpCommand, enrdumpCommand,
discv4Command, discv4Command,
dnsCommand, dnsCommand,
nodesetCommand,
} }
} }

View File

@ -21,7 +21,9 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os"
"sort" "sort"
"time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
@ -36,6 +38,15 @@ type nodeSet map[enode.ID]nodeJSON
type nodeJSON struct { type nodeJSON struct {
Seq uint64 `json:"seq"` Seq uint64 `json:"seq"`
N *enode.Node `json:"record"` N *enode.Node `json:"record"`
// The score tracks how many liveness checks were performed. It is incremented by one
// every time the node passes a check, and halved every time it doesn't.
Score int `json:"score,omitempty"`
// These two track the time of last successful contact.
FirstResponse time.Time `json:"firstResponse,omitempty"`
LastResponse time.Time `json:"lastResponse,omitempty"`
// This one tracks the time of our last attempt to contact the node.
LastCheck time.Time `json:"lastCheck,omitempty"`
} }
func loadNodesJSON(file string) nodeSet { func loadNodesJSON(file string) nodeSet {
@ -51,6 +62,10 @@ func writeNodesJSON(file string, nodes nodeSet) {
if err != nil { if err != nil {
exit(err) exit(err)
} }
if file == "-" {
os.Stdout.Write(nodesJSON)
return
}
if err := ioutil.WriteFile(file, nodesJSON, 0644); err != nil { if err := ioutil.WriteFile(file, nodesJSON, 0644); err != nil {
exit(err) exit(err)
} }

193
cmd/devp2p/nodesetcmd.go Normal file
View File

@ -0,0 +1,193 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of go-ethereum.
//
// go-ethereum is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// go-ethereum is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
package main
import (
"fmt"
"net"
"time"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rlp"
"gopkg.in/urfave/cli.v1"
)
var (
nodesetCommand = cli.Command{
Name: "nodeset",
Usage: "Node set tools",
Subcommands: []cli.Command{
nodesetInfoCommand,
nodesetFilterCommand,
},
}
nodesetInfoCommand = cli.Command{
Name: "info",
Usage: "Shows statistics about a node set",
Action: nodesetInfo,
ArgsUsage: "<nodes.json>",
}
nodesetFilterCommand = cli.Command{
Name: "filter",
Usage: "Filters a node set",
Action: nodesetFilter,
ArgsUsage: "<nodes.json> filters..",
SkipFlagParsing: true,
}
)
func nodesetInfo(ctx *cli.Context) error {
if ctx.NArg() < 1 {
return fmt.Errorf("need nodes file as argument")
}
ns := loadNodesJSON(ctx.Args().First())
fmt.Printf("Set contains %d nodes.\n", len(ns))
return nil
}
func nodesetFilter(ctx *cli.Context) error {
if ctx.NArg() < 1 {
return fmt.Errorf("need nodes file as argument")
}
ns := loadNodesJSON(ctx.Args().First())
filter, err := andFilter(ctx.Args().Tail())
if err != nil {
return err
}
result := make(nodeSet)
for id, n := range ns {
if filter(n) {
result[id] = n
}
}
writeNodesJSON("-", result)
return nil
}
type nodeFilter func(nodeJSON) bool
type nodeFilterC struct {
narg int
fn func([]string) (nodeFilter, error)
}
var filterFlags = map[string]nodeFilterC{
"-ip": {1, ipFilter},
"-min-age": {1, minAgeFilter},
"-eth-network": {1, ethFilter},
"-les-server": {0, lesFilter},
}
func parseFilters(args []string) ([]nodeFilter, error) {
var filters []nodeFilter
for len(args) > 0 {
fc, ok := filterFlags[args[0]]
if !ok {
return nil, fmt.Errorf("invalid filter %q", args[0])
}
if len(args) < fc.narg {
return nil, fmt.Errorf("filter %q wants %d arguments, have %d", args[0], fc.narg, len(args))
}
filter, err := fc.fn(args[1:])
if err != nil {
return nil, fmt.Errorf("%s: %v", args[0], err)
}
filters = append(filters, filter)
args = args[fc.narg+1:]
}
return filters, nil
}
func andFilter(args []string) (nodeFilter, error) {
checks, err := parseFilters(args)
if err != nil {
return nil, err
}
f := func(n nodeJSON) bool {
for _, filter := range checks {
if !filter(n) {
return false
}
}
return true
}
return f, nil
}
func ipFilter(args []string) (nodeFilter, error) {
_, cidr, err := net.ParseCIDR(args[0])
if err != nil {
return nil, err
}
f := func(n nodeJSON) bool { return cidr.Contains(n.N.IP()) }
return f, nil
}
func minAgeFilter(args []string) (nodeFilter, error) {
minage, err := time.ParseDuration(args[0])
if err != nil {
return nil, err
}
f := func(n nodeJSON) bool {
age := n.LastResponse.Sub(n.FirstResponse)
return age >= minage
}
return f, nil
}
func ethFilter(args []string) (nodeFilter, error) {
var filter func(forkid.ID) error
switch args[0] {
case "mainnet":
filter = forkid.NewStaticFilter(params.MainnetChainConfig, params.MainnetGenesisHash)
case "rinkeby":
filter = forkid.NewStaticFilter(params.RinkebyChainConfig, params.RinkebyGenesisHash)
case "goerli":
filter = forkid.NewStaticFilter(params.GoerliChainConfig, params.GoerliGenesisHash)
case "ropsten":
filter = forkid.NewStaticFilter(params.TestnetChainConfig, params.TestnetGenesisHash)
default:
return nil, fmt.Errorf("unknown network %q", args[0])
}
f := func(n nodeJSON) bool {
var eth struct {
ForkID forkid.ID
_ []rlp.RawValue `rlp:"tail"`
}
if n.N.Load(enr.WithEntry("eth", &eth)) != nil {
return false
}
return filter(eth.ForkID) == nil
}
return f, nil
}
func lesFilter(args []string) (nodeFilter, error) {
f := func(n nodeJSON) bool {
var les struct {
_ []rlp.RawValue `rlp:"tail"`
}
return n.N.Load(enr.WithEntry("les", &les)) == nil
}
return f, nil
}

View File

@ -80,7 +80,7 @@ func newID(config *params.ChainConfig, genesis common.Hash, head uint64) ID {
return ID{Hash: checksumToBytes(hash), Next: next} return ID{Hash: checksumToBytes(hash), Next: next}
} }
// NewFilter creates an filter that returns if a fork ID should be rejected or not // NewFilter creates a filter that returns if a fork ID should be rejected or not
// based on the local chain's status. // based on the local chain's status.
func NewFilter(chain *core.BlockChain) func(id ID) error { func NewFilter(chain *core.BlockChain) func(id ID) error {
return newFilter( return newFilter(
@ -92,6 +92,12 @@ func NewFilter(chain *core.BlockChain) func(id ID) error {
) )
} }
// NewStaticFilter creates a filter at block zero.
func NewStaticFilter(config *params.ChainConfig, genesis common.Hash) func(id ID) error {
head := func() uint64 { return 0 }
return newFilter(config, genesis, head)
}
// newFilter is the internal version of NewFilter, taking closures as its arguments // newFilter is the internal version of NewFilter, taking closures as its arguments
// instead of a chain. The reason is to allow testing it without having to simulate // instead of a chain. The reason is to allow testing it without having to simulate
// an entire blockchain. // an entire blockchain.

View File

@ -33,12 +33,7 @@ const (
// private networks. // private networks.
dialHistoryExpiration = inboundThrottleTime + 5*time.Second dialHistoryExpiration = inboundThrottleTime + 5*time.Second
// Discovery lookups are throttled and can only run // If no peers are found for this amount of time, the initial bootnodes are dialed.
// once every few seconds.
lookupInterval = 4 * time.Second
// If no peers are found for this amount of time, the initial bootnodes are
// attempted to be connected.
fallbackInterval = 20 * time.Second fallbackInterval = 20 * time.Second
// Endpoint resolution is throttled with bounded backoff. // Endpoint resolution is throttled with bounded backoff.
@ -52,6 +47,10 @@ type NodeDialer interface {
Dial(*enode.Node) (net.Conn, error) Dial(*enode.Node) (net.Conn, error)
} }
type nodeResolver interface {
Resolve(*enode.Node) *enode.Node
}
// TCPDialer implements the NodeDialer interface by using a net.Dialer to // TCPDialer implements the NodeDialer interface by using a net.Dialer to
// create TCP connections to nodes in the network // create TCP connections to nodes in the network
type TCPDialer struct { type TCPDialer struct {
@ -69,7 +68,6 @@ func (t TCPDialer) Dial(dest *enode.Node) (net.Conn, error) {
// of the main loop in Server.run. // of the main loop in Server.run.
type dialstate struct { type dialstate struct {
maxDynDials int maxDynDials int
ntab discoverTable
netrestrict *netutil.Netlist netrestrict *netutil.Netlist
self enode.ID self enode.ID
bootnodes []*enode.Node // default dials when there are no peers bootnodes []*enode.Node // default dials when there are no peers
@ -79,55 +77,23 @@ type dialstate struct {
lookupRunning bool lookupRunning bool
dialing map[enode.ID]connFlag dialing map[enode.ID]connFlag
lookupBuf []*enode.Node // current discovery lookup results lookupBuf []*enode.Node // current discovery lookup results
randomNodes []*enode.Node // filled from Table
static map[enode.ID]*dialTask static map[enode.ID]*dialTask
hist expHeap hist expHeap
} }
type discoverTable interface {
Close()
Resolve(*enode.Node) *enode.Node
LookupRandom() []*enode.Node
ReadRandomNodes([]*enode.Node) int
}
type task interface { type task interface {
Do(*Server) Do(*Server)
} }
// A dialTask is generated for each node that is dialed. Its func newDialState(self enode.ID, maxdyn int, cfg *Config) *dialstate {
// fields cannot be accessed while the task is running.
type dialTask struct {
flags connFlag
dest *enode.Node
lastResolved time.Time
resolveDelay time.Duration
}
// discoverTask runs discovery table operations.
// Only one discoverTask is active at any time.
// discoverTask.Do performs a random lookup.
type discoverTask struct {
results []*enode.Node
}
// A waitExpireTask is generated if there are no other tasks
// to keep the loop in Server.run ticking.
type waitExpireTask struct {
time.Duration
}
func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate {
s := &dialstate{ s := &dialstate{
maxDynDials: maxdyn, maxDynDials: maxdyn,
ntab: ntab,
self: self, self: self,
netrestrict: cfg.NetRestrict, netrestrict: cfg.NetRestrict,
log: cfg.Logger, log: cfg.Logger,
static: make(map[enode.ID]*dialTask), static: make(map[enode.ID]*dialTask),
dialing: make(map[enode.ID]connFlag), dialing: make(map[enode.ID]connFlag),
bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)), bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)),
randomNodes: make([]*enode.Node, maxdyn/2),
} }
copy(s.bootnodes, cfg.BootstrapNodes) copy(s.bootnodes, cfg.BootstrapNodes)
if s.log == nil { if s.log == nil {
@ -151,10 +117,6 @@ func (s *dialstate) removeStatic(n *enode.Node) {
} }
func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task { func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task {
if s.start.IsZero() {
s.start = now
}
var newtasks []task var newtasks []task
addDial := func(flag connFlag, n *enode.Node) bool { addDial := func(flag connFlag, n *enode.Node) bool {
if err := s.checkDial(n, peers); err != nil { if err := s.checkDial(n, peers); err != nil {
@ -166,20 +128,9 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
return true return true
} }
// Compute number of dynamic dials necessary at this point. if s.start.IsZero() {
needDynDials := s.maxDynDials s.start = now
for _, p := range peers {
if p.rw.is(dynDialedConn) {
needDynDials--
} }
}
for _, flag := range s.dialing {
if flag&dynDialedConn != 0 {
needDynDials--
}
}
// Expire the dial history on every invocation.
s.hist.expire(now) s.hist.expire(now)
// Create dials for static nodes if they are not connected. // Create dials for static nodes if they are not connected.
@ -194,6 +145,20 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
newtasks = append(newtasks, t) newtasks = append(newtasks, t)
} }
} }
// Compute number of dynamic dials needed.
needDynDials := s.maxDynDials
for _, p := range peers {
if p.rw.is(dynDialedConn) {
needDynDials--
}
}
for _, flag := range s.dialing {
if flag&dynDialedConn != 0 {
needDynDials--
}
}
// If we don't have any peers whatsoever, try to dial a random bootnode. This // If we don't have any peers whatsoever, try to dial a random bootnode. This
// scenario is useful for the testnet (and private networks) where the discovery // scenario is useful for the testnet (and private networks) where the discovery
// table might be full of mostly bad peers, making it hard to find good ones. // table might be full of mostly bad peers, making it hard to find good ones.
@ -201,24 +166,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
bootnode := s.bootnodes[0] bootnode := s.bootnodes[0]
s.bootnodes = append(s.bootnodes[:0], s.bootnodes[1:]...) s.bootnodes = append(s.bootnodes[:0], s.bootnodes[1:]...)
s.bootnodes = append(s.bootnodes, bootnode) s.bootnodes = append(s.bootnodes, bootnode)
if addDial(dynDialedConn, bootnode) { if addDial(dynDialedConn, bootnode) {
needDynDials-- needDynDials--
} }
} }
// Use random nodes from the table for half of the necessary
// dynamic dials. // Create dynamic dials from discovery results.
randomCandidates := needDynDials / 2
if randomCandidates > 0 {
n := s.ntab.ReadRandomNodes(s.randomNodes)
for i := 0; i < randomCandidates && i < n; i++ {
if addDial(dynDialedConn, s.randomNodes[i]) {
needDynDials--
}
}
}
// Create dynamic dials from random lookup results, removing tried
// items from the result buffer.
i := 0 i := 0
for ; i < len(s.lookupBuf) && needDynDials > 0; i++ { for ; i < len(s.lookupBuf) && needDynDials > 0; i++ {
if addDial(dynDialedConn, s.lookupBuf[i]) { if addDial(dynDialedConn, s.lookupBuf[i]) {
@ -226,10 +179,11 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
} }
} }
s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])] s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])]
// Launch a discovery lookup if more candidates are needed. // Launch a discovery lookup if more candidates are needed.
if len(s.lookupBuf) < needDynDials && !s.lookupRunning { if len(s.lookupBuf) < needDynDials && !s.lookupRunning {
s.lookupRunning = true s.lookupRunning = true
newtasks = append(newtasks, &discoverTask{}) newtasks = append(newtasks, &discoverTask{want: needDynDials - len(s.lookupBuf)})
} }
// Launch a timer to wait for the next node to expire if all // Launch a timer to wait for the next node to expire if all
@ -279,6 +233,15 @@ func (s *dialstate) taskDone(t task, now time.Time) {
} }
} }
// A dialTask is generated for each node that is dialed. Its
// fields cannot be accessed while the task is running.
type dialTask struct {
flags connFlag
dest *enode.Node
lastResolved time.Time
resolveDelay time.Duration
}
func (t *dialTask) Do(srv *Server) { func (t *dialTask) Do(srv *Server) {
if t.dest.Incomplete() { if t.dest.Incomplete() {
if !t.resolve(srv) { if !t.resolve(srv) {
@ -304,8 +267,8 @@ func (t *dialTask) Do(srv *Server) {
// discovery network with useless queries for nodes that don't exist. // discovery network with useless queries for nodes that don't exist.
// The backoff delay resets when the node is found. // The backoff delay resets when the node is found.
func (t *dialTask) resolve(srv *Server) bool { func (t *dialTask) resolve(srv *Server) bool {
if srv.ntab == nil { if srv.staticNodeResolver == nil {
srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled") srv.log.Debug("Can't resolve node", "id", t.dest.ID(), "err", "discovery is disabled")
return false return false
} }
if t.resolveDelay == 0 { if t.resolveDelay == 0 {
@ -314,20 +277,20 @@ func (t *dialTask) resolve(srv *Server) bool {
if time.Since(t.lastResolved) < t.resolveDelay { if time.Since(t.lastResolved) < t.resolveDelay {
return false return false
} }
resolved := srv.ntab.Resolve(t.dest) resolved := srv.staticNodeResolver.Resolve(t.dest)
t.lastResolved = time.Now() t.lastResolved = time.Now()
if resolved == nil { if resolved == nil {
t.resolveDelay *= 2 t.resolveDelay *= 2
if t.resolveDelay > maxResolveDelay { if t.resolveDelay > maxResolveDelay {
t.resolveDelay = maxResolveDelay t.resolveDelay = maxResolveDelay
} }
srv.log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay) srv.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay)
return false return false
} }
// The node was found. // The node was found.
t.resolveDelay = initialResolveDelay t.resolveDelay = initialResolveDelay
t.dest = resolved t.dest = resolved
srv.log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) srv.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
return true return true
} }
@ -350,26 +313,34 @@ func (t *dialTask) String() string {
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP()) return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP())
} }
// discoverTask runs discovery table operations.
// Only one discoverTask is active at any time.
// discoverTask.Do performs a random lookup.
type discoverTask struct {
want int
results []*enode.Node
}
func (t *discoverTask) Do(srv *Server) { func (t *discoverTask) Do(srv *Server) {
// newTasks generates a lookup task whenever dynamic dials are t.results = enode.ReadNodes(srv.discmix, t.want)
// necessary. Lookups need to take some time, otherwise the
// event loop spins too fast.
next := srv.lastLookup.Add(lookupInterval)
if now := time.Now(); now.Before(next) {
time.Sleep(next.Sub(now))
}
srv.lastLookup = time.Now()
t.results = srv.ntab.LookupRandom()
} }
func (t *discoverTask) String() string { func (t *discoverTask) String() string {
s := "discovery lookup" s := "discovery query"
if len(t.results) > 0 { if len(t.results) > 0 {
s += fmt.Sprintf(" (%d results)", len(t.results)) s += fmt.Sprintf(" (%d results)", len(t.results))
} else {
s += fmt.Sprintf(" (want %d)", t.want)
} }
return s return s
} }
// A waitExpireTask is generated if there are no other tasks
// to keep the loop in Server.run ticking.
type waitExpireTask struct {
time.Duration
}
func (t waitExpireTask) Do(*Server) { func (t waitExpireTask) Do(*Server) {
time.Sleep(t.Duration) time.Sleep(t.Duration)
} }

View File

@ -73,7 +73,7 @@ func runDialTest(t *testing.T, test dialtest) {
t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v", t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v",
i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running)) i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
} }
t.Logf("round %d new tasks: %s", i, strings.TrimSpace(spew.Sdump(new))) t.Logf("round %d (running %d) new tasks: %s", i, running, strings.TrimSpace(spew.Sdump(new)))
// Time advances by 16 seconds on every round. // Time advances by 16 seconds on every round.
vtime = vtime.Add(16 * time.Second) vtime = vtime.Add(16 * time.Second)
@ -81,19 +81,11 @@ func runDialTest(t *testing.T, test dialtest) {
} }
} }
type fakeTable []*enode.Node
func (t fakeTable) Self() *enode.Node { return new(enode.Node) }
func (t fakeTable) Close() {}
func (t fakeTable) LookupRandom() []*enode.Node { return nil }
func (t fakeTable) Resolve(*enode.Node) *enode.Node { return nil }
func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t) }
// This test checks that dynamic dials are launched from discovery results. // This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) { func TestDialStateDynDial(t *testing.T) {
config := &Config{Logger: testlog.Logger(t, log.LvlTrace)} config := &Config{Logger: testlog.Logger(t, log.LvlTrace)}
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(enode.ID{}, fakeTable{}, 5, config), init: newDialState(enode.ID{}, 5, config),
rounds: []round{ rounds: []round{
// A discovery query is launched. // A discovery query is launched.
{ {
@ -102,7 +94,9 @@ func TestDialStateDynDial(t *testing.T) {
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
}, },
new: []task{&discoverTask{}}, new: []task{
&discoverTask{want: 3},
},
}, },
// Dynamic dials are launched when it completes. // Dynamic dials are launched when it completes.
{ {
@ -188,7 +182,7 @@ func TestDialStateDynDial(t *testing.T) {
}, },
new: []task{ new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)},
&discoverTask{}, &discoverTask{want: 2},
}, },
}, },
// Peer 7 is connected, but there still aren't enough dynamic peers // Peer 7 is connected, but there still aren't enough dynamic peers
@ -218,7 +212,7 @@ func TestDialStateDynDial(t *testing.T) {
&discoverTask{}, &discoverTask{},
}, },
new: []task{ new: []task{
&discoverTask{}, &discoverTask{want: 2},
}, },
}, },
}, },
@ -235,35 +229,37 @@ func TestDialStateDynDialBootnode(t *testing.T) {
}, },
Logger: testlog.Logger(t, log.LvlTrace), Logger: testlog.Logger(t, log.LvlTrace),
} }
table := fakeTable{ runDialTest(t, dialtest{
init: newDialState(enode.ID{}, 5, config),
rounds: []round{
{
new: []task{
&discoverTask{want: 5},
},
},
{
done: []task{
&discoverTask{
results: []*enode.Node{
newNode(uintID(4), nil), newNode(uintID(4), nil),
newNode(uintID(5), nil), newNode(uintID(5), nil),
newNode(uintID(6), nil), },
newNode(uintID(7), nil), },
newNode(uintID(8), nil), },
}
runDialTest(t, dialtest{
init: newDialState(enode.ID{}, table, 5, config),
rounds: []round{
// 2 dynamic dials attempted, bootnodes pending fallback interval
{
new: []task{ new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
&discoverTask{}, &discoverTask{want: 3},
}, },
}, },
// No dials succeed, bootnodes still pending fallback interval // No dials succeed, bootnodes still pending fallback interval
{},
// 1 bootnode attempted as fallback interval was reached
{ {
done: []task{ done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
}, },
},
// No dials succeed, bootnodes still pending fallback interval
{},
// No dials succeed, 2 dynamic dials attempted and 1 bootnode too as fallback interval was reached
{
new: []task{ new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
}, },
@ -275,15 +271,12 @@ func TestDialStateDynDialBootnode(t *testing.T) {
}, },
new: []task{ new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
}, },
}, },
// No dials succeed, 3rd bootnode is attempted // No dials succeed, 3rd bootnode is attempted
{ {
done: []task{ done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
}, },
new: []task{ new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
@ -293,115 +286,19 @@ func TestDialStateDynDialBootnode(t *testing.T) {
{ {
done: []task{ done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
},
new: []task{},
},
// Random dial succeeds, no more bootnodes are attempted
{
new: []task{
&waitExpireTask{3 * time.Second},
},
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
},
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
},
},
},
})
}
func TestDialStateDynDialFromTable(t *testing.T) {
// This table always returns the same random nodes
// in the order given below.
table := fakeTable{
newNode(uintID(1), nil),
newNode(uintID(2), nil),
newNode(uintID(3), nil),
newNode(uintID(4), nil),
newNode(uintID(5), nil),
newNode(uintID(6), nil),
newNode(uintID(7), nil),
newNode(uintID(8), nil),
}
runDialTest(t, dialtest{
init: newDialState(enode.ID{}, table, 10, &Config{Logger: testlog.Logger(t, log.LvlTrace)}),
rounds: []round{
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
{
new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
&discoverTask{},
},
},
// Dialing nodes 1,2 succeeds. Dials from the lookup are launched.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
},
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
&discoverTask{results: []*enode.Node{ &discoverTask{results: []*enode.Node{
newNode(uintID(10), nil), newNode(uintID(6), nil),
newNode(uintID(11), nil),
newNode(uintID(12), nil),
}}, }},
}, },
new: []task{ new: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)}, &discoverTask{want: 4},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)},
&discoverTask{},
}, },
}, },
// Dialing nodes 3,4,5 fails. The dials from the lookup succeed. // Random dial succeeds, no more bootnodes are attempted
{ {
peers: []*Peer{ peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, {rw: &conn{flags: dynDialedConn, node: newNode(uintID(6), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}},
},
done: []task{
&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)},
&dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)},
},
},
// Waiting for expiry. No waitExpireTask is launched because the
// discovery query is still running.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}},
},
},
// Nodes 3,4 are not tried again because only the first two
// returned random nodes (nodes 1,2) are tried and they're
// already connected.
{
peers: []*Peer{
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}},
{rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}},
}, },
}, },
}, },
@ -416,11 +313,11 @@ func newNode(id enode.ID, ip net.IP) *enode.Node {
return enode.SignNull(&r, id) return enode.SignNull(&r, id)
} }
// This test checks that candidates that do not match the netrestrict list are not dialed. // // This test checks that candidates that do not match the netrestrict list are not dialed.
func TestDialStateNetRestrict(t *testing.T) { func TestDialStateNetRestrict(t *testing.T) {
// This table always returns the same random nodes // This table always returns the same random nodes
// in the order given below. // in the order given below.
table := fakeTable{ nodes := []*enode.Node{
newNode(uintID(1), net.ParseIP("127.0.0.1")), newNode(uintID(1), net.ParseIP("127.0.0.1")),
newNode(uintID(2), net.ParseIP("127.0.0.2")), newNode(uintID(2), net.ParseIP("127.0.0.2")),
newNode(uintID(3), net.ParseIP("127.0.0.3")), newNode(uintID(3), net.ParseIP("127.0.0.3")),
@ -434,12 +331,23 @@ func TestDialStateNetRestrict(t *testing.T) {
restrict.Add("127.0.2.0/24") restrict.Add("127.0.2.0/24")
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(enode.ID{}, table, 10, &Config{NetRestrict: restrict}), init: newDialState(enode.ID{}, 10, &Config{NetRestrict: restrict}),
rounds: []round{ rounds: []round{
{ {
new: []task{ new: []task{
&dialTask{flags: dynDialedConn, dest: table[4]}, &discoverTask{want: 10},
&discoverTask{}, },
},
{
done: []task{
&discoverTask{results: nodes},
},
new: []task{
&dialTask{flags: dynDialedConn, dest: nodes[4]},
&dialTask{flags: dynDialedConn, dest: nodes[5]},
&dialTask{flags: dynDialedConn, dest: nodes[6]},
&dialTask{flags: dynDialedConn, dest: nodes[7]},
&discoverTask{want: 6},
}, },
}, },
}, },
@ -459,7 +367,7 @@ func TestDialStateStaticDial(t *testing.T) {
Logger: testlog.Logger(t, log.LvlTrace), Logger: testlog.Logger(t, log.LvlTrace),
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(enode.ID{}, fakeTable{}, 0, config), init: newDialState(enode.ID{}, 0, config),
rounds: []round{ rounds: []round{
// Static dials are launched for the nodes that // Static dials are launched for the nodes that
// aren't yet connected. // aren't yet connected.
@ -544,7 +452,7 @@ func TestDialStateCache(t *testing.T) {
Logger: testlog.Logger(t, log.LvlTrace), Logger: testlog.Logger(t, log.LvlTrace),
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(enode.ID{}, fakeTable{}, 0, config), init: newDialState(enode.ID{}, 0, config),
rounds: []round{ rounds: []round{
// Static dials are launched for the nodes that // Static dials are launched for the nodes that
// aren't yet connected. // aren't yet connected.
@ -618,8 +526,8 @@ func TestDialResolve(t *testing.T) {
Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}, Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}},
} }
resolved := newNode(uintID(1), net.IP{127, 0, 55, 234}) resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
table := &resolveMock{answer: resolved} resolver := &resolveMock{answer: resolved}
state := newDialState(enode.ID{}, table, 0, config) state := newDialState(enode.ID{}, 0, config)
// Check that the task is generated with an incomplete ID. // Check that the task is generated with an incomplete ID.
dest := newNode(uintID(1), nil) dest := newNode(uintID(1), nil)
@ -630,10 +538,14 @@ func TestDialResolve(t *testing.T) {
} }
// Now run the task, it should resolve the ID once. // Now run the task, it should resolve the ID once.
srv := &Server{ntab: table, log: config.Logger, Config: *config} srv := &Server{
Config: *config,
log: config.Logger,
staticNodeResolver: resolver,
}
tasks[0].Do(srv) tasks[0].Do(srv)
if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) { if !reflect.DeepEqual(resolver.calls, []*enode.Node{dest}) {
t.Fatalf("wrong resolve calls, got %v", table.resolveCalls) t.Fatalf("wrong resolve calls, got %v", resolver.calls)
} }
// Report it as done to the dialer, which should update the static node record. // Report it as done to the dialer, which should update the static node record.
@ -666,18 +578,13 @@ func uintID(i uint32) enode.ID {
return id return id
} }
// implements discoverTable for TestDialResolve // for TestDialResolve
type resolveMock struct { type resolveMock struct {
resolveCalls []*enode.Node calls []*enode.Node
answer *enode.Node answer *enode.Node
} }
func (t *resolveMock) Resolve(n *enode.Node) *enode.Node { func (t *resolveMock) Resolve(n *enode.Node) *enode.Node {
t.resolveCalls = append(t.resolveCalls, n) t.calls = append(t.calls, n)
return t.answer return t.answer
} }
func (t *resolveMock) Self() *enode.Node { return new(enode.Node) }
func (t *resolveMock) Close() {}
func (t *resolveMock) LookupRandom() []*enode.Node { return nil }
func (t *resolveMock) ReadRandomNodes(buf []*enode.Node) int { return 0 }

View File

@ -25,6 +25,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/p2p/netutil"
) )
// UDPConn is a network connection on which discovery can operate.
type UDPConn interface { type UDPConn interface {
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
@ -32,7 +33,7 @@ type UDPConn interface {
LocalAddr() net.Addr LocalAddr() net.Addr
} }
// Config holds Table-related settings. // Config holds settings for the discovery listener.
type Config struct { type Config struct {
// These settings are required and configure the UDP listener: // These settings are required and configure the UDP listener:
PrivateKey *ecdsa.PrivateKey PrivateKey *ecdsa.PrivateKey
@ -50,7 +51,7 @@ func ListenUDP(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
} }
// ReadPacket is a packet that couldn't be handled. Those packets are sent to the unhandled // ReadPacket is a packet that couldn't be handled. Those packets are sent to the unhandled
// channel if configured. // channel if configured. This is exported for internal use, do not use this type.
type ReadPacket struct { type ReadPacket struct {
Data []byte Data []byte
Addr *net.UDPAddr Addr *net.UDPAddr

209
p2p/discover/lookup.go Normal file
View File

@ -0,0 +1,209 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package discover
import (
"context"
"github.com/ethereum/go-ethereum/p2p/enode"
)
// lookup performs a network search for nodes close to the given target. It approaches the
// target by querying nodes that are closer to it on each iteration. The given target does
// not need to be an actual node identifier.
type lookup struct {
tab *Table
queryfunc func(*node) ([]*node, error)
replyCh chan []*node
cancelCh <-chan struct{}
asked, seen map[enode.ID]bool
result nodesByDistance
replyBuffer []*node
queries int
}
type queryFunc func(*node) ([]*node, error)
func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *lookup {
it := &lookup{
tab: tab,
queryfunc: q,
asked: make(map[enode.ID]bool),
seen: make(map[enode.ID]bool),
result: nodesByDistance{target: target},
replyCh: make(chan []*node, alpha),
cancelCh: ctx.Done(),
queries: -1,
}
// Don't query further if we hit ourself.
// Unlikely to happen often in practice.
it.asked[tab.self().ID()] = true
return it
}
// run runs the lookup to completion and returns the closest nodes found.
func (it *lookup) run() []*enode.Node {
for it.advance() {
}
return unwrapNodes(it.result.entries)
}
// advance advances the lookup until any new nodes have been found.
// It returns false when the lookup has ended.
func (it *lookup) advance() bool {
for it.startQueries() {
select {
case nodes := <-it.replyCh:
it.replyBuffer = it.replyBuffer[:0]
for _, n := range nodes {
if n != nil && !it.seen[n.ID()] {
it.seen[n.ID()] = true
it.result.push(n, bucketSize)
it.replyBuffer = append(it.replyBuffer, n)
}
}
it.queries--
if len(it.replyBuffer) > 0 {
return true
}
case <-it.cancelCh:
it.shutdown()
}
}
return false
}
func (it *lookup) shutdown() {
for it.queries > 0 {
<-it.replyCh
it.queries--
}
it.queryfunc = nil
it.replyBuffer = nil
}
func (it *lookup) startQueries() bool {
if it.queryfunc == nil {
return false
}
// The first query returns nodes from the local table.
if it.queries == -1 {
it.tab.mutex.Lock()
closest := it.tab.closest(it.result.target, bucketSize, false)
it.tab.mutex.Unlock()
it.queries = 1
it.replyCh <- closest.entries
return true
}
// Ask the closest nodes that we haven't asked yet.
for i := 0; i < len(it.result.entries) && it.queries < alpha; i++ {
n := it.result.entries[i]
if !it.asked[n.ID()] {
it.asked[n.ID()] = true
it.queries++
go it.query(n, it.replyCh)
}
}
// The lookup ends when no more nodes can be asked.
return it.queries > 0
}
func (it *lookup) query(n *node, reply chan<- []*node) {
fails := it.tab.db.FindFails(n.ID(), n.IP())
r, err := it.queryfunc(n)
if err == errClosed {
// Avoid recording failures on shutdown.
reply <- nil
return
} else if len(r) == 0 {
fails++
it.tab.db.UpdateFindFails(n.ID(), n.IP(), fails)
it.tab.log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err)
if fails >= maxFindnodeFailures {
it.tab.log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails)
it.tab.delete(n)
}
} else if fails > 0 {
// Reset failure counter because it counts _consecutive_ failures.
it.tab.db.UpdateFindFails(n.ID(), n.IP(), 0)
}
// Grab as many nodes as possible. Some of them might not be alive anymore, but we'll
// just remove those again during revalidation.
for _, n := range r {
it.tab.addSeenNode(n)
}
reply <- r
}
// lookupIterator performs lookup operations and iterates over all seen nodes.
// When a lookup finishes, a new one is created through nextLookup.
type lookupIterator struct {
buffer []*node
nextLookup lookupFunc
ctx context.Context
cancel func()
lookup *lookup
}
type lookupFunc func(ctx context.Context) *lookup
func newLookupIterator(ctx context.Context, next lookupFunc) *lookupIterator {
ctx, cancel := context.WithCancel(ctx)
return &lookupIterator{ctx: ctx, cancel: cancel, nextLookup: next}
}
// Node returns the current node.
func (it *lookupIterator) Node() *enode.Node {
if len(it.buffer) == 0 {
return nil
}
return unwrapNode(it.buffer[0])
}
// Next moves to the next node.
func (it *lookupIterator) Next() bool {
// Consume next node in buffer.
if len(it.buffer) > 0 {
it.buffer = it.buffer[1:]
}
// Advance the lookup to refill the buffer.
for len(it.buffer) == 0 {
if it.ctx.Err() != nil {
it.lookup = nil
it.buffer = nil
return false
}
if it.lookup == nil {
it.lookup = it.nextLookup(it.ctx)
continue
}
if !it.lookup.advance() {
it.lookup = nil
continue
}
it.buffer = it.lookup.replyBuffer
}
return true
}
// Close ends the iterator.
func (it *lookupIterator) Close() {
it.cancel()
}

View File

@ -17,11 +17,14 @@
package discover package discover
import ( import (
"bytes"
"crypto/ecdsa" "crypto/ecdsa"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"reflect"
"sort" "sort"
"sync" "sync"
@ -169,6 +172,28 @@ func hasDuplicates(slice []*node) bool {
return false return false
} }
func checkNodesEqual(got, want []*enode.Node) error {
if reflect.DeepEqual(got, want) {
return nil
}
output := new(bytes.Buffer)
fmt.Fprintf(output, "got %d nodes:\n", len(got))
for _, n := range got {
fmt.Fprintf(output, " %v %v\n", n.ID(), n)
}
fmt.Fprintf(output, "want %d:\n", len(want))
for _, n := range want {
fmt.Fprintf(output, " %v %v\n", n.ID(), n)
}
return errors.New(output.String())
}
func sortByID(nodes []*enode.Node) {
sort.Slice(nodes, func(i, j int) bool {
return string(nodes[i].ID().Bytes()) < string(nodes[j].ID().Bytes())
})
}
func sortedByDistanceTo(distbase enode.ID, slice []*node) bool { func sortedByDistanceTo(distbase enode.ID, slice []*node) bool {
return sort.SliceIsSorted(slice, func(i, j int) bool { return sort.SliceIsSorted(slice, func(i, j int) bool {
return enode.DistCmp(distbase, slice[i].ID(), slice[j].ID()) < 0 return enode.DistCmp(distbase, slice[i].ID(), slice[j].ID()) < 0

View File

@ -20,7 +20,6 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"fmt" "fmt"
"net" "net"
"reflect"
"sort" "sort"
"testing" "testing"
@ -49,19 +48,7 @@ func TestUDPv4_Lookup(t *testing.T) {
}() }()
// Answer lookup packets. // Answer lookup packets.
for done := false; !done; { serveTestnet(test, lookupTestnet)
done = test.waitPacketOut(func(p packetV4, to *net.UDPAddr, hash []byte) {
n, key := lookupTestnet.nodeByAddr(to)
switch p.(type) {
case *pingV4:
test.packetInFrom(nil, key, to, &pongV4{Expiration: futureExp, ReplyTok: hash})
case *findnodeV4:
dist := enode.LogDist(n.ID(), lookupTestnet.target.id())
nodes := lookupTestnet.nodesAtDistance(dist - 1)
test.packetInFrom(nil, key, to, &neighborsV4{Expiration: futureExp, Nodes: nodes})
}
})
}
// Verify result nodes. // Verify result nodes.
results := <-resultC results := <-resultC
@ -78,8 +65,94 @@ func TestUDPv4_Lookup(t *testing.T) {
if !sortedByDistanceTo(lookupTestnet.target.id(), wrapNodes(results)) { if !sortedByDistanceTo(lookupTestnet.target.id(), wrapNodes(results)) {
t.Errorf("result set not sorted by distance to target") t.Errorf("result set not sorted by distance to target")
} }
if !reflect.DeepEqual(results, lookupTestnet.closest(bucketSize)) { if err := checkNodesEqual(results, lookupTestnet.closest(bucketSize)); err != nil {
t.Errorf("results aren't the closest %d nodes", bucketSize) t.Errorf("results aren't the closest %d nodes\n%v", bucketSize, err)
}
}
func TestUDPv4_LookupIterator(t *testing.T) {
t.Parallel()
test := newUDPTest(t)
defer test.close()
// Seed table with initial nodes.
bootnodes := make([]*node, len(lookupTestnet.dists[256]))
for i := range lookupTestnet.dists[256] {
bootnodes[i] = wrapNode(lookupTestnet.node(256, i))
}
fillTable(test.table, bootnodes)
go serveTestnet(test, lookupTestnet)
// Create the iterator and collect the nodes it yields.
iter := test.udp.RandomNodes()
seen := make(map[enode.ID]*enode.Node)
for limit := lookupTestnet.len(); iter.Next() && len(seen) < limit; {
seen[iter.Node().ID()] = iter.Node()
}
iter.Close()
// Check that all nodes in lookupTestnet were seen by the iterator.
results := make([]*enode.Node, 0, len(seen))
for _, n := range seen {
results = append(results, n)
}
sortByID(results)
want := lookupTestnet.nodes()
if err := checkNodesEqual(results, want); err != nil {
t.Fatal(err)
}
}
// TestUDPv4_LookupIteratorClose checks that lookupIterator ends when its Close
// method is called.
func TestUDPv4_LookupIteratorClose(t *testing.T) {
t.Parallel()
test := newUDPTest(t)
defer test.close()
// Seed table with initial nodes.
bootnodes := make([]*node, len(lookupTestnet.dists[256]))
for i := range lookupTestnet.dists[256] {
bootnodes[i] = wrapNode(lookupTestnet.node(256, i))
}
fillTable(test.table, bootnodes)
go serveTestnet(test, lookupTestnet)
it := test.udp.RandomNodes()
if ok := it.Next(); !ok || it.Node() == nil {
t.Fatalf("iterator didn't return any node")
}
it.Close()
ncalls := 0
for ; ncalls < 100 && it.Next(); ncalls++ {
if it.Node() == nil {
t.Error("iterator returned Node() == nil node after Next() == true")
}
}
t.Logf("iterator returned %d nodes after close", ncalls)
if it.Next() {
t.Errorf("Next() == true after close and %d more calls", ncalls)
}
if n := it.Node(); n != nil {
t.Errorf("iterator returned non-nil node after close and %d more calls", ncalls)
}
}
func serveTestnet(test *udpTest, testnet *preminedTestnet) {
for done := false; !done; {
done = test.waitPacketOut(func(p packetV4, to *net.UDPAddr, hash []byte) {
n, key := testnet.nodeByAddr(to)
switch p.(type) {
case *pingV4:
test.packetInFrom(nil, key, to, &pongV4{Expiration: futureExp, ReplyTok: hash})
case *findnodeV4:
dist := enode.LogDist(n.ID(), testnet.target.id())
nodes := testnet.nodesAtDistance(dist - 1)
test.packetInFrom(nil, key, to, &neighborsV4{Expiration: futureExp, Nodes: nodes})
}
})
} }
} }
@ -148,6 +221,25 @@ type preminedTestnet struct {
dists [hashBits + 1][]*ecdsa.PrivateKey dists [hashBits + 1][]*ecdsa.PrivateKey
} }
func (tn *preminedTestnet) len() int {
n := 0
for _, keys := range tn.dists {
n += len(keys)
}
return n
}
func (tn *preminedTestnet) nodes() []*enode.Node {
result := make([]*enode.Node, 0, tn.len())
for dist, keys := range tn.dists {
for index := range keys {
result = append(result, tn.node(dist, index))
}
}
sortByID(result)
return result
}
func (tn *preminedTestnet) node(dist, index int) *enode.Node { func (tn *preminedTestnet) node(dist, index int) *enode.Node {
key := tn.dists[dist][index] key := tn.dists[dist][index]
ip := net.IP{127, byte(dist >> 8), byte(dist), byte(index)} ip := net.IP{127, byte(dist >> 8), byte(dist), byte(index)}

View File

@ -19,6 +19,7 @@ package discover
import ( import (
"bytes" "bytes"
"container/list" "container/list"
"context"
"crypto/ecdsa" "crypto/ecdsa"
crand "crypto/rand" crand "crypto/rand"
"errors" "errors"
@ -207,7 +208,8 @@ type UDPv4 struct {
addReplyMatcher chan *replyMatcher addReplyMatcher chan *replyMatcher
gotreply chan reply gotreply chan reply
closing chan struct{} closeCtx context.Context
cancelCloseCtx func()
} }
// replyMatcher represents a pending reply. // replyMatcher represents a pending reply.
@ -256,20 +258,23 @@ type reply struct {
} }
func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
closeCtx, cancel := context.WithCancel(context.Background())
t := &UDPv4{ t := &UDPv4{
conn: c, conn: c,
priv: cfg.PrivateKey, priv: cfg.PrivateKey,
netrestrict: cfg.NetRestrict, netrestrict: cfg.NetRestrict,
localNode: ln, localNode: ln,
db: ln.Database(), db: ln.Database(),
closing: make(chan struct{}),
gotreply: make(chan reply), gotreply: make(chan reply),
addReplyMatcher: make(chan *replyMatcher), addReplyMatcher: make(chan *replyMatcher),
closeCtx: closeCtx,
cancelCloseCtx: cancel,
log: cfg.Log, log: cfg.Log,
} }
if t.log == nil { if t.log == nil {
t.log = log.Root() t.log = log.Root()
} }
tab, err := newTable(t, ln.Database(), cfg.Bootnodes, t.log) tab, err := newTable(t, ln.Database(), cfg.Bootnodes, t.log)
if err != nil { if err != nil {
return nil, err return nil, err
@ -291,126 +296,13 @@ func (t *UDPv4) Self() *enode.Node {
// Close shuts down the socket and aborts any running queries. // Close shuts down the socket and aborts any running queries.
func (t *UDPv4) Close() { func (t *UDPv4) Close() {
t.closeOnce.Do(func() { t.closeOnce.Do(func() {
close(t.closing) t.cancelCloseCtx()
t.conn.Close() t.conn.Close()
t.wg.Wait() t.wg.Wait()
t.tab.close() t.tab.close()
}) })
} }
// ReadRandomNodes reads random nodes from the local table.
func (t *UDPv4) ReadRandomNodes(buf []*enode.Node) int {
return t.tab.ReadRandomNodes(buf)
}
// LookupRandom finds random nodes in the network.
func (t *UDPv4) LookupRandom() []*enode.Node {
if t.tab.len() == 0 {
// All nodes were dropped, refresh. The very first query will hit this
// case and run the bootstrapping logic.
<-t.tab.refresh()
}
return t.lookupRandom()
}
func (t *UDPv4) LookupPubkey(key *ecdsa.PublicKey) []*enode.Node {
if t.tab.len() == 0 {
// All nodes were dropped, refresh. The very first query will hit this
// case and run the bootstrapping logic.
<-t.tab.refresh()
}
return unwrapNodes(t.lookup(encodePubkey(key)))
}
func (t *UDPv4) lookupRandom() []*enode.Node {
var target encPubkey
crand.Read(target[:])
return unwrapNodes(t.lookup(target))
}
func (t *UDPv4) lookupSelf() []*enode.Node {
return unwrapNodes(t.lookup(encodePubkey(&t.priv.PublicKey)))
}
// lookup performs a network search for nodes close to the given target. It approaches the
// target by querying nodes that are closer to it on each iteration. The given target does
// not need to be an actual node identifier.
func (t *UDPv4) lookup(targetKey encPubkey) []*node {
var (
target = enode.ID(crypto.Keccak256Hash(targetKey[:]))
asked = make(map[enode.ID]bool)
seen = make(map[enode.ID]bool)
reply = make(chan []*node, alpha)
pendingQueries = 0
result *nodesByDistance
)
// Don't query further if we hit ourself.
// Unlikely to happen often in practice.
asked[t.Self().ID()] = true
// Generate the initial result set.
t.tab.mutex.Lock()
result = t.tab.closest(target, bucketSize, false)
t.tab.mutex.Unlock()
for {
// ask the alpha closest nodes that we haven't asked yet
for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
n := result.entries[i]
if !asked[n.ID()] {
asked[n.ID()] = true
pendingQueries++
go t.lookupWorker(n, targetKey, reply)
}
}
if pendingQueries == 0 {
// we have asked all closest nodes, stop the search
break
}
select {
case nodes := <-reply:
for _, n := range nodes {
if n != nil && !seen[n.ID()] {
seen[n.ID()] = true
result.push(n, bucketSize)
}
}
case <-t.tab.closeReq:
return nil // shutdown, no need to continue.
}
pendingQueries--
}
return result.entries
}
func (t *UDPv4) lookupWorker(n *node, targetKey encPubkey, reply chan<- []*node) {
fails := t.db.FindFails(n.ID(), n.IP())
r, err := t.findnode(n.ID(), n.addr(), targetKey)
if err == errClosed {
// Avoid recording failures on shutdown.
reply <- nil
return
} else if len(r) == 0 {
fails++
t.db.UpdateFindFails(n.ID(), n.IP(), fails)
t.log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err)
if fails >= maxFindnodeFailures {
t.log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails)
t.tab.delete(n)
}
} else if fails > 0 {
// Reset failure counter because it counts _consecutive_ failures.
t.db.UpdateFindFails(n.ID(), n.IP(), 0)
}
// Grab as many nodes as possible. Some of them might not be alive anymore, but we'll
// just remove those again during revalidation.
for _, n := range r {
t.tab.addSeenNode(n)
}
reply <- r
}
// Resolve searches for a specific node with the given ID and tries to get the most recent // Resolve searches for a specific node with the given ID and tries to get the most recent
// version of the node record for it. It returns n if the node could not be resolved. // version of the node record for it. It returns n if the node could not be resolved.
func (t *UDPv4) Resolve(n *enode.Node) *enode.Node { func (t *UDPv4) Resolve(n *enode.Node) *enode.Node {
@ -498,6 +390,45 @@ func (t *UDPv4) makePing(toaddr *net.UDPAddr) *pingV4 {
} }
} }
// LookupPubkey finds the closest nodes to the given public key.
func (t *UDPv4) LookupPubkey(key *ecdsa.PublicKey) []*enode.Node {
if t.tab.len() == 0 {
// All nodes were dropped, refresh. The very first query will hit this
// case and run the bootstrapping logic.
<-t.tab.refresh()
}
return t.newLookup(t.closeCtx, encodePubkey(key)).run()
}
// RandomNodes is an iterator yielding nodes from a random walk of the DHT.
func (t *UDPv4) RandomNodes() enode.Iterator {
return newLookupIterator(t.closeCtx, t.newRandomLookup)
}
// lookupRandom implements transport.
func (t *UDPv4) lookupRandom() []*enode.Node {
return t.newRandomLookup(t.closeCtx).run()
}
// lookupSelf implements transport.
func (t *UDPv4) lookupSelf() []*enode.Node {
return t.newLookup(t.closeCtx, encodePubkey(&t.priv.PublicKey)).run()
}
func (t *UDPv4) newRandomLookup(ctx context.Context) *lookup {
var target encPubkey
crand.Read(target[:])
return t.newLookup(ctx, target)
}
func (t *UDPv4) newLookup(ctx context.Context, targetKey encPubkey) *lookup {
target := enode.ID(crypto.Keccak256Hash(targetKey[:]))
it := newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) {
return t.findnode(n.ID(), n.addr(), targetKey)
})
return it
}
// findnode sends a findnode request to the given node and waits until // findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors. // the node has sent up to k neighbors.
func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) { func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
@ -575,7 +506,7 @@ func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchF
select { select {
case t.addReplyMatcher <- p: case t.addReplyMatcher <- p:
// loop will handle it // loop will handle it
case <-t.closing: case <-t.closeCtx.Done():
ch <- errClosed ch <- errClosed
} }
return p return p
@ -589,7 +520,7 @@ func (t *UDPv4) handleReply(from enode.ID, fromIP net.IP, req packetV4) bool {
case t.gotreply <- reply{from, fromIP, req, matched}: case t.gotreply <- reply{from, fromIP, req, matched}:
// loop will handle it // loop will handle it
return <-matched return <-matched
case <-t.closing: case <-t.closeCtx.Done():
return false return false
} }
} }
@ -635,7 +566,7 @@ func (t *UDPv4) loop() {
resetTimeout() resetTimeout()
select { select {
case <-t.closing: case <-t.closeCtx.Done():
for el := plist.Front(); el != nil; el = el.Next() { for el := plist.Front(); el != nil; el = el.Next() {
el.Value.(*replyMatcher).errc <- errClosed el.Value.(*replyMatcher).errc <- errClosed
} }

286
p2p/enode/iter.go Normal file
View File

@ -0,0 +1,286 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package enode
import (
"sync"
"time"
)
// Iterator represents a sequence of nodes. The Next method moves to the next node in the
// sequence. It returns false when the sequence has ended or the iterator is closed. Close
// may be called concurrently with Next and Node, and interrupts Next if it is blocked.
type Iterator interface {
Next() bool // moves to next node
Node() *Node // returns current node
Close() // ends the iterator
}
// ReadNodes reads at most n nodes from the given iterator. The return value contains no
// duplicates and no nil values. To prevent looping indefinitely for small repeating node
// sequences, this function calls Next at most n times.
func ReadNodes(it Iterator, n int) []*Node {
seen := make(map[ID]*Node, n)
for i := 0; i < n && it.Next(); i++ {
// Remove duplicates, keeping the node with higher seq.
node := it.Node()
prevNode, ok := seen[node.ID()]
if ok && prevNode.Seq() > node.Seq() {
continue
}
seen[node.ID()] = node
}
result := make([]*Node, 0, len(seen))
for _, node := range seen {
result = append(result, node)
}
return result
}
// IterNodes makes an iterator which runs through the given nodes once.
func IterNodes(nodes []*Node) Iterator {
return &sliceIter{nodes: nodes, index: -1}
}
// CycleNodes makes an iterator which cycles through the given nodes indefinitely.
func CycleNodes(nodes []*Node) Iterator {
return &sliceIter{nodes: nodes, index: -1, cycle: true}
}
type sliceIter struct {
mu sync.Mutex
nodes []*Node
index int
cycle bool
}
func (it *sliceIter) Next() bool {
it.mu.Lock()
defer it.mu.Unlock()
if len(it.nodes) == 0 {
return false
}
it.index++
if it.index == len(it.nodes) {
if it.cycle {
it.index = 0
} else {
it.nodes = nil
return false
}
}
return true
}
func (it *sliceIter) Node() *Node {
if len(it.nodes) == 0 {
return nil
}
return it.nodes[it.index]
}
func (it *sliceIter) Close() {
it.mu.Lock()
defer it.mu.Unlock()
it.nodes = nil
}
// Filter wraps an iterator such that Next only returns nodes for which
// the 'check' function returns true.
func Filter(it Iterator, check func(*Node) bool) Iterator {
return &filterIter{it, check}
}
type filterIter struct {
Iterator
check func(*Node) bool
}
func (f *filterIter) Next() bool {
for f.Iterator.Next() {
if f.check(f.Node()) {
return true
}
}
return false
}
// FairMix aggregates multiple node iterators. The mixer itself is an iterator which ends
// only when Close is called. Source iterators added via AddSource are removed from the
// mix when they end.
//
// The distribution of nodes returned by Next is approximately fair, i.e. FairMix
// attempts to draw from all sources equally often. However, if a certain source is slow
// and doesn't return a node within the configured timeout, a node from any other source
// will be returned.
//
// It's safe to call AddSource and Close concurrently with Next.
type FairMix struct {
wg sync.WaitGroup
fromAny chan *Node
timeout time.Duration
cur *Node
mu sync.Mutex
closed chan struct{}
sources []*mixSource
last int
}
type mixSource struct {
it Iterator
next chan *Node
timeout time.Duration
}
// NewFairMix creates a mixer.
//
// The timeout specifies how long the mixer will wait for the next fairly-chosen source
// before giving up and taking a node from any other source. A good way to set the timeout
// is deciding how long you'd want to wait for a node on average. Passing a negative
// timeout makes the mixer completely fair.
func NewFairMix(timeout time.Duration) *FairMix {
m := &FairMix{
fromAny: make(chan *Node),
closed: make(chan struct{}),
timeout: timeout,
}
return m
}
// AddSource adds a source of nodes.
func (m *FairMix) AddSource(it Iterator) {
m.mu.Lock()
defer m.mu.Unlock()
if m.closed == nil {
return
}
m.wg.Add(1)
source := &mixSource{it, make(chan *Node), m.timeout}
m.sources = append(m.sources, source)
go m.runSource(m.closed, source)
}
// Close shuts down the mixer and all current sources.
// Calling this is required to release resources associated with the mixer.
func (m *FairMix) Close() {
m.mu.Lock()
defer m.mu.Unlock()
if m.closed == nil {
return
}
for _, s := range m.sources {
s.it.Close()
}
close(m.closed)
m.wg.Wait()
close(m.fromAny)
m.sources = nil
m.closed = nil
}
// Next returns a node from a random source.
func (m *FairMix) Next() bool {
m.cur = nil
var timeout <-chan time.Time
if m.timeout >= 0 {
timer := time.NewTimer(m.timeout)
timeout = timer.C
defer timer.Stop()
}
for {
source := m.pickSource()
if source == nil {
return m.nextFromAny()
}
select {
case n, ok := <-source.next:
if ok {
m.cur = n
source.timeout = m.timeout
return true
}
// This source has ended.
m.deleteSource(source)
case <-timeout:
source.timeout /= 2
return m.nextFromAny()
}
}
}
// Node returns the current node.
func (m *FairMix) Node() *Node {
return m.cur
}
// nextFromAny is used when there are no sources or when the 'fair' choice
// doesn't turn up a node quickly enough.
func (m *FairMix) nextFromAny() bool {
n, ok := <-m.fromAny
if ok {
m.cur = n
}
return ok
}
// pickSource chooses the next source to read from, cycling through them in order.
func (m *FairMix) pickSource() *mixSource {
m.mu.Lock()
defer m.mu.Unlock()
if len(m.sources) == 0 {
return nil
}
m.last = (m.last + 1) % len(m.sources)
return m.sources[m.last]
}
// deleteSource deletes a source.
func (m *FairMix) deleteSource(s *mixSource) {
m.mu.Lock()
defer m.mu.Unlock()
for i := range m.sources {
if m.sources[i] == s {
copy(m.sources[i:], m.sources[i+1:])
m.sources[len(m.sources)-1] = nil
m.sources = m.sources[:len(m.sources)-1]
break
}
}
}
// runSource reads a single source in a loop.
func (m *FairMix) runSource(closed chan struct{}, s *mixSource) {
defer m.wg.Done()
defer close(s.next)
for s.it.Next() {
n := s.it.Node()
select {
case s.next <- n:
case m.fromAny <- n:
case <-closed:
return
}
}
}

291
p2p/enode/iter_test.go Normal file
View File

@ -0,0 +1,291 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package enode
import (
"encoding/binary"
"runtime"
"sync/atomic"
"testing"
"time"
"github.com/ethereum/go-ethereum/p2p/enr"
)
func TestReadNodes(t *testing.T) {
nodes := ReadNodes(new(genIter), 10)
checkNodes(t, nodes, 10)
}
// This test checks that ReadNodes terminates when reading N nodes from an iterator
// which returns less than N nodes in an endless cycle.
func TestReadNodesCycle(t *testing.T) {
iter := &callCountIter{
Iterator: CycleNodes([]*Node{
testNode(0, 0),
testNode(1, 0),
testNode(2, 0),
}),
}
nodes := ReadNodes(iter, 10)
checkNodes(t, nodes, 3)
if iter.count != 10 {
t.Fatalf("%d calls to Next, want %d", iter.count, 100)
}
}
func TestFilterNodes(t *testing.T) {
nodes := make([]*Node, 100)
for i := range nodes {
nodes[i] = testNode(uint64(i), uint64(i))
}
it := Filter(IterNodes(nodes), func(n *Node) bool {
return n.Seq() >= 50
})
for i := 50; i < len(nodes); i++ {
if !it.Next() {
t.Fatal("Next returned false")
}
if it.Node() != nodes[i] {
t.Fatalf("iterator returned wrong node %v\nwant %v", it.Node(), nodes[i])
}
}
if it.Next() {
t.Fatal("Next returned true after underlying iterator has ended")
}
}
func checkNodes(t *testing.T, nodes []*Node, wantLen int) {
if len(nodes) != wantLen {
t.Errorf("slice has %d nodes, want %d", len(nodes), wantLen)
return
}
seen := make(map[ID]bool)
for i, e := range nodes {
if e == nil {
t.Errorf("nil node at index %d", i)
return
}
if seen[e.ID()] {
t.Errorf("slice has duplicate node %v", e.ID())
return
}
seen[e.ID()] = true
}
}
// This test checks fairness of FairMix in the happy case where all sources return nodes
// within the context's deadline.
func TestFairMix(t *testing.T) {
for i := 0; i < 500; i++ {
testMixerFairness(t)
}
}
func testMixerFairness(t *testing.T) {
mix := NewFairMix(1 * time.Second)
mix.AddSource(&genIter{index: 1})
mix.AddSource(&genIter{index: 2})
mix.AddSource(&genIter{index: 3})
defer mix.Close()
nodes := ReadNodes(mix, 500)
checkNodes(t, nodes, 500)
// Verify that the nodes slice contains an approximately equal number of nodes
// from each source.
d := idPrefixDistribution(nodes)
for _, count := range d {
if approxEqual(count, len(nodes)/3, 30) {
t.Fatalf("ID distribution is unfair: %v", d)
}
}
}
// This test checks that FairMix falls back to an alternative source when
// the 'fair' choice doesn't return a node within the timeout.
func TestFairMixNextFromAll(t *testing.T) {
mix := NewFairMix(1 * time.Millisecond)
mix.AddSource(&genIter{index: 1})
mix.AddSource(CycleNodes(nil))
defer mix.Close()
nodes := ReadNodes(mix, 500)
checkNodes(t, nodes, 500)
d := idPrefixDistribution(nodes)
if len(d) > 1 || d[1] != len(nodes) {
t.Fatalf("wrong ID distribution: %v", d)
}
}
// This test ensures FairMix works for Next with no sources.
func TestFairMixEmpty(t *testing.T) {
var (
mix = NewFairMix(1 * time.Second)
testN = testNode(1, 1)
ch = make(chan *Node)
)
defer mix.Close()
go func() {
mix.Next()
ch <- mix.Node()
}()
mix.AddSource(CycleNodes([]*Node{testN}))
if n := <-ch; n != testN {
t.Errorf("got wrong node: %v", n)
}
}
// This test checks closing a source while Next runs.
func TestFairMixRemoveSource(t *testing.T) {
mix := NewFairMix(1 * time.Second)
source := make(blockingIter)
mix.AddSource(source)
sig := make(chan *Node)
go func() {
<-sig
mix.Next()
sig <- mix.Node()
}()
sig <- nil
runtime.Gosched()
source.Close()
wantNode := testNode(0, 0)
mix.AddSource(CycleNodes([]*Node{wantNode}))
n := <-sig
if len(mix.sources) != 1 {
t.Fatalf("have %d sources, want one", len(mix.sources))
}
if n != wantNode {
t.Fatalf("mixer returned wrong node")
}
}
type blockingIter chan struct{}
func (it blockingIter) Next() bool {
<-it
return false
}
func (it blockingIter) Node() *Node {
return nil
}
func (it blockingIter) Close() {
close(it)
}
func TestFairMixClose(t *testing.T) {
for i := 0; i < 20 && !t.Failed(); i++ {
testMixerClose(t)
}
}
func testMixerClose(t *testing.T) {
mix := NewFairMix(-1)
mix.AddSource(CycleNodes(nil))
mix.AddSource(CycleNodes(nil))
done := make(chan struct{})
go func() {
defer close(done)
if mix.Next() {
t.Error("Next returned true")
}
}()
// This call is supposed to make it more likely that NextNode is
// actually executing by the time we call Close.
runtime.Gosched()
mix.Close()
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("Next didn't unblock on Close")
}
mix.Close() // shouldn't crash
}
func idPrefixDistribution(nodes []*Node) map[uint32]int {
d := make(map[uint32]int)
for _, node := range nodes {
id := node.ID()
d[binary.BigEndian.Uint32(id[:4])]++
}
return d
}
func approxEqual(x, y, ε int) bool {
if y > x {
x, y = y, x
}
return x-y > ε
}
// genIter creates fake nodes with numbered IDs based on 'index' and 'gen'
type genIter struct {
node *Node
index, gen uint32
}
func (s *genIter) Next() bool {
index := atomic.LoadUint32(&s.index)
if index == ^uint32(0) {
s.node = nil
return false
}
s.node = testNode(uint64(index)<<32|uint64(s.gen), 0)
s.gen++
return true
}
func (s *genIter) Node() *Node {
return s.node
}
func (s *genIter) Close() {
s.index = ^uint32(0)
}
func testNode(id, seq uint64) *Node {
var nodeID ID
binary.BigEndian.PutUint64(nodeID[:], id)
r := new(enr.Record)
r.SetSeq(seq)
return SignNull(r, nodeID)
}
// callCountIter counts calls to NextNode.
type callCountIter struct {
Iterator
count int
}
func (it *callCountIter) Next() bool {
it.count++
return it.Iterator.Next()
}

View File

@ -54,6 +54,11 @@ type Protocol struct {
// but returns nil, it is assumed that the protocol handshake is still running. // but returns nil, it is assumed that the protocol handshake is still running.
PeerInfo func(id enode.ID) interface{} PeerInfo func(id enode.ID) interface{}
// DialCandidates, if non-nil, is a way to tell Server about protocol-specific nodes
// that should be dialed. The server continuously reads nodes from the iterator and
// attempts to create connections to them.
DialCandidates enode.Iterator
// Attributes contains protocol specific information for the node record. // Attributes contains protocol specific information for the node record.
Attributes []enr.Entry Attributes []enr.Entry
} }

View File

@ -45,6 +45,11 @@ import (
const ( const (
defaultDialTimeout = 15 * time.Second defaultDialTimeout = 15 * time.Second
// This is the fairness knob for the discovery mixer. When looking for peers, we'll
// wait this long for a single source of candidates before moving on and trying other
// sources.
discmixTimeout = 5 * time.Second
// Connectivity defaults. // Connectivity defaults.
maxActiveDialTasks = 16 maxActiveDialTasks = 16
defaultMaxPendingPeers = 50 defaultMaxPendingPeers = 50
@ -167,16 +172,20 @@ type Server struct {
lock sync.Mutex // protects running lock sync.Mutex // protects running
running bool running bool
nodedb *enode.DB
localnode *enode.LocalNode
ntab discoverTable
listener net.Listener listener net.Listener
ourHandshake *protoHandshake ourHandshake *protoHandshake
DiscV5 *discv5.Network
loopWG sync.WaitGroup // loop, listenLoop loopWG sync.WaitGroup // loop, listenLoop
peerFeed event.Feed peerFeed event.Feed
log log.Logger log log.Logger
nodedb *enode.DB
localnode *enode.LocalNode
ntab *discover.UDPv4
DiscV5 *discv5.Network
discmix *enode.FairMix
staticNodeResolver nodeResolver
// Channels into the run loop. // Channels into the run loop.
quit chan struct{} quit chan struct{}
addstatic chan *enode.Node addstatic chan *enode.Node
@ -470,7 +479,7 @@ func (srv *Server) Start() (err error) {
} }
dynPeers := srv.maxDialedConns() dynPeers := srv.maxDialedConns()
dialer := newDialState(srv.localnode.ID(), srv.ntab, dynPeers, &srv.Config) dialer := newDialState(srv.localnode.ID(), dynPeers, &srv.Config)
srv.loopWG.Add(1) srv.loopWG.Add(1)
go srv.run(dialer) go srv.run(dialer)
return nil return nil
@ -521,6 +530,18 @@ func (srv *Server) setupLocalNode() error {
} }
func (srv *Server) setupDiscovery() error { func (srv *Server) setupDiscovery() error {
srv.discmix = enode.NewFairMix(discmixTimeout)
// Add protocol-specific discovery sources.
added := make(map[string]bool)
for _, proto := range srv.Protocols {
if proto.DialCandidates != nil && !added[proto.Name] {
srv.discmix.AddSource(proto.DialCandidates)
added[proto.Name] = true
}
}
// Don't listen on UDP endpoint if DHT is disabled.
if srv.NoDiscovery && !srv.DiscoveryV5 { if srv.NoDiscovery && !srv.DiscoveryV5 {
return nil return nil
} }
@ -562,7 +583,10 @@ func (srv *Server) setupDiscovery() error {
return err return err
} }
srv.ntab = ntab srv.ntab = ntab
srv.discmix.AddSource(ntab.RandomNodes())
srv.staticNodeResolver = ntab
} }
// Discovery V5 // Discovery V5
if srv.DiscoveryV5 { if srv.DiscoveryV5 {
var ntab *discv5.Network var ntab *discv5.Network
@ -620,6 +644,7 @@ func (srv *Server) run(dialstate dialer) {
srv.log.Info("Started P2P networking", "self", srv.localnode.Node().URLv4()) srv.log.Info("Started P2P networking", "self", srv.localnode.Node().URLv4())
defer srv.loopWG.Done() defer srv.loopWG.Done()
defer srv.nodedb.Close() defer srv.nodedb.Close()
defer srv.discmix.Close()
var ( var (
peers = make(map[enode.ID]*Peer) peers = make(map[enode.ID]*Peer)

View File

@ -233,8 +233,8 @@ func TestServerTaskScheduling(t *testing.T) {
Config: Config{MaxPeers: 10}, Config: Config{MaxPeers: 10},
localnode: enode.NewLocalNode(db, newkey()), localnode: enode.NewLocalNode(db, newkey()),
nodedb: db, nodedb: db,
discmix: enode.NewFairMix(0),
quit: make(chan struct{}), quit: make(chan struct{}),
ntab: fakeTable{},
running: true, running: true,
log: log.New(), log: log.New(),
} }
@ -282,9 +282,9 @@ func TestServerManyTasks(t *testing.T) {
quit: make(chan struct{}), quit: make(chan struct{}),
localnode: enode.NewLocalNode(db, newkey()), localnode: enode.NewLocalNode(db, newkey()),
nodedb: db, nodedb: db,
ntab: fakeTable{},
running: true, running: true,
log: log.New(), log: log.New(),
discmix: enode.NewFairMix(0),
} }
done = make(chan *testTask) done = make(chan *testTask)
start, end = 0, 0 start, end = 0, 0