diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go
index 6f6233e4d..59a2c9f83 100644
--- a/cmd/utils/flags.go
+++ b/cmd/utils/flags.go
@@ -1563,19 +1563,19 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
cfg.NetworkId = 3
}
cfg.Genesis = core.DefaultRopstenGenesisBlock()
- setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.RopstenGenesisHash])
+ setDNSDiscoveryDefaults(cfg, params.RopstenGenesisHash)
case ctx.GlobalBool(RinkebyFlag.Name):
if !ctx.GlobalIsSet(NetworkIdFlag.Name) {
cfg.NetworkId = 4
}
cfg.Genesis = core.DefaultRinkebyGenesisBlock()
- setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.RinkebyGenesisHash])
+ setDNSDiscoveryDefaults(cfg, params.RinkebyGenesisHash)
case ctx.GlobalBool(GoerliFlag.Name):
if !ctx.GlobalIsSet(NetworkIdFlag.Name) {
cfg.NetworkId = 5
}
cfg.Genesis = core.DefaultGoerliGenesisBlock()
- setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.GoerliGenesisHash])
+ setDNSDiscoveryDefaults(cfg, params.GoerliGenesisHash)
case ctx.GlobalBool(DeveloperFlag.Name):
if !ctx.GlobalIsSet(NetworkIdFlag.Name) {
cfg.NetworkId = 1337
@@ -1604,18 +1604,25 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
}
default:
if cfg.NetworkId == 1 {
- setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.MainnetGenesisHash])
+ setDNSDiscoveryDefaults(cfg, params.MainnetGenesisHash)
}
}
}
// setDNSDiscoveryDefaults configures DNS discovery with the given URL if
// no URLs are set.
-func setDNSDiscoveryDefaults(cfg *eth.Config, url string) {
+func setDNSDiscoveryDefaults(cfg *eth.Config, genesis common.Hash) {
if cfg.DiscoveryURLs != nil {
- return
+ return // already set through flags/config
+ }
+
+ protocol := "eth"
+ if cfg.SyncMode == downloader.LightSync {
+ protocol = "les"
+ }
+ if url := params.KnownDNSNetwork(genesis, protocol); url != "" {
+ cfg.DiscoveryURLs = []string{url}
}
- cfg.DiscoveryURLs = []string{url}
}
// RegisterEthService adds an Ethereum client to the stack.
diff --git a/core/forkid/forkid.go b/core/forkid/forkid.go
index e433db446..08a948510 100644
--- a/core/forkid/forkid.go
+++ b/core/forkid/forkid.go
@@ -27,7 +27,7 @@ import (
"strings"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/params"
)
@@ -44,6 +44,18 @@ var (
ErrLocalIncompatibleOrStale = errors.New("local incompatible or needs update")
)
+// Blockchain defines all necessary method to build a forkID.
+type Blockchain interface {
+ // Config retrieves the chain's fork configuration.
+ Config() *params.ChainConfig
+
+ // Genesis retrieves the chain's genesis block.
+ Genesis() *types.Block
+
+ // CurrentHeader retrieves the current head header of the canonical chain.
+ CurrentHeader() *types.Header
+}
+
// ID is a fork identifier as defined by EIP-2124.
type ID struct {
Hash [4]byte // CRC32 checksum of the genesis block and passed fork block numbers
@@ -54,7 +66,7 @@ type ID struct {
type Filter func(id ID) error
// NewID calculates the Ethereum fork ID from the chain config and head.
-func NewID(chain *core.BlockChain) ID {
+func NewID(chain Blockchain) ID {
return newID(
chain.Config(),
chain.Genesis().Hash(),
@@ -85,7 +97,7 @@ func newID(config *params.ChainConfig, genesis common.Hash, head uint64) ID {
// NewFilter creates a filter that returns if a fork ID should be rejected or not
// based on the local chain's status.
-func NewFilter(chain *core.BlockChain) Filter {
+func NewFilter(chain Blockchain) Filter {
return newFilter(
chain.Config(),
chain.Genesis().Hash(),
diff --git a/eth/backend.go b/eth/backend.go
index 391e3c0e6..e4f98360d 100644
--- a/eth/backend.go
+++ b/eth/backend.go
@@ -72,7 +72,7 @@ type Ethereum struct {
blockchain *core.BlockChain
protocolManager *ProtocolManager
lesServer LesServer
- dialCandiates enode.Iterator
+ dialCandidates enode.Iterator
// DB interfaces
chainDb ethdb.Database // Block chain database
@@ -226,7 +226,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
}
eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams)
- eth.dialCandiates, err = eth.setupDiscovery(&ctx.Config.P2P)
+ eth.dialCandidates, err = eth.setupDiscovery(&ctx.Config.P2P)
if err != nil {
return nil, err
}
@@ -523,7 +523,7 @@ func (s *Ethereum) Protocols() []p2p.Protocol {
for i, vsn := range ProtocolVersions {
protos[i] = s.protocolManager.makeProtocol(vsn)
protos[i].Attributes = []enr.Entry{s.currentEthEntry()}
- protos[i].DialCandidates = s.dialCandiates
+ protos[i].DialCandidates = s.dialCandidates
}
if s.lesServer != nil {
protos = append(protos, s.lesServer.Protocols()...)
diff --git a/les/client.go b/les/client.go
index 2ac3285d8..34a654e22 100644
--- a/les/client.go
+++ b/les/client.go
@@ -51,16 +51,17 @@ import (
type LightEthereum struct {
lesCommons
- peers *serverPeerSet
- reqDist *requestDistributor
- retriever *retrieveManager
- odr *LesOdr
- relay *lesTxRelay
- handler *clientHandler
- txPool *light.TxPool
- blockchain *light.LightChain
- serverPool *serverPool
- valueTracker *lpc.ValueTracker
+ peers *serverPeerSet
+ reqDist *requestDistributor
+ retriever *retrieveManager
+ odr *LesOdr
+ relay *lesTxRelay
+ handler *clientHandler
+ txPool *light.TxPool
+ blockchain *light.LightChain
+ serverPool *serverPool
+ valueTracker *lpc.ValueTracker
+ dialCandidates enode.Iterator
bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests
bloomIndexer *core.ChainIndexer // Bloom indexer operating during block imports
@@ -104,11 +105,19 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb),
bloomRequests: make(chan chan *bloombits.Retrieval),
bloomIndexer: eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations),
- serverPool: newServerPool(chainDb, config.UltraLightServers),
valueTracker: lpc.NewValueTracker(lespayDb, &mclock.System{}, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)),
}
peers.subscribe((*vtSubscription)(leth.valueTracker))
- leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool)
+
+ dnsdisc, err := leth.setupDiscovery(&ctx.Config.P2P)
+ if err != nil {
+ return nil, err
+ }
+ leth.serverPool = newServerPool(lespayDb, []byte("serverpool:"), leth.valueTracker, dnsdisc, time.Second, nil, &mclock.System{}, config.UltraLightServers)
+ peers.subscribe(leth.serverPool)
+ leth.dialCandidates = leth.serverPool.dialIterator
+
+ leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool.getTimeout)
leth.relay = newLesTxRelay(peers, leth.retriever)
leth.odr = NewLesOdr(chainDb, light.DefaultClientIndexerConfig, leth.retriever)
@@ -140,11 +149,6 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
leth.chtIndexer.Start(leth.blockchain)
leth.bloomIndexer.Start(leth.blockchain)
- leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth)
- if leth.handler.ulc != nil {
- log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction)
- leth.blockchain.DisableCheckFreq()
- }
// Rewind the chain in case of an incompatible config upgrade.
if compat, ok := genesisErr.(*params.ConfigCompatError); ok {
log.Warn("Rewinding chain to upgrade configuration", "err", compat)
@@ -159,6 +163,11 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
}
leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams)
+ leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth)
+ if leth.handler.ulc != nil {
+ log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction)
+ leth.blockchain.DisableCheckFreq()
+ }
return leth, nil
}
@@ -260,7 +269,7 @@ func (s *LightEthereum) Protocols() []p2p.Protocol {
return p.Info()
}
return nil
- })
+ }, s.dialCandidates)
}
// Start implements node.Service, starting all internal goroutines needed by the
@@ -268,15 +277,12 @@ func (s *LightEthereum) Protocols() []p2p.Protocol {
func (s *LightEthereum) Start(srvr *p2p.Server) error {
log.Warn("Light client mode is an experimental feature")
+ s.serverPool.start()
// Start bloom request workers.
s.wg.Add(bloomServiceThreads)
s.startBloomHandlers(params.BloomBitsBlocksClient)
s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.config.NetworkId)
-
- // clients are searching for the first advertised protocol in the list
- protocolVersion := AdvertiseProtocolVersions[0]
- s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion))
return nil
}
@@ -284,6 +290,8 @@ func (s *LightEthereum) Start(srvr *p2p.Server) error {
// Ethereum protocol.
func (s *LightEthereum) Stop() error {
close(s.closeCh)
+ s.serverPool.stop()
+ s.valueTracker.Stop()
s.peers.close()
s.reqDist.close()
s.odr.Stop()
@@ -295,8 +303,6 @@ func (s *LightEthereum) Stop() error {
s.txPool.Stop()
s.engine.Close()
s.eventMux.Stop()
- s.serverPool.stop()
- s.valueTracker.Stop()
s.chainDb.Close()
s.wg.Wait()
log.Info("Light ethereum stopped")
diff --git a/les/client_handler.go b/les/client_handler.go
index 6367fdb6b..abe472e46 100644
--- a/les/client_handler.go
+++ b/les/client_handler.go
@@ -64,7 +64,7 @@ func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.T
if checkpoint != nil {
height = (checkpoint.SectionIndex+1)*params.CHTFrequency - 1
}
- handler.fetcher = newLightFetcher(handler)
+ handler.fetcher = newLightFetcher(handler, backend.serverPool.getTimeout)
handler.downloader = downloader.New(height, backend.chainDb, nil, backend.eventMux, nil, backend.blockchain, handler.removePeer)
handler.backend.peers.subscribe((*downloaderPeerNotify)(handler))
return handler
@@ -85,14 +85,9 @@ func (h *clientHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter)
}
peer := newServerPeer(int(version), h.backend.config.NetworkId, trusted, p, newMeteredMsgWriter(rw, int(version)))
defer peer.close()
- peer.poolEntry = h.backend.serverPool.connect(peer, peer.Node())
- if peer.poolEntry == nil {
- return p2p.DiscRequested
- }
h.wg.Add(1)
defer h.wg.Done()
err := h.handle(peer)
- h.backend.serverPool.disconnect(peer.poolEntry)
return err
}
@@ -129,10 +124,6 @@ func (h *clientHandler) handle(p *serverPeer) error {
h.fetcher.announce(p, &announceData{Hash: p.headInfo.Hash, Number: p.headInfo.Number, Td: p.headInfo.Td})
- // pool entry can be nil during the unit test.
- if p.poolEntry != nil {
- h.backend.serverPool.registered(p.poolEntry)
- }
// Mark the peer starts to be served.
atomic.StoreUint32(&p.serving, 1)
defer atomic.StoreUint32(&p.serving, 0)
diff --git a/les/commons.go b/les/commons.go
index 29b5b7660..cd8a22834 100644
--- a/les/commons.go
+++ b/les/commons.go
@@ -81,7 +81,7 @@ type NodeInfo struct {
}
// makeProtocols creates protocol descriptors for the given LES versions.
-func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}) []p2p.Protocol {
+func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}, dialCandidates enode.Iterator) []p2p.Protocol {
protos := make([]p2p.Protocol, len(versions))
for i, version := range versions {
version := version
@@ -93,7 +93,8 @@ func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p
Run: func(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
return runPeer(version, peer, rw)
},
- PeerInfo: peerInfo,
+ PeerInfo: peerInfo,
+ DialCandidates: dialCandidates,
}
}
return protos
diff --git a/les/distributor.go b/les/distributor.go
index 97d2ccdfe..31150e4d7 100644
--- a/les/distributor.go
+++ b/les/distributor.go
@@ -180,12 +180,11 @@ func (d *requestDistributor) loop() {
type selectPeerItem struct {
peer distPeer
req *distReq
- weight int64
+ weight uint64
}
-// Weight implements wrsItem interface
-func (sp selectPeerItem) Weight() int64 {
- return sp.weight
+func selectPeerWeight(i interface{}) uint64 {
+ return i.(selectPeerItem).weight
}
// nextRequest returns the next possible request from any peer, along with the
@@ -220,9 +219,9 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) {
wait, bufRemain := peer.waitBefore(cost)
if wait == 0 {
if sel == nil {
- sel = utils.NewWeightedRandomSelect()
+ sel = utils.NewWeightedRandomSelect(selectPeerWeight)
}
- sel.Update(selectPeerItem{peer: peer, req: req, weight: int64(bufRemain*1000000) + 1})
+ sel.Update(selectPeerItem{peer: peer, req: req, weight: uint64(bufRemain*1000000) + 1})
} else {
if bestWait == 0 || wait < bestWait {
bestWait = wait
diff --git a/les/enr_entry.go b/les/enr_entry.go
index c2a92dd99..65d0d1fdb 100644
--- a/les/enr_entry.go
+++ b/les/enr_entry.go
@@ -17,6 +17,9 @@
package les
import (
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/dnsdisc"
+ "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -30,3 +33,12 @@ type lesEntry struct {
func (e lesEntry) ENRKey() string {
return "les"
}
+
+// setupDiscovery creates the node discovery source for the eth protocol.
+func (eth *LightEthereum) setupDiscovery(cfg *p2p.Config) (enode.Iterator, error) {
+ if /*cfg.NoDiscovery || */ len(eth.config.DiscoveryURLs) == 0 {
+ return nil, nil
+ }
+ client := dnsdisc.NewClient(dnsdisc.Config{})
+ return client.NewIterator(eth.config.DiscoveryURLs...)
+}
diff --git a/les/fetcher.go b/les/fetcher.go
index 7fa81cbcb..aaf74aaa1 100644
--- a/les/fetcher.go
+++ b/les/fetcher.go
@@ -40,8 +40,9 @@ const (
// ODR system to ensure that we only request data related to a certain block from peers who have already processed
// and announced that block.
type lightFetcher struct {
- handler *clientHandler
- chain *light.LightChain
+ handler *clientHandler
+ chain *light.LightChain
+ softRequestTimeout func() time.Duration
lock sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests
maxConfirmedTd *big.Int
@@ -109,18 +110,19 @@ type fetchResponse struct {
}
// newLightFetcher creates a new light fetcher
-func newLightFetcher(h *clientHandler) *lightFetcher {
+func newLightFetcher(h *clientHandler, softRequestTimeout func() time.Duration) *lightFetcher {
f := &lightFetcher{
- handler: h,
- chain: h.backend.blockchain,
- peers: make(map[*serverPeer]*fetcherPeerInfo),
- deliverChn: make(chan fetchResponse, 100),
- requested: make(map[uint64]fetchRequest),
- timeoutChn: make(chan uint64),
- requestTrigger: make(chan struct{}, 1),
- syncDone: make(chan *serverPeer),
- closeCh: make(chan struct{}),
- maxConfirmedTd: big.NewInt(0),
+ handler: h,
+ chain: h.backend.blockchain,
+ peers: make(map[*serverPeer]*fetcherPeerInfo),
+ deliverChn: make(chan fetchResponse, 100),
+ requested: make(map[uint64]fetchRequest),
+ timeoutChn: make(chan uint64),
+ requestTrigger: make(chan struct{}, 1),
+ syncDone: make(chan *serverPeer),
+ closeCh: make(chan struct{}),
+ maxConfirmedTd: big.NewInt(0),
+ softRequestTimeout: softRequestTimeout,
}
h.backend.peers.subscribe(f)
@@ -163,7 +165,7 @@ func (f *lightFetcher) syncLoop() {
f.lock.Unlock()
} else {
go func() {
- time.Sleep(softRequestTimeout)
+ time.Sleep(f.softRequestTimeout())
f.reqMu.Lock()
req, ok := f.requested[reqID]
if ok {
@@ -187,7 +189,6 @@ func (f *lightFetcher) syncLoop() {
}
f.reqMu.Unlock()
if ok {
- f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), true)
req.peer.Log().Debug("Fetching data timed out hard")
go f.handler.removePeer(req.peer.id)
}
@@ -201,9 +202,6 @@ func (f *lightFetcher) syncLoop() {
delete(f.requested, resp.reqID)
}
f.reqMu.Unlock()
- if ok {
- f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), req.timeout)
- }
f.lock.Lock()
if !ok || !(f.syncing || f.processResponse(req, resp)) {
resp.peer.Log().Debug("Failed processing response")
@@ -879,12 +877,10 @@ func (f *lightFetcher) checkUpdateStats(p *serverPeer, newEntry *updateStatsEntr
fp.firstUpdateStats = newEntry
}
for fp.firstUpdateStats != nil && fp.firstUpdateStats.time <= now-mclock.AbsTime(blockDelayTimeout) {
- f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, blockDelayTimeout)
fp.firstUpdateStats = fp.firstUpdateStats.next
}
if fp.confirmedTd != nil {
for fp.firstUpdateStats != nil && fp.firstUpdateStats.td.Cmp(fp.confirmedTd) <= 0 {
- f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, time.Duration(now-fp.firstUpdateStats.time))
fp.firstUpdateStats = fp.firstUpdateStats.next
}
}
diff --git a/les/lespay/client/fillset.go b/les/lespay/client/fillset.go
new file mode 100644
index 000000000..0da850bca
--- /dev/null
+++ b/les/lespay/client/fillset.go
@@ -0,0 +1,107 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package client
+
+import (
+ "sync"
+
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/nodestate"
+)
+
+// FillSet tries to read nodes from an input iterator and add them to a node set by
+// setting the specified node state flag(s) until the size of the set reaches the target.
+// Note that other mechanisms (like other FillSet instances reading from different inputs)
+// can also set the same flag(s) and FillSet will always care about the total number of
+// nodes having those flags.
+type FillSet struct {
+ lock sync.Mutex
+ cond *sync.Cond
+ ns *nodestate.NodeStateMachine
+ input enode.Iterator
+ closed bool
+ flags nodestate.Flags
+ count, target int
+}
+
+// NewFillSet creates a new FillSet
+func NewFillSet(ns *nodestate.NodeStateMachine, input enode.Iterator, flags nodestate.Flags) *FillSet {
+ fs := &FillSet{
+ ns: ns,
+ input: input,
+ flags: flags,
+ }
+ fs.cond = sync.NewCond(&fs.lock)
+
+ ns.SubscribeState(flags, func(n *enode.Node, oldState, newState nodestate.Flags) {
+ fs.lock.Lock()
+ if oldState.Equals(flags) {
+ fs.count--
+ }
+ if newState.Equals(flags) {
+ fs.count++
+ }
+ if fs.target > fs.count {
+ fs.cond.Signal()
+ }
+ fs.lock.Unlock()
+ })
+
+ go fs.readLoop()
+ return fs
+}
+
+// readLoop keeps reading nodes from the input and setting the specified flags for them
+// whenever the node set size is under the current target
+func (fs *FillSet) readLoop() {
+ for {
+ fs.lock.Lock()
+ for fs.target <= fs.count && !fs.closed {
+ fs.cond.Wait()
+ }
+
+ fs.lock.Unlock()
+ if !fs.input.Next() {
+ return
+ }
+ fs.ns.SetState(fs.input.Node(), fs.flags, nodestate.Flags{}, 0)
+ }
+}
+
+// SetTarget sets the current target for node set size. If the previous target was not
+// reached and FillSet was still waiting for the next node from the input then the next
+// incoming node will be added to the set regardless of the target. This ensures that
+// all nodes coming from the input are eventually added to the set.
+func (fs *FillSet) SetTarget(target int) {
+ fs.lock.Lock()
+ defer fs.lock.Unlock()
+
+ fs.target = target
+ if fs.target > fs.count {
+ fs.cond.Signal()
+ }
+}
+
+// Close shuts FillSet down and closes the input iterator
+func (fs *FillSet) Close() {
+ fs.lock.Lock()
+ defer fs.lock.Unlock()
+
+ fs.closed = true
+ fs.input.Close()
+ fs.cond.Signal()
+}
diff --git a/les/lespay/client/fillset_test.go b/les/lespay/client/fillset_test.go
new file mode 100644
index 000000000..58240682c
--- /dev/null
+++ b/les/lespay/client/fillset_test.go
@@ -0,0 +1,113 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package client
+
+import (
+ "math/rand"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common/mclock"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+ "github.com/ethereum/go-ethereum/p2p/nodestate"
+)
+
+type testIter struct {
+ waitCh chan struct{}
+ nodeCh chan *enode.Node
+ node *enode.Node
+}
+
+func (i *testIter) Next() bool {
+ i.waitCh <- struct{}{}
+ i.node = <-i.nodeCh
+ return i.node != nil
+}
+
+func (i *testIter) Node() *enode.Node {
+ return i.node
+}
+
+func (i *testIter) Close() {}
+
+func (i *testIter) push() {
+ var id enode.ID
+ rand.Read(id[:])
+ i.nodeCh <- enode.SignNull(new(enr.Record), id)
+}
+
+func (i *testIter) waiting(timeout time.Duration) bool {
+ select {
+ case <-i.waitCh:
+ return true
+ case <-time.After(timeout):
+ return false
+ }
+}
+
+func TestFillSet(t *testing.T) {
+ ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup)
+ iter := &testIter{
+ waitCh: make(chan struct{}),
+ nodeCh: make(chan *enode.Node),
+ }
+ fs := NewFillSet(ns, iter, sfTest1)
+ ns.Start()
+
+ expWaiting := func(i int, push bool) {
+ for ; i > 0; i-- {
+ if !iter.waiting(time.Second * 10) {
+ t.Fatalf("FillSet not waiting for new nodes")
+ }
+ if push {
+ iter.push()
+ }
+ }
+ }
+
+ expNotWaiting := func() {
+ if iter.waiting(time.Millisecond * 100) {
+ t.Fatalf("FillSet unexpectedly waiting for new nodes")
+ }
+ }
+
+ expNotWaiting()
+ fs.SetTarget(3)
+ expWaiting(3, true)
+ expNotWaiting()
+ fs.SetTarget(100)
+ expWaiting(2, true)
+ expWaiting(1, false)
+ // lower the target before the previous one has been filled up
+ fs.SetTarget(0)
+ iter.push()
+ expNotWaiting()
+ fs.SetTarget(10)
+ expWaiting(4, true)
+ expNotWaiting()
+ // remove all previosly set flags
+ ns.ForEach(sfTest1, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) {
+ ns.SetState(node, nodestate.Flags{}, sfTest1, 0)
+ })
+ // now expect FillSet to fill the set up again with 10 new nodes
+ expWaiting(10, true)
+ expNotWaiting()
+
+ fs.Close()
+ ns.Stop()
+}
diff --git a/les/lespay/client/queueiterator.go b/les/lespay/client/queueiterator.go
new file mode 100644
index 000000000..ad3f8df5b
--- /dev/null
+++ b/les/lespay/client/queueiterator.go
@@ -0,0 +1,123 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package client
+
+import (
+ "sync"
+
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/nodestate"
+)
+
+// QueueIterator returns nodes from the specified selectable set in the same order as
+// they entered the set.
+type QueueIterator struct {
+ lock sync.Mutex
+ cond *sync.Cond
+
+ ns *nodestate.NodeStateMachine
+ queue []*enode.Node
+ nextNode *enode.Node
+ waitCallback func(bool)
+ fifo, closed bool
+}
+
+// NewQueueIterator creates a new QueueIterator. Nodes are selectable if they have all the required
+// and none of the disabled flags set. When a node is selected the selectedFlag is set which also
+// disables further selectability until it is removed or times out.
+func NewQueueIterator(ns *nodestate.NodeStateMachine, requireFlags, disableFlags nodestate.Flags, fifo bool, waitCallback func(bool)) *QueueIterator {
+ qi := &QueueIterator{
+ ns: ns,
+ fifo: fifo,
+ waitCallback: waitCallback,
+ }
+ qi.cond = sync.NewCond(&qi.lock)
+
+ ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState nodestate.Flags) {
+ oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags)
+ newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags)
+ if newMatch == oldMatch {
+ return
+ }
+
+ qi.lock.Lock()
+ defer qi.lock.Unlock()
+
+ if newMatch {
+ qi.queue = append(qi.queue, n)
+ } else {
+ id := n.ID()
+ for i, qn := range qi.queue {
+ if qn.ID() == id {
+ copy(qi.queue[i:len(qi.queue)-1], qi.queue[i+1:])
+ qi.queue = qi.queue[:len(qi.queue)-1]
+ break
+ }
+ }
+ }
+ qi.cond.Signal()
+ })
+ return qi
+}
+
+// Next moves to the next selectable node.
+func (qi *QueueIterator) Next() bool {
+ qi.lock.Lock()
+ if !qi.closed && len(qi.queue) == 0 {
+ if qi.waitCallback != nil {
+ qi.waitCallback(true)
+ }
+ for !qi.closed && len(qi.queue) == 0 {
+ qi.cond.Wait()
+ }
+ if qi.waitCallback != nil {
+ qi.waitCallback(false)
+ }
+ }
+ if qi.closed {
+ qi.nextNode = nil
+ qi.lock.Unlock()
+ return false
+ }
+ // Move to the next node in queue.
+ if qi.fifo {
+ qi.nextNode = qi.queue[0]
+ copy(qi.queue[:len(qi.queue)-1], qi.queue[1:])
+ qi.queue = qi.queue[:len(qi.queue)-1]
+ } else {
+ qi.nextNode = qi.queue[len(qi.queue)-1]
+ qi.queue = qi.queue[:len(qi.queue)-1]
+ }
+ qi.lock.Unlock()
+ return true
+}
+
+// Close ends the iterator.
+func (qi *QueueIterator) Close() {
+ qi.lock.Lock()
+ qi.closed = true
+ qi.lock.Unlock()
+ qi.cond.Signal()
+}
+
+// Node returns the current node.
+func (qi *QueueIterator) Node() *enode.Node {
+ qi.lock.Lock()
+ defer qi.lock.Unlock()
+
+ return qi.nextNode
+}
diff --git a/les/lespay/client/queueiterator_test.go b/les/lespay/client/queueiterator_test.go
new file mode 100644
index 000000000..a74301c7d
--- /dev/null
+++ b/les/lespay/client/queueiterator_test.go
@@ -0,0 +1,106 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package client
+
+import (
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common/mclock"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+ "github.com/ethereum/go-ethereum/p2p/nodestate"
+)
+
+func testNodeID(i int) enode.ID {
+ return enode.ID{42, byte(i % 256), byte(i / 256)}
+}
+
+func testNodeIndex(id enode.ID) int {
+ if id[0] != 42 {
+ return -1
+ }
+ return int(id[1]) + int(id[2])*256
+}
+
+func testNode(i int) *enode.Node {
+ return enode.SignNull(new(enr.Record), testNodeID(i))
+}
+
+func TestQueueIteratorFIFO(t *testing.T) {
+ testQueueIterator(t, true)
+}
+
+func TestQueueIteratorLIFO(t *testing.T) {
+ testQueueIterator(t, false)
+}
+
+func testQueueIterator(t *testing.T, fifo bool) {
+ ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup)
+ qi := NewQueueIterator(ns, sfTest2, sfTest3.Or(sfTest4), fifo, nil)
+ ns.Start()
+ for i := 1; i <= iterTestNodeCount; i++ {
+ ns.SetState(testNode(i), sfTest1, nodestate.Flags{}, 0)
+ }
+ next := func() int {
+ ch := make(chan struct{})
+ go func() {
+ qi.Next()
+ close(ch)
+ }()
+ select {
+ case <-ch:
+ case <-time.After(time.Second * 5):
+ t.Fatalf("Iterator.Next() timeout")
+ }
+ node := qi.Node()
+ ns.SetState(node, sfTest4, nodestate.Flags{}, 0)
+ return testNodeIndex(node.ID())
+ }
+ exp := func(i int) {
+ n := next()
+ if n != i {
+ t.Errorf("Wrong item returned by iterator (expected %d, got %d)", i, n)
+ }
+ }
+ explist := func(list []int) {
+ for i := range list {
+ if fifo {
+ exp(list[i])
+ } else {
+ exp(list[len(list)-1-i])
+ }
+ }
+ }
+
+ ns.SetState(testNode(1), sfTest2, nodestate.Flags{}, 0)
+ ns.SetState(testNode(2), sfTest2, nodestate.Flags{}, 0)
+ ns.SetState(testNode(3), sfTest2, nodestate.Flags{}, 0)
+ explist([]int{1, 2, 3})
+ ns.SetState(testNode(4), sfTest2, nodestate.Flags{}, 0)
+ ns.SetState(testNode(5), sfTest2, nodestate.Flags{}, 0)
+ ns.SetState(testNode(6), sfTest2, nodestate.Flags{}, 0)
+ ns.SetState(testNode(5), sfTest3, nodestate.Flags{}, 0)
+ explist([]int{4, 6})
+ ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0)
+ ns.SetState(testNode(2), nodestate.Flags{}, sfTest4, 0)
+ ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0)
+ ns.SetState(testNode(2), sfTest3, nodestate.Flags{}, 0)
+ ns.SetState(testNode(2), nodestate.Flags{}, sfTest3, 0)
+ explist([]int{1, 3, 2})
+ ns.Stop()
+}
diff --git a/les/lespay/client/valuetracker.go b/les/lespay/client/valuetracker.go
index 92bfd694e..4e67b31d9 100644
--- a/les/lespay/client/valuetracker.go
+++ b/les/lespay/client/valuetracker.go
@@ -213,6 +213,15 @@ func (vt *ValueTracker) StatsExpirer() *utils.Expirer {
return &vt.statsExpirer
}
+// StatsExpirer returns the current expiration factor so that other values can be expired
+// with the same rate as the service value statistics.
+func (vt *ValueTracker) StatsExpFactor() utils.ExpirationFactor {
+ vt.statsExpLock.RLock()
+ defer vt.statsExpLock.RUnlock()
+
+ return vt.statsExpFactor
+}
+
// loadFromDb loads the value tracker's state from the database and converts saved
// request basket index mapping if it does not match the specified index to name mapping.
func (vt *ValueTracker) loadFromDb(mapping []string) error {
@@ -500,16 +509,3 @@ func (vt *ValueTracker) RequestStats() []RequestStatsItem {
}
return res
}
-
-// TotalServiceValue returns the total service value provided by the given node (as
-// a function of the weights which are calculated from the request timeout value).
-func (vt *ValueTracker) TotalServiceValue(nv *NodeValueTracker, weights ResponseTimeWeights) float64 {
- vt.statsExpLock.RLock()
- expFactor := vt.statsExpFactor
- vt.statsExpLock.RUnlock()
-
- nv.lock.Lock()
- defer nv.lock.Unlock()
-
- return nv.rtStats.Value(weights, expFactor)
-}
diff --git a/les/lespay/client/wrsiterator.go b/les/lespay/client/wrsiterator.go
new file mode 100644
index 000000000..8a2e39ad4
--- /dev/null
+++ b/les/lespay/client/wrsiterator.go
@@ -0,0 +1,128 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package client
+
+import (
+ "sync"
+
+ "github.com/ethereum/go-ethereum/les/utils"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/nodestate"
+)
+
+// WrsIterator returns nodes from the specified selectable set with a weighted random
+// selection. Selection weights are provided by a callback function.
+type WrsIterator struct {
+ lock sync.Mutex
+ cond *sync.Cond
+
+ ns *nodestate.NodeStateMachine
+ wrs *utils.WeightedRandomSelect
+ nextNode *enode.Node
+ closed bool
+}
+
+// NewWrsIterator creates a new WrsIterator. Nodes are selectable if they have all the required
+// and none of the disabled flags set. When a node is selected the selectedFlag is set which also
+// disables further selectability until it is removed or times out.
+func NewWrsIterator(ns *nodestate.NodeStateMachine, requireFlags, disableFlags nodestate.Flags, weightField nodestate.Field) *WrsIterator {
+ wfn := func(i interface{}) uint64 {
+ n := ns.GetNode(i.(enode.ID))
+ if n == nil {
+ return 0
+ }
+ wt, _ := ns.GetField(n, weightField).(uint64)
+ return wt
+ }
+
+ w := &WrsIterator{
+ ns: ns,
+ wrs: utils.NewWeightedRandomSelect(wfn),
+ }
+ w.cond = sync.NewCond(&w.lock)
+
+ ns.SubscribeField(weightField, func(n *enode.Node, state nodestate.Flags, oldValue, newValue interface{}) {
+ if state.HasAll(requireFlags) && state.HasNone(disableFlags) {
+ w.lock.Lock()
+ w.wrs.Update(n.ID())
+ w.lock.Unlock()
+ w.cond.Signal()
+ }
+ })
+
+ ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState nodestate.Flags) {
+ oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags)
+ newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags)
+ if newMatch == oldMatch {
+ return
+ }
+
+ w.lock.Lock()
+ if newMatch {
+ w.wrs.Update(n.ID())
+ } else {
+ w.wrs.Remove(n.ID())
+ }
+ w.lock.Unlock()
+ w.cond.Signal()
+ })
+ return w
+}
+
+// Next selects the next node.
+func (w *WrsIterator) Next() bool {
+ w.nextNode = w.chooseNode()
+ return w.nextNode != nil
+}
+
+func (w *WrsIterator) chooseNode() *enode.Node {
+ w.lock.Lock()
+ defer w.lock.Unlock()
+
+ for {
+ for !w.closed && w.wrs.IsEmpty() {
+ w.cond.Wait()
+ }
+ if w.closed {
+ return nil
+ }
+ // Choose the next node at random. Even though w.wrs is guaranteed
+ // non-empty here, Choose might return nil if all items have weight
+ // zero.
+ if c := w.wrs.Choose(); c != nil {
+ id := c.(enode.ID)
+ w.wrs.Remove(id)
+ return w.ns.GetNode(id)
+ }
+ }
+
+}
+
+// Close ends the iterator.
+func (w *WrsIterator) Close() {
+ w.lock.Lock()
+ w.closed = true
+ w.lock.Unlock()
+ w.cond.Signal()
+}
+
+// Node returns the current node.
+func (w *WrsIterator) Node() *enode.Node {
+ w.lock.Lock()
+ defer w.lock.Unlock()
+ return w.nextNode
+}
diff --git a/les/lespay/client/wrsiterator_test.go b/les/lespay/client/wrsiterator_test.go
new file mode 100644
index 000000000..77bb5ee0c
--- /dev/null
+++ b/les/lespay/client/wrsiterator_test.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package client
+
+import (
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common/mclock"
+ "github.com/ethereum/go-ethereum/p2p/nodestate"
+)
+
+var (
+ testSetup = &nodestate.Setup{}
+ sfTest1 = testSetup.NewFlag("test1")
+ sfTest2 = testSetup.NewFlag("test2")
+ sfTest3 = testSetup.NewFlag("test3")
+ sfTest4 = testSetup.NewFlag("test4")
+ sfiTestWeight = testSetup.NewField("nodeWeight", reflect.TypeOf(uint64(0)))
+)
+
+const iterTestNodeCount = 6
+
+func TestWrsIterator(t *testing.T) {
+ ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup)
+ w := NewWrsIterator(ns, sfTest2, sfTest3.Or(sfTest4), sfiTestWeight)
+ ns.Start()
+ for i := 1; i <= iterTestNodeCount; i++ {
+ ns.SetState(testNode(i), sfTest1, nodestate.Flags{}, 0)
+ ns.SetField(testNode(i), sfiTestWeight, uint64(1))
+ }
+ next := func() int {
+ ch := make(chan struct{})
+ go func() {
+ w.Next()
+ close(ch)
+ }()
+ select {
+ case <-ch:
+ case <-time.After(time.Second * 5):
+ t.Fatalf("Iterator.Next() timeout")
+ }
+ node := w.Node()
+ ns.SetState(node, sfTest4, nodestate.Flags{}, 0)
+ return testNodeIndex(node.ID())
+ }
+ set := make(map[int]bool)
+ expset := func() {
+ for len(set) > 0 {
+ n := next()
+ if !set[n] {
+ t.Errorf("Item returned by iterator not in the expected set (got %d)", n)
+ }
+ delete(set, n)
+ }
+ }
+
+ ns.SetState(testNode(1), sfTest2, nodestate.Flags{}, 0)
+ ns.SetState(testNode(2), sfTest2, nodestate.Flags{}, 0)
+ ns.SetState(testNode(3), sfTest2, nodestate.Flags{}, 0)
+ set[1] = true
+ set[2] = true
+ set[3] = true
+ expset()
+ ns.SetState(testNode(4), sfTest2, nodestate.Flags{}, 0)
+ ns.SetState(testNode(5), sfTest2.Or(sfTest3), nodestate.Flags{}, 0)
+ ns.SetState(testNode(6), sfTest2, nodestate.Flags{}, 0)
+ set[4] = true
+ set[6] = true
+ expset()
+ ns.SetField(testNode(2), sfiTestWeight, uint64(0))
+ ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0)
+ ns.SetState(testNode(2), nodestate.Flags{}, sfTest4, 0)
+ ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0)
+ set[1] = true
+ set[3] = true
+ expset()
+ ns.SetField(testNode(2), sfiTestWeight, uint64(1))
+ ns.SetState(testNode(2), nodestate.Flags{}, sfTest2, 0)
+ ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0)
+ ns.SetState(testNode(2), sfTest2, sfTest4, 0)
+ ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0)
+ set[1] = true
+ set[2] = true
+ set[3] = true
+ expset()
+ ns.Stop()
+}
diff --git a/les/metrics.go b/les/metrics.go
index 9ef8c3651..c5edb61c3 100644
--- a/les/metrics.go
+++ b/les/metrics.go
@@ -107,6 +107,13 @@ var (
requestRTT = metrics.NewRegisteredTimer("les/client/req/rtt", nil)
requestSendDelay = metrics.NewRegisteredTimer("les/client/req/sendDelay", nil)
+
+ serverSelectableGauge = metrics.NewRegisteredGauge("les/client/serverPool/selectable", nil)
+ serverDialedMeter = metrics.NewRegisteredMeter("les/client/serverPool/dialed", nil)
+ serverConnectedGauge = metrics.NewRegisteredGauge("les/client/serverPool/connected", nil)
+ sessionValueMeter = metrics.NewRegisteredMeter("les/client/serverPool/sessionValue", nil)
+ totalValueGauge = metrics.NewRegisteredGauge("les/client/serverPool/totalValue", nil)
+ suggestedTimeoutGauge = metrics.NewRegisteredGauge("les/client/serverPool/timeout", nil)
)
// meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of
diff --git a/les/peer.go b/les/peer.go
index e92b4580d..4793d9026 100644
--- a/les/peer.go
+++ b/les/peer.go
@@ -336,7 +336,6 @@ type serverPeer struct {
checkpointNumber uint64 // The block height which the checkpoint is registered.
checkpoint params.TrustedCheckpoint // The advertised checkpoint sent by server.
- poolEntry *poolEntry // Statistic for server peer.
fcServer *flowcontrol.ServerNode // Client side mirror token bucket.
vtLock sync.Mutex
valueTracker *lpc.ValueTracker
diff --git a/les/protocol.go b/les/protocol.go
index f8ad94a7b..4fd19f9be 100644
--- a/les/protocol.go
+++ b/les/protocol.go
@@ -130,7 +130,6 @@ func init() {
}
requestMapping[uint32(code)] = rm
}
-
}
type errCode int
diff --git a/les/retrieve.go b/les/retrieve.go
index 5fa68b745..4f77004f2 100644
--- a/les/retrieve.go
+++ b/les/retrieve.go
@@ -24,22 +24,20 @@ import (
"sync"
"time"
- "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/light"
)
var (
retryQueue = time.Millisecond * 100
- softRequestTimeout = time.Millisecond * 500
hardRequestTimeout = time.Second * 10
)
// retrieveManager is a layer on top of requestDistributor which takes care of
// matching replies by request ID and handles timeouts and resends if necessary.
type retrieveManager struct {
- dist *requestDistributor
- peers *serverPeerSet
- serverPool peerSelector
+ dist *requestDistributor
+ peers *serverPeerSet
+ softRequestTimeout func() time.Duration
lock sync.RWMutex
sentReqs map[uint64]*sentReq
@@ -48,11 +46,6 @@ type retrieveManager struct {
// validatorFunc is a function that processes a reply message
type validatorFunc func(distPeer, *Msg) error
-// peerSelector receives feedback info about response times and timeouts
-type peerSelector interface {
- adjustResponseTime(*poolEntry, time.Duration, bool)
-}
-
// sentReq represents a request sent and tracked by retrieveManager
type sentReq struct {
rm *retrieveManager
@@ -99,12 +92,12 @@ const (
)
// newRetrieveManager creates the retrieve manager
-func newRetrieveManager(peers *serverPeerSet, dist *requestDistributor, serverPool peerSelector) *retrieveManager {
+func newRetrieveManager(peers *serverPeerSet, dist *requestDistributor, srto func() time.Duration) *retrieveManager {
return &retrieveManager{
- peers: peers,
- dist: dist,
- serverPool: serverPool,
- sentReqs: make(map[uint64]*sentReq),
+ peers: peers,
+ dist: dist,
+ sentReqs: make(map[uint64]*sentReq),
+ softRequestTimeout: srto,
}
}
@@ -325,8 +318,7 @@ func (r *sentReq) tryRequest() {
return
}
- reqSent := mclock.Now()
- srto, hrto := false, false
+ hrto := false
r.lock.RLock()
s, ok := r.sentTo[p]
@@ -338,11 +330,7 @@ func (r *sentReq) tryRequest() {
defer func() {
// send feedback to server pool and remove peer if hard timeout happened
pp, ok := p.(*serverPeer)
- if ok && r.rm.serverPool != nil {
- respTime := time.Duration(mclock.Now() - reqSent)
- r.rm.serverPool.adjustResponseTime(pp.poolEntry, respTime, srto)
- }
- if hrto {
+ if hrto && ok {
pp.Log().Debug("Request timed out hard")
if r.rm.peers != nil {
r.rm.peers.unregister(pp.id)
@@ -363,8 +351,7 @@ func (r *sentReq) tryRequest() {
}
r.eventsCh <- reqPeerEvent{event, p}
return
- case <-time.After(softRequestTimeout):
- srto = true
+ case <-time.After(r.rm.softRequestTimeout()):
r.eventsCh <- reqPeerEvent{rpSoftTimeout, p}
}
diff --git a/les/server.go b/les/server.go
index f72f31321..4b623f61e 100644
--- a/les/server.go
+++ b/les/server.go
@@ -157,7 +157,7 @@ func (s *LesServer) Protocols() []p2p.Protocol {
return p.Info()
}
return nil
- })
+ }, nil)
// Add "les" ENR entries.
for i := range ps {
ps[i].Attributes = []enr.Entry{&lesEntry{}}
diff --git a/les/serverpool.go b/les/serverpool.go
index d1c53295a..aff774324 100644
--- a/les/serverpool.go
+++ b/les/serverpool.go
@@ -1,4 +1,4 @@
-// Copyright 2016 The go-ethereum Authors
+// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
@@ -17,904 +17,457 @@
package les
import (
- "crypto/ecdsa"
- "fmt"
- "io"
- "math"
+ "errors"
"math/rand"
- "net"
- "strconv"
+ "reflect"
"sync"
+ "sync/atomic"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
- "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
+ lpc "github.com/ethereum/go-ethereum/les/lespay/client"
"github.com/ethereum/go-ethereum/les/utils"
"github.com/ethereum/go-ethereum/log"
- "github.com/ethereum/go-ethereum/p2p"
- "github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+ "github.com/ethereum/go-ethereum/p2p/nodestate"
"github.com/ethereum/go-ethereum/rlp"
)
const (
- // After a connection has been ended or timed out, there is a waiting period
- // before it can be selected for connection again.
- // waiting period = base delay * (1 + random(1))
- // base delay = shortRetryDelay for the first shortRetryCnt times after a
- // successful connection, after that longRetryDelay is applied
- shortRetryCnt = 5
- shortRetryDelay = time.Second * 5
- longRetryDelay = time.Minute * 10
- // maxNewEntries is the maximum number of newly discovered (never connected) nodes.
- // If the limit is reached, the least recently discovered one is thrown out.
- maxNewEntries = 1000
- // maxKnownEntries is the maximum number of known (already connected) nodes.
- // If the limit is reached, the least recently connected one is thrown out.
- // (not that unlike new entries, known entries are persistent)
- maxKnownEntries = 1000
- // target for simultaneously connected servers
- targetServerCount = 5
- // target for servers selected from the known table
- // (we leave room for trying new ones if there is any)
- targetKnownSelect = 3
- // after dialTimeout, consider the server unavailable and adjust statistics
- dialTimeout = time.Second * 30
- // targetConnTime is the minimum expected connection duration before a server
- // drops a client without any specific reason
- targetConnTime = time.Minute * 10
- // new entry selection weight calculation based on most recent discovery time:
- // unity until discoverExpireStart, then exponential decay with discoverExpireConst
- discoverExpireStart = time.Minute * 20
- discoverExpireConst = time.Minute * 20
- // known entry selection weight is dropped by a factor of exp(-failDropLn) after
- // each unsuccessful connection (restored after a successful one)
- failDropLn = 0.1
- // known node connection success and quality statistics have a long term average
- // and a short term value which is adjusted exponentially with a factor of
- // pstatRecentAdjust with each dial/connection and also returned exponentially
- // to the average with the time constant pstatReturnToMeanTC
- pstatReturnToMeanTC = time.Hour
- // node address selection weight is dropped by a factor of exp(-addrFailDropLn) after
- // each unsuccessful connection (restored after a successful one)
- addrFailDropLn = math.Ln2
- // responseScoreTC and delayScoreTC are exponential decay time constants for
- // calculating selection chances from response times and block delay times
- responseScoreTC = time.Millisecond * 100
- delayScoreTC = time.Second * 5
- timeoutPow = 10
- // initStatsWeight is used to initialize previously unknown peers with good
- // statistics to give a chance to prove themselves
- initStatsWeight = 1
+ minTimeout = time.Millisecond * 500 // minimum request timeout suggested by the server pool
+ timeoutRefresh = time.Second * 5 // recalculate timeout if older than this
+ dialCost = 10000 // cost of a TCP dial (used for known node selection weight calculation)
+ dialWaitStep = 1.5 // exponential multiplier of redial wait time when no value was provided by the server
+ queryCost = 500 // cost of a UDP pre-negotiation query
+ queryWaitStep = 1.02 // exponential multiplier of redial wait time when no value was provided by the server
+ waitThreshold = time.Hour * 2000 // drop node if waiting time is over the threshold
+ nodeWeightMul = 1000000 // multiplier constant for node weight calculation
+ nodeWeightThreshold = 100 // minimum weight for keeping a node in the the known (valuable) set
+ minRedialWait = 10 // minimum redial wait time in seconds
+ preNegLimit = 5 // maximum number of simultaneous pre-negotiation queries
+ maxQueryFails = 100 // number of consecutive UDP query failures before we print a warning
)
-// connReq represents a request for peer connection.
-type connReq struct {
- p *serverPeer
- node *enode.Node
- result chan *poolEntry
-}
-
-// disconnReq represents a request for peer disconnection.
-type disconnReq struct {
- entry *poolEntry
- stopped bool
- done chan struct{}
-}
-
-// registerReq represents a request for peer registration.
-type registerReq struct {
- entry *poolEntry
- done chan struct{}
-}
-
-// serverPool implements a pool for storing and selecting newly discovered and already
-// known light server nodes. It received discovered nodes, stores statistics about
-// known nodes and takes care of always having enough good quality servers connected.
+// serverPool provides a node iterator for dial candidates. The output is a mix of newly discovered
+// nodes, a weighted random selection of known (previously valuable) nodes and trusted/paid nodes.
type serverPool struct {
- db ethdb.Database
- dbKey []byte
- server *p2p.Server
- connWg sync.WaitGroup
+ clock mclock.Clock
+ unixTime func() int64
+ db ethdb.KeyValueStore
- topic discv5.Topic
+ ns *nodestate.NodeStateMachine
+ vt *lpc.ValueTracker
+ mixer *enode.FairMix
+ mixSources []enode.Iterator
+ dialIterator enode.Iterator
+ validSchemes enr.IdentityScheme
+ trustedURLs []string
+ fillSet *lpc.FillSet
+ queryFails uint32
- discSetPeriod chan time.Duration
- discNodes chan *enode.Node
- discLookups chan bool
-
- trustedNodes map[enode.ID]*enode.Node
- entries map[enode.ID]*poolEntry
- timeout, enableRetry chan *poolEntry
- adjustStats chan poolStatAdjust
-
- knownQueue, newQueue poolEntryQueue
- knownSelect, newSelect *utils.WeightedRandomSelect
- knownSelected, newSelected int
- fastDiscover bool
- connCh chan *connReq
- disconnCh chan *disconnReq
- registerCh chan *registerReq
-
- closeCh chan struct{}
- wg sync.WaitGroup
+ timeoutLock sync.RWMutex
+ timeout time.Duration
+ timeWeights lpc.ResponseTimeWeights
+ timeoutRefreshed mclock.AbsTime
}
-// newServerPool creates a new serverPool instance
-func newServerPool(db ethdb.Database, ulcServers []string) *serverPool {
- pool := &serverPool{
+// nodeHistory keeps track of dial costs which determine node weight together with the
+// service value calculated by lpc.ValueTracker.
+type nodeHistory struct {
+ dialCost utils.ExpiredValue
+ redialWaitStart, redialWaitEnd int64 // unix time (seconds)
+}
+
+type nodeHistoryEnc struct {
+ DialCost utils.ExpiredValue
+ RedialWaitStart, RedialWaitEnd uint64
+}
+
+// queryFunc sends a pre-negotiation query and blocks until a response arrives or timeout occurs.
+// It returns 1 if the remote node has confirmed that connection is possible, 0 if not
+// possible and -1 if no response arrived (timeout).
+type queryFunc func(*enode.Node) int
+
+var (
+ serverPoolSetup = &nodestate.Setup{Version: 1}
+ sfHasValue = serverPoolSetup.NewPersistentFlag("hasValue")
+ sfQueried = serverPoolSetup.NewFlag("queried")
+ sfCanDial = serverPoolSetup.NewFlag("canDial")
+ sfDialing = serverPoolSetup.NewFlag("dialed")
+ sfWaitDialTimeout = serverPoolSetup.NewFlag("dialTimeout")
+ sfConnected = serverPoolSetup.NewFlag("connected")
+ sfRedialWait = serverPoolSetup.NewFlag("redialWait")
+ sfAlwaysConnect = serverPoolSetup.NewFlag("alwaysConnect")
+ sfDisableSelection = nodestate.MergeFlags(sfQueried, sfCanDial, sfDialing, sfConnected, sfRedialWait)
+
+ sfiNodeHistory = serverPoolSetup.NewPersistentField("nodeHistory", reflect.TypeOf(nodeHistory{}),
+ func(field interface{}) ([]byte, error) {
+ if n, ok := field.(nodeHistory); ok {
+ ne := nodeHistoryEnc{
+ DialCost: n.dialCost,
+ RedialWaitStart: uint64(n.redialWaitStart),
+ RedialWaitEnd: uint64(n.redialWaitEnd),
+ }
+ enc, err := rlp.EncodeToBytes(&ne)
+ return enc, err
+ } else {
+ return nil, errors.New("invalid field type")
+ }
+ },
+ func(enc []byte) (interface{}, error) {
+ var ne nodeHistoryEnc
+ err := rlp.DecodeBytes(enc, &ne)
+ n := nodeHistory{
+ dialCost: ne.DialCost,
+ redialWaitStart: int64(ne.RedialWaitStart),
+ redialWaitEnd: int64(ne.RedialWaitEnd),
+ }
+ return n, err
+ },
+ )
+ sfiNodeWeight = serverPoolSetup.NewField("nodeWeight", reflect.TypeOf(uint64(0)))
+ sfiConnectedStats = serverPoolSetup.NewField("connectedStats", reflect.TypeOf(lpc.ResponseTimeStats{}))
+)
+
+// newServerPool creates a new server pool
+func newServerPool(db ethdb.KeyValueStore, dbKey []byte, vt *lpc.ValueTracker, discovery enode.Iterator, mixTimeout time.Duration, query queryFunc, clock mclock.Clock, trustedURLs []string) *serverPool {
+ s := &serverPool{
db: db,
- entries: make(map[enode.ID]*poolEntry),
- timeout: make(chan *poolEntry, 1),
- adjustStats: make(chan poolStatAdjust, 100),
- enableRetry: make(chan *poolEntry, 1),
- connCh: make(chan *connReq),
- disconnCh: make(chan *disconnReq),
- registerCh: make(chan *registerReq),
- closeCh: make(chan struct{}),
- knownSelect: utils.NewWeightedRandomSelect(),
- newSelect: utils.NewWeightedRandomSelect(),
- fastDiscover: true,
- trustedNodes: parseTrustedNodes(ulcServers),
+ clock: clock,
+ unixTime: func() int64 { return time.Now().Unix() },
+ validSchemes: enode.ValidSchemes,
+ trustedURLs: trustedURLs,
+ vt: vt,
+ ns: nodestate.NewNodeStateMachine(db, []byte(string(dbKey)+"ns:"), clock, serverPoolSetup),
+ }
+ s.recalTimeout()
+ s.mixer = enode.NewFairMix(mixTimeout)
+ knownSelector := lpc.NewWrsIterator(s.ns, sfHasValue, sfDisableSelection, sfiNodeWeight)
+ alwaysConnect := lpc.NewQueueIterator(s.ns, sfAlwaysConnect, sfDisableSelection, true, nil)
+ s.mixSources = append(s.mixSources, knownSelector)
+ s.mixSources = append(s.mixSources, alwaysConnect)
+ if discovery != nil {
+ s.mixSources = append(s.mixSources, discovery)
}
- pool.knownQueue = newPoolEntryQueue(maxKnownEntries, pool.removeEntry)
- pool.newQueue = newPoolEntryQueue(maxNewEntries, pool.removeEntry)
- return pool
+ iter := enode.Iterator(s.mixer)
+ if query != nil {
+ iter = s.addPreNegFilter(iter, query)
+ }
+ s.dialIterator = enode.Filter(iter, func(node *enode.Node) bool {
+ s.ns.SetState(node, sfDialing, sfCanDial, 0)
+ s.ns.SetState(node, sfWaitDialTimeout, nodestate.Flags{}, time.Second*10)
+ return true
+ })
+
+ s.ns.SubscribeState(nodestate.MergeFlags(sfWaitDialTimeout, sfConnected), func(n *enode.Node, oldState, newState nodestate.Flags) {
+ if oldState.Equals(sfWaitDialTimeout) && newState.IsEmpty() {
+ // dial timeout, no connection
+ s.setRedialWait(n, dialCost, dialWaitStep)
+ s.ns.SetState(n, nodestate.Flags{}, sfDialing, 0)
+ }
+ })
+
+ s.ns.AddLogMetrics(sfHasValue, sfDisableSelection, "selectable", nil, nil, serverSelectableGauge)
+ s.ns.AddLogMetrics(sfDialing, nodestate.Flags{}, "dialed", serverDialedMeter, nil, nil)
+ s.ns.AddLogMetrics(sfConnected, nodestate.Flags{}, "connected", nil, nil, serverConnectedGauge)
+ return s
}
-func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) {
- pool.server = server
- pool.topic = topic
- pool.dbKey = append([]byte("serverPool/"), []byte(topic)...)
- pool.loadNodes()
- pool.connectToTrustedNodes()
-
- if pool.server.DiscV5 != nil {
- pool.discSetPeriod = make(chan time.Duration, 1)
- pool.discNodes = make(chan *enode.Node, 100)
- pool.discLookups = make(chan bool, 100)
- go pool.discoverNodes()
- }
- pool.checkDial()
- pool.wg.Add(1)
- go pool.eventLoop()
-
- // Inject the bootstrap nodes as initial dial candiates.
- pool.wg.Add(1)
- go func() {
- defer pool.wg.Done()
- for _, n := range server.BootstrapNodes {
- select {
- case pool.discNodes <- n:
- case <-pool.closeCh:
+// addPreNegFilter installs a node filter mechanism that performs a pre-negotiation query.
+// Nodes that are filtered out and does not appear on the output iterator are put back
+// into redialWait state.
+func (s *serverPool) addPreNegFilter(input enode.Iterator, query queryFunc) enode.Iterator {
+ s.fillSet = lpc.NewFillSet(s.ns, input, sfQueried)
+ s.ns.SubscribeState(sfQueried, func(n *enode.Node, oldState, newState nodestate.Flags) {
+ if newState.Equals(sfQueried) {
+ fails := atomic.LoadUint32(&s.queryFails)
+ if fails == maxQueryFails {
+ log.Warn("UDP pre-negotiation query does not seem to work")
+ }
+ if fails > maxQueryFails {
+ fails = maxQueryFails
+ }
+ if rand.Intn(maxQueryFails*2) < int(fails) {
+ // skip pre-negotiation with increasing chance, max 50%
+ // this ensures that the client can operate even if UDP is not working at all
+ s.ns.SetState(n, sfCanDial, nodestate.Flags{}, time.Second*10)
+ // set canDial before resetting queried so that FillSet will not read more
+ // candidates unnecessarily
+ s.ns.SetState(n, nodestate.Flags{}, sfQueried, 0)
return
}
- }
- }()
-}
-
-func (pool *serverPool) stop() {
- close(pool.closeCh)
- pool.wg.Wait()
-}
-
-// discoverNodes wraps SearchTopic, converting result nodes to enode.Node.
-func (pool *serverPool) discoverNodes() {
- ch := make(chan *discv5.Node)
- go func() {
- pool.server.DiscV5.SearchTopic(pool.topic, pool.discSetPeriod, ch, pool.discLookups)
- close(ch)
- }()
- for n := range ch {
- pubkey, err := decodePubkey64(n.ID[:])
- if err != nil {
- continue
- }
- pool.discNodes <- enode.NewV4(pubkey, n.IP, int(n.TCP), int(n.UDP))
- }
-}
-
-// connect should be called upon any incoming connection. If the connection has been
-// dialed by the server pool recently, the appropriate pool entry is returned.
-// Otherwise, the connection should be rejected.
-// Note that whenever a connection has been accepted and a pool entry has been returned,
-// disconnect should also always be called.
-func (pool *serverPool) connect(p *serverPeer, node *enode.Node) *poolEntry {
- log.Debug("Connect new entry", "enode", p.id)
- req := &connReq{p: p, node: node, result: make(chan *poolEntry, 1)}
- select {
- case pool.connCh <- req:
- case <-pool.closeCh:
- return nil
- }
- return <-req.result
-}
-
-// registered should be called after a successful handshake
-func (pool *serverPool) registered(entry *poolEntry) {
- log.Debug("Registered new entry", "enode", entry.node.ID())
- req := ®isterReq{entry: entry, done: make(chan struct{})}
- select {
- case pool.registerCh <- req:
- case <-pool.closeCh:
- return
- }
- <-req.done
-}
-
-// disconnect should be called when ending a connection. Service quality statistics
-// can be updated optionally (not updated if no registration happened, in this case
-// only connection statistics are updated, just like in case of timeout)
-func (pool *serverPool) disconnect(entry *poolEntry) {
- stopped := false
- select {
- case <-pool.closeCh:
- stopped = true
- default:
- }
- log.Debug("Disconnected old entry", "enode", entry.node.ID())
- req := &disconnReq{entry: entry, stopped: stopped, done: make(chan struct{})}
-
- // Block until disconnection request is served.
- pool.disconnCh <- req
- <-req.done
-}
-
-const (
- pseBlockDelay = iota
- pseResponseTime
- pseResponseTimeout
-)
-
-// poolStatAdjust records are sent to adjust peer block delay/response time statistics
-type poolStatAdjust struct {
- adjustType int
- entry *poolEntry
- time time.Duration
-}
-
-// adjustBlockDelay adjusts the block announce delay statistics of a node
-func (pool *serverPool) adjustBlockDelay(entry *poolEntry, time time.Duration) {
- if entry == nil {
- return
- }
- pool.adjustStats <- poolStatAdjust{pseBlockDelay, entry, time}
-}
-
-// adjustResponseTime adjusts the request response time statistics of a node
-func (pool *serverPool) adjustResponseTime(entry *poolEntry, time time.Duration, timeout bool) {
- if entry == nil {
- return
- }
- if timeout {
- pool.adjustStats <- poolStatAdjust{pseResponseTimeout, entry, time}
- } else {
- pool.adjustStats <- poolStatAdjust{pseResponseTime, entry, time}
- }
-}
-
-// eventLoop handles pool events and mutex locking for all internal functions
-func (pool *serverPool) eventLoop() {
- defer pool.wg.Done()
- lookupCnt := 0
- var convTime mclock.AbsTime
- if pool.discSetPeriod != nil {
- pool.discSetPeriod <- time.Millisecond * 100
- }
-
- // disconnect updates service quality statistics depending on the connection time
- // and disconnection initiator.
- disconnect := func(req *disconnReq, stopped bool) {
- // Handle peer disconnection requests.
- entry := req.entry
- if entry.state == psRegistered {
- connAdjust := float64(mclock.Now()-entry.regTime) / float64(targetConnTime)
- if connAdjust > 1 {
- connAdjust = 1
- }
- if stopped {
- // disconnect requested by ourselves.
- entry.connectStats.add(1, connAdjust)
- } else {
- // disconnect requested by server side.
- entry.connectStats.add(connAdjust, 1)
- }
- }
- entry.state = psNotConnected
-
- if entry.knownSelected {
- pool.knownSelected--
- } else {
- pool.newSelected--
- }
- pool.setRetryDial(entry)
- pool.connWg.Done()
- close(req.done)
- }
-
- for {
- select {
- case entry := <-pool.timeout:
- if !entry.removed {
- pool.checkDialTimeout(entry)
- }
-
- case entry := <-pool.enableRetry:
- if !entry.removed {
- entry.delayedRetry = false
- pool.updateCheckDial(entry)
- }
-
- case adj := <-pool.adjustStats:
- switch adj.adjustType {
- case pseBlockDelay:
- adj.entry.delayStats.add(float64(adj.time), 1)
- case pseResponseTime:
- adj.entry.responseStats.add(float64(adj.time), 1)
- adj.entry.timeoutStats.add(0, 1)
- case pseResponseTimeout:
- adj.entry.timeoutStats.add(1, 1)
- }
-
- case node := <-pool.discNodes:
- if pool.trustedNodes[node.ID()] == nil {
- entry := pool.findOrNewNode(node)
- pool.updateCheckDial(entry)
- }
-
- case conv := <-pool.discLookups:
- if conv {
- if lookupCnt == 0 {
- convTime = mclock.Now()
- }
- lookupCnt++
- if pool.fastDiscover && (lookupCnt == 50 || time.Duration(mclock.Now()-convTime) > time.Minute) {
- pool.fastDiscover = false
- if pool.discSetPeriod != nil {
- pool.discSetPeriod <- time.Minute
- }
- }
- }
-
- case req := <-pool.connCh:
- if pool.trustedNodes[req.p.ID()] != nil {
- // ignore trusted nodes
- req.result <- &poolEntry{trusted: true}
- } else {
- // Handle peer connection requests.
- entry := pool.entries[req.p.ID()]
- if entry == nil {
- entry = pool.findOrNewNode(req.node)
- }
- if entry.state == psConnected || entry.state == psRegistered {
- req.result <- nil
- continue
- }
- pool.connWg.Add(1)
- entry.peer = req.p
- entry.state = psConnected
- addr := &poolEntryAddress{
- ip: req.node.IP(),
- port: uint16(req.node.TCP()),
- lastSeen: mclock.Now(),
- }
- entry.lastConnected = addr
- entry.addr = make(map[string]*poolEntryAddress)
- entry.addr[addr.strKey()] = addr
- entry.addrSelect = *utils.NewWeightedRandomSelect()
- entry.addrSelect.Update(addr)
- req.result <- entry
- }
-
- case req := <-pool.registerCh:
- if req.entry.trusted {
- continue
- }
- // Handle peer registration requests.
- entry := req.entry
- entry.state = psRegistered
- entry.regTime = mclock.Now()
- if !entry.known {
- pool.newQueue.remove(entry)
- entry.known = true
- }
- pool.knownQueue.setLatest(entry)
- entry.shortRetry = shortRetryCnt
- close(req.done)
-
- case req := <-pool.disconnCh:
- if req.entry.trusted {
- continue
- }
- // Handle peer disconnection requests.
- disconnect(req, req.stopped)
-
- case <-pool.closeCh:
- if pool.discSetPeriod != nil {
- close(pool.discSetPeriod)
- }
-
- // Spawn a goroutine to close the disconnCh after all connections are disconnected.
go func() {
- pool.connWg.Wait()
- close(pool.disconnCh)
+ q := query(n)
+ if q == -1 {
+ atomic.AddUint32(&s.queryFails, 1)
+ } else {
+ atomic.StoreUint32(&s.queryFails, 0)
+ }
+ if q == 1 {
+ s.ns.SetState(n, sfCanDial, nodestate.Flags{}, time.Second*10)
+ } else {
+ s.setRedialWait(n, queryCost, queryWaitStep)
+ }
+ s.ns.SetState(n, nodestate.Flags{}, sfQueried, 0)
}()
-
- // Handle all remaining disconnection requests before exit.
- for req := range pool.disconnCh {
- disconnect(req, true)
- }
- pool.saveNodes()
- return
}
- }
-}
-
-func (pool *serverPool) findOrNewNode(node *enode.Node) *poolEntry {
- now := mclock.Now()
- entry := pool.entries[node.ID()]
- if entry == nil {
- log.Debug("Discovered new entry", "id", node.ID())
- entry = &poolEntry{
- node: node,
- addr: make(map[string]*poolEntryAddress),
- addrSelect: *utils.NewWeightedRandomSelect(),
- shortRetry: shortRetryCnt,
+ })
+ return lpc.NewQueueIterator(s.ns, sfCanDial, nodestate.Flags{}, false, func(waiting bool) {
+ if waiting {
+ s.fillSet.SetTarget(preNegLimit)
+ } else {
+ s.fillSet.SetTarget(0)
}
- pool.entries[node.ID()] = entry
- // initialize previously unknown peers with good statistics to give a chance to prove themselves
- entry.connectStats.add(1, initStatsWeight)
- entry.delayStats.add(0, initStatsWeight)
- entry.responseStats.add(0, initStatsWeight)
- entry.timeoutStats.add(0, initStatsWeight)
- }
- entry.lastDiscovered = now
- addr := &poolEntryAddress{ip: node.IP(), port: uint16(node.TCP())}
- if a, ok := entry.addr[addr.strKey()]; ok {
- addr = a
- } else {
- entry.addr[addr.strKey()] = addr
- }
- addr.lastSeen = now
- entry.addrSelect.Update(addr)
- if !entry.known {
- pool.newQueue.setLatest(entry)
- }
- return entry
-}
-
-// loadNodes loads known nodes and their statistics from the database
-func (pool *serverPool) loadNodes() {
- enc, err := pool.db.Get(pool.dbKey)
- if err != nil {
- return
- }
- var list []*poolEntry
- err = rlp.DecodeBytes(enc, &list)
- if err != nil {
- log.Debug("Failed to decode node list", "err", err)
- return
- }
- for _, e := range list {
- log.Debug("Loaded server stats", "id", e.node.ID(), "fails", e.lastConnected.fails,
- "conn", fmt.Sprintf("%v/%v", e.connectStats.avg, e.connectStats.weight),
- "delay", fmt.Sprintf("%v/%v", time.Duration(e.delayStats.avg), e.delayStats.weight),
- "response", fmt.Sprintf("%v/%v", time.Duration(e.responseStats.avg), e.responseStats.weight),
- "timeout", fmt.Sprintf("%v/%v", e.timeoutStats.avg, e.timeoutStats.weight))
- pool.entries[e.node.ID()] = e
- if pool.trustedNodes[e.node.ID()] == nil {
- pool.knownQueue.setLatest(e)
- pool.knownSelect.Update((*knownEntry)(e))
- }
- }
-}
-
-// connectToTrustedNodes adds trusted server nodes as static trusted peers.
-//
-// Note: trusted nodes are not handled by the server pool logic, they are not
-// added to either the known or new selection pools. They are connected/reconnected
-// by p2p.Server whenever possible.
-func (pool *serverPool) connectToTrustedNodes() {
- //connect to trusted nodes
- for _, node := range pool.trustedNodes {
- pool.server.AddTrustedPeer(node)
- pool.server.AddPeer(node)
- log.Debug("Added trusted node", "id", node.ID().String())
- }
-}
-
-// parseTrustedNodes returns valid and parsed enodes
-func parseTrustedNodes(trustedNodes []string) map[enode.ID]*enode.Node {
- nodes := make(map[enode.ID]*enode.Node)
-
- for _, node := range trustedNodes {
- node, err := enode.Parse(enode.ValidSchemes, node)
- if err != nil {
- log.Warn("Trusted node URL invalid", "enode", node, "err", err)
- continue
- }
- nodes[node.ID()] = node
- }
- return nodes
-}
-
-// saveNodes saves known nodes and their statistics into the database. Nodes are
-// ordered from least to most recently connected.
-func (pool *serverPool) saveNodes() {
- list := make([]*poolEntry, len(pool.knownQueue.queue))
- for i := range list {
- list[i] = pool.knownQueue.fetchOldest()
- }
- enc, err := rlp.EncodeToBytes(list)
- if err == nil {
- pool.db.Put(pool.dbKey, enc)
- }
-}
-
-// removeEntry removes a pool entry when the entry count limit is reached.
-// Note that it is called by the new/known queues from which the entry has already
-// been removed so removing it from the queues is not necessary.
-func (pool *serverPool) removeEntry(entry *poolEntry) {
- pool.newSelect.Remove((*discoveredEntry)(entry))
- pool.knownSelect.Remove((*knownEntry)(entry))
- entry.removed = true
- delete(pool.entries, entry.node.ID())
-}
-
-// setRetryDial starts the timer which will enable dialing a certain node again
-func (pool *serverPool) setRetryDial(entry *poolEntry) {
- delay := longRetryDelay
- if entry.shortRetry > 0 {
- entry.shortRetry--
- delay = shortRetryDelay
- }
- delay += time.Duration(rand.Int63n(int64(delay) + 1))
- entry.delayedRetry = true
- go func() {
- select {
- case <-pool.closeCh:
- case <-time.After(delay):
- select {
- case <-pool.closeCh:
- case pool.enableRetry <- entry:
- }
- }
- }()
-}
-
-// updateCheckDial is called when an entry can potentially be dialed again. It updates
-// its selection weights and checks if new dials can/should be made.
-func (pool *serverPool) updateCheckDial(entry *poolEntry) {
- pool.newSelect.Update((*discoveredEntry)(entry))
- pool.knownSelect.Update((*knownEntry)(entry))
- pool.checkDial()
-}
-
-// checkDial checks if new dials can/should be made. It tries to select servers both
-// based on good statistics and recent discovery.
-func (pool *serverPool) checkDial() {
- fillWithKnownSelects := !pool.fastDiscover
- for pool.knownSelected < targetKnownSelect {
- entry := pool.knownSelect.Choose()
- if entry == nil {
- fillWithKnownSelects = false
- break
- }
- pool.dial((*poolEntry)(entry.(*knownEntry)), true)
- }
- for pool.knownSelected+pool.newSelected < targetServerCount {
- entry := pool.newSelect.Choose()
- if entry == nil {
- break
- }
- pool.dial((*poolEntry)(entry.(*discoveredEntry)), false)
- }
- if fillWithKnownSelects {
- // no more newly discovered nodes to select and since fast discover period
- // is over, we probably won't find more in the near future so select more
- // known entries if possible
- for pool.knownSelected < targetServerCount {
- entry := pool.knownSelect.Choose()
- if entry == nil {
- break
- }
- pool.dial((*poolEntry)(entry.(*knownEntry)), true)
- }
- }
-}
-
-// dial initiates a new connection
-func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) {
- if pool.server == nil || entry.state != psNotConnected {
- return
- }
- entry.state = psDialed
- entry.knownSelected = knownSelected
- if knownSelected {
- pool.knownSelected++
- } else {
- pool.newSelected++
- }
- addr := entry.addrSelect.Choose().(*poolEntryAddress)
- log.Debug("Dialing new peer", "lesaddr", entry.node.ID().String()+"@"+addr.strKey(), "set", len(entry.addr), "known", knownSelected)
- entry.dialed = addr
- go func() {
- pool.server.AddPeer(entry.node)
- select {
- case <-pool.closeCh:
- case <-time.After(dialTimeout):
- select {
- case <-pool.closeCh:
- case pool.timeout <- entry:
- }
- }
- }()
-}
-
-// checkDialTimeout checks if the node is still in dialed state and if so, resets it
-// and adjusts connection statistics accordingly.
-func (pool *serverPool) checkDialTimeout(entry *poolEntry) {
- if entry.state != psDialed {
- return
- }
- log.Debug("Dial timeout", "lesaddr", entry.node.ID().String()+"@"+entry.dialed.strKey())
- entry.state = psNotConnected
- if entry.knownSelected {
- pool.knownSelected--
- } else {
- pool.newSelected--
- }
- entry.connectStats.add(0, 1)
- entry.dialed.fails++
- pool.setRetryDial(entry)
-}
-
-const (
- psNotConnected = iota
- psDialed
- psConnected
- psRegistered
-)
-
-// poolEntry represents a server node and stores its current state and statistics.
-type poolEntry struct {
- peer *serverPeer
- pubkey [64]byte // secp256k1 key of the node
- addr map[string]*poolEntryAddress
- node *enode.Node
- lastConnected, dialed *poolEntryAddress
- addrSelect utils.WeightedRandomSelect
-
- lastDiscovered mclock.AbsTime
- known, knownSelected, trusted bool
- connectStats, delayStats poolStats
- responseStats, timeoutStats poolStats
- state int
- regTime mclock.AbsTime
- queueIdx int
- removed bool
-
- delayedRetry bool
- shortRetry int
-}
-
-// poolEntryEnc is the RLP encoding of poolEntry.
-type poolEntryEnc struct {
- Pubkey []byte
- IP net.IP
- Port uint16
- Fails uint
- CStat, DStat, RStat, TStat poolStats
-}
-
-func (e *poolEntry) EncodeRLP(w io.Writer) error {
- return rlp.Encode(w, &poolEntryEnc{
- Pubkey: encodePubkey64(e.node.Pubkey()),
- IP: e.lastConnected.ip,
- Port: e.lastConnected.port,
- Fails: e.lastConnected.fails,
- CStat: e.connectStats,
- DStat: e.delayStats,
- RStat: e.responseStats,
- TStat: e.timeoutStats,
})
}
-func (e *poolEntry) DecodeRLP(s *rlp.Stream) error {
- var entry poolEntryEnc
- if err := s.Decode(&entry); err != nil {
- return err
+// start starts the server pool. Note that NodeStateMachine should be started first.
+func (s *serverPool) start() {
+ s.ns.Start()
+ for _, iter := range s.mixSources {
+ // add sources to mixer at startup because the mixer instantly tries to read them
+ // which should only happen after NodeStateMachine has been started
+ s.mixer.AddSource(iter)
}
- pubkey, err := decodePubkey64(entry.Pubkey)
- if err != nil {
- return err
- }
- addr := &poolEntryAddress{ip: entry.IP, port: entry.Port, fails: entry.Fails, lastSeen: mclock.Now()}
- e.node = enode.NewV4(pubkey, entry.IP, int(entry.Port), int(entry.Port))
- e.addr = make(map[string]*poolEntryAddress)
- e.addr[addr.strKey()] = addr
- e.addrSelect = *utils.NewWeightedRandomSelect()
- e.addrSelect.Update(addr)
- e.lastConnected = addr
- e.connectStats = entry.CStat
- e.delayStats = entry.DStat
- e.responseStats = entry.RStat
- e.timeoutStats = entry.TStat
- e.shortRetry = shortRetryCnt
- e.known = true
- return nil
-}
-
-func encodePubkey64(pub *ecdsa.PublicKey) []byte {
- return crypto.FromECDSAPub(pub)[1:]
-}
-
-func decodePubkey64(b []byte) (*ecdsa.PublicKey, error) {
- return crypto.UnmarshalPubkey(append([]byte{0x04}, b...))
-}
-
-// discoveredEntry implements wrsItem
-type discoveredEntry poolEntry
-
-// Weight calculates random selection weight for newly discovered entries
-func (e *discoveredEntry) Weight() int64 {
- if e.state != psNotConnected || e.delayedRetry {
- return 0
- }
- t := time.Duration(mclock.Now() - e.lastDiscovered)
- if t <= discoverExpireStart {
- return 1000000000
- }
- return int64(1000000000 * math.Exp(-float64(t-discoverExpireStart)/float64(discoverExpireConst)))
-}
-
-// knownEntry implements wrsItem
-type knownEntry poolEntry
-
-// Weight calculates random selection weight for known entries
-func (e *knownEntry) Weight() int64 {
- if e.state != psNotConnected || !e.known || e.delayedRetry {
- return 0
- }
- return int64(1000000000 * e.connectStats.recentAvg() * math.Exp(-float64(e.lastConnected.fails)*failDropLn-e.responseStats.recentAvg()/float64(responseScoreTC)-e.delayStats.recentAvg()/float64(delayScoreTC)) * math.Pow(1-e.timeoutStats.recentAvg(), timeoutPow))
-}
-
-// poolEntryAddress is a separate object because currently it is necessary to remember
-// multiple potential network addresses for a pool entry. This will be removed after
-// the final implementation of v5 discovery which will retrieve signed and serial
-// numbered advertisements, making it clear which IP/port is the latest one.
-type poolEntryAddress struct {
- ip net.IP
- port uint16
- lastSeen mclock.AbsTime // last time it was discovered, connected or loaded from db
- fails uint // connection failures since last successful connection (persistent)
-}
-
-func (a *poolEntryAddress) Weight() int64 {
- t := time.Duration(mclock.Now() - a.lastSeen)
- return int64(1000000*math.Exp(-float64(t)/float64(discoverExpireConst)-float64(a.fails)*addrFailDropLn)) + 1
-}
-
-func (a *poolEntryAddress) strKey() string {
- return a.ip.String() + ":" + strconv.Itoa(int(a.port))
-}
-
-// poolStats implement statistics for a certain quantity with a long term average
-// and a short term value which is adjusted exponentially with a factor of
-// pstatRecentAdjust with each update and also returned exponentially to the
-// average with the time constant pstatReturnToMeanTC
-type poolStats struct {
- sum, weight, avg, recent float64
- lastRecalc mclock.AbsTime
-}
-
-// init initializes stats with a long term sum/update count pair retrieved from the database
-func (s *poolStats) init(sum, weight float64) {
- s.sum = sum
- s.weight = weight
- var avg float64
- if weight > 0 {
- avg = s.sum / weight
- }
- s.avg = avg
- s.recent = avg
- s.lastRecalc = mclock.Now()
-}
-
-// recalc recalculates recent value return-to-mean and long term average
-func (s *poolStats) recalc() {
- now := mclock.Now()
- s.recent = s.avg + (s.recent-s.avg)*math.Exp(-float64(now-s.lastRecalc)/float64(pstatReturnToMeanTC))
- if s.sum == 0 {
- s.avg = 0
- } else {
- if s.sum > s.weight*1e30 {
- s.avg = 1e30
+ for _, url := range s.trustedURLs {
+ if node, err := enode.Parse(s.validSchemes, url); err == nil {
+ s.ns.SetState(node, sfAlwaysConnect, nodestate.Flags{}, 0)
} else {
- s.avg = s.sum / s.weight
+ log.Error("Invalid trusted server URL", "url", url, "error", err)
}
}
- s.lastRecalc = now
-}
-
-// add updates the stats with a new value
-func (s *poolStats) add(value, weight float64) {
- s.weight += weight
- s.sum += value * weight
- s.recalc()
-}
-
-// recentAvg returns the short-term adjusted average
-func (s *poolStats) recentAvg() float64 {
- s.recalc()
- return s.recent
-}
-
-func (s *poolStats) EncodeRLP(w io.Writer) error {
- return rlp.Encode(w, []interface{}{math.Float64bits(s.sum), math.Float64bits(s.weight)})
-}
-
-func (s *poolStats) DecodeRLP(st *rlp.Stream) error {
- var stats struct {
- SumUint, WeightUint uint64
- }
- if err := st.Decode(&stats); err != nil {
- return err
- }
- s.init(math.Float64frombits(stats.SumUint), math.Float64frombits(stats.WeightUint))
- return nil
-}
-
-// poolEntryQueue keeps track of its least recently accessed entries and removes
-// them when the number of entries reaches the limit
-type poolEntryQueue struct {
- queue map[int]*poolEntry // known nodes indexed by their latest lastConnCnt value
- newPtr, oldPtr, maxCnt int
- removeFromPool func(*poolEntry)
-}
-
-// newPoolEntryQueue returns a new poolEntryQueue
-func newPoolEntryQueue(maxCnt int, removeFromPool func(*poolEntry)) poolEntryQueue {
- return poolEntryQueue{queue: make(map[int]*poolEntry), maxCnt: maxCnt, removeFromPool: removeFromPool}
-}
-
-// fetchOldest returns and removes the least recently accessed entry
-func (q *poolEntryQueue) fetchOldest() *poolEntry {
- if len(q.queue) == 0 {
- return nil
- }
- for {
- if e := q.queue[q.oldPtr]; e != nil {
- delete(q.queue, q.oldPtr)
- q.oldPtr++
- return e
+ unixTime := s.unixTime()
+ s.ns.ForEach(sfHasValue, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) {
+ s.calculateWeight(node)
+ if n, ok := s.ns.GetField(node, sfiNodeHistory).(nodeHistory); ok && n.redialWaitEnd > unixTime {
+ wait := n.redialWaitEnd - unixTime
+ lastWait := n.redialWaitEnd - n.redialWaitStart
+ if wait > lastWait {
+ // if the time until expiration is larger than the last suggested
+ // waiting time then the system clock was probably adjusted
+ wait = lastWait
+ }
+ s.ns.SetState(node, sfRedialWait, nodestate.Flags{}, time.Duration(wait)*time.Second)
}
- q.oldPtr++
- }
+ })
}
-// remove removes an entry from the queue
-func (q *poolEntryQueue) remove(entry *poolEntry) {
- if q.queue[entry.queueIdx] == entry {
- delete(q.queue, entry.queueIdx)
+// stop stops the server pool
+func (s *serverPool) stop() {
+ s.dialIterator.Close()
+ if s.fillSet != nil {
+ s.fillSet.Close()
}
+ s.ns.ForEach(sfConnected, nodestate.Flags{}, func(n *enode.Node, state nodestate.Flags) {
+ // recalculate weight of connected nodes in order to update hasValue flag if necessary
+ s.calculateWeight(n)
+ })
+ s.ns.Stop()
}
-// setLatest adds or updates a recently accessed entry. It also checks if an old entry
-// needs to be removed and removes it from the parent pool too with a callback function.
-func (q *poolEntryQueue) setLatest(entry *poolEntry) {
- if q.queue[entry.queueIdx] == entry {
- delete(q.queue, entry.queueIdx)
+// registerPeer implements serverPeerSubscriber
+func (s *serverPool) registerPeer(p *serverPeer) {
+ s.ns.SetState(p.Node(), sfConnected, sfDialing.Or(sfWaitDialTimeout), 0)
+ nvt := s.vt.Register(p.ID())
+ s.ns.SetField(p.Node(), sfiConnectedStats, nvt.RtStats())
+ p.setValueTracker(s.vt, nvt)
+ p.updateVtParams()
+}
+
+// unregisterPeer implements serverPeerSubscriber
+func (s *serverPool) unregisterPeer(p *serverPeer) {
+ s.setRedialWait(p.Node(), dialCost, dialWaitStep)
+ s.ns.SetState(p.Node(), nodestate.Flags{}, sfConnected, 0)
+ s.ns.SetField(p.Node(), sfiConnectedStats, nil)
+ s.vt.Unregister(p.ID())
+ p.setValueTracker(nil, nil)
+}
+
+// recalTimeout calculates the current recommended timeout. This value is used by
+// the client as a "soft timeout" value. It also affects the service value calculation
+// of individual nodes.
+func (s *serverPool) recalTimeout() {
+ // Use cached result if possible, avoid recalculating too frequently.
+ s.timeoutLock.RLock()
+ refreshed := s.timeoutRefreshed
+ s.timeoutLock.RUnlock()
+ now := s.clock.Now()
+ if refreshed != 0 && time.Duration(now-refreshed) < timeoutRefresh {
+ return
+ }
+ // Cached result is stale, recalculate a new one.
+ rts := s.vt.RtStats()
+
+ // Add a fake statistic here. It is an easy way to initialize with some
+ // conservative values when the database is new. As soon as we have a
+ // considerable amount of real stats this small value won't matter.
+ rts.Add(time.Second*2, 10, s.vt.StatsExpFactor())
+
+ // Use either 10% failure rate timeout or twice the median response time
+ // as the recommended timeout.
+ timeout := minTimeout
+ if t := rts.Timeout(0.1); t > timeout {
+ timeout = t
+ }
+ if t := rts.Timeout(0.5) * 2; t > timeout {
+ timeout = t
+ }
+ s.timeoutLock.Lock()
+ if s.timeout != timeout {
+ s.timeout = timeout
+ s.timeWeights = lpc.TimeoutWeights(s.timeout)
+
+ suggestedTimeoutGauge.Update(int64(s.timeout / time.Millisecond))
+ totalValueGauge.Update(int64(rts.Value(s.timeWeights, s.vt.StatsExpFactor())))
+ }
+ s.timeoutRefreshed = now
+ s.timeoutLock.Unlock()
+}
+
+// getTimeout returns the recommended request timeout.
+func (s *serverPool) getTimeout() time.Duration {
+ s.recalTimeout()
+ s.timeoutLock.RLock()
+ defer s.timeoutLock.RUnlock()
+ return s.timeout
+}
+
+// getTimeoutAndWeight returns the recommended request timeout as well as the
+// response time weight which is necessary to calculate service value.
+func (s *serverPool) getTimeoutAndWeight() (time.Duration, lpc.ResponseTimeWeights) {
+ s.recalTimeout()
+ s.timeoutLock.RLock()
+ defer s.timeoutLock.RUnlock()
+ return s.timeout, s.timeWeights
+}
+
+// addDialCost adds the given amount of dial cost to the node history and returns the current
+// amount of total dial cost
+func (s *serverPool) addDialCost(n *nodeHistory, amount int64) uint64 {
+ logOffset := s.vt.StatsExpirer().LogOffset(s.clock.Now())
+ if amount > 0 {
+ n.dialCost.Add(amount, logOffset)
+ }
+ totalDialCost := n.dialCost.Value(logOffset)
+ if totalDialCost < dialCost {
+ totalDialCost = dialCost
+ }
+ return totalDialCost
+}
+
+// serviceValue returns the service value accumulated in this session and in total
+func (s *serverPool) serviceValue(node *enode.Node) (sessionValue, totalValue float64) {
+ nvt := s.vt.GetNode(node.ID())
+ if nvt == nil {
+ return 0, 0
+ }
+ currentStats := nvt.RtStats()
+ _, timeWeights := s.getTimeoutAndWeight()
+ expFactor := s.vt.StatsExpFactor()
+
+ totalValue = currentStats.Value(timeWeights, expFactor)
+ if connStats, ok := s.ns.GetField(node, sfiConnectedStats).(lpc.ResponseTimeStats); ok {
+ diff := currentStats
+ diff.SubStats(&connStats)
+ sessionValue = diff.Value(timeWeights, expFactor)
+ sessionValueMeter.Mark(int64(sessionValue))
+ }
+ return
+}
+
+// updateWeight calculates the node weight and updates the nodeWeight field and the
+// hasValue flag. It also saves the node state if necessary.
+func (s *serverPool) updateWeight(node *enode.Node, totalValue float64, totalDialCost uint64) {
+ weight := uint64(totalValue * nodeWeightMul / float64(totalDialCost))
+ if weight >= nodeWeightThreshold {
+ s.ns.SetState(node, sfHasValue, nodestate.Flags{}, 0)
+ s.ns.SetField(node, sfiNodeWeight, weight)
} else {
- if len(q.queue) == q.maxCnt {
- e := q.fetchOldest()
- q.remove(e)
- q.removeFromPool(e)
- }
+ s.ns.SetState(node, nodestate.Flags{}, sfHasValue, 0)
+ s.ns.SetField(node, sfiNodeWeight, nil)
}
- entry.queueIdx = q.newPtr
- q.queue[entry.queueIdx] = entry
- q.newPtr++
+ s.ns.Persist(node) // saved if node history or hasValue changed
+}
+
+// setRedialWait calculates and sets the redialWait timeout based on the service value
+// and dial cost accumulated during the last session/attempt and in total.
+// The waiting time is raised exponentially if no service value has been received in order
+// to prevent dialing an unresponsive node frequently for a very long time just because it
+// was useful in the past. It can still be occasionally dialed though and once it provides
+// a significant amount of service value again its waiting time is quickly reduced or reset
+// to the minimum.
+// Note: node weight is also recalculated and updated by this function.
+func (s *serverPool) setRedialWait(node *enode.Node, addDialCost int64, waitStep float64) {
+ n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory)
+ sessionValue, totalValue := s.serviceValue(node)
+ totalDialCost := s.addDialCost(&n, addDialCost)
+
+ // if the current dial session has yielded at least the average value/dial cost ratio
+ // then the waiting time should be reset to the minimum. If the session value
+ // is below average but still positive then timeout is limited to the ratio of
+ // average / current service value multiplied by the minimum timeout. If the attempt
+ // was unsuccessful then timeout is raised exponentially without limitation.
+ // Note: dialCost is used in the formula below even if dial was not attempted at all
+ // because the pre-negotiation query did not return a positive result. In this case
+ // the ratio has no meaning anyway and waitFactor is always raised, though in smaller
+ // steps because queries are cheaper and therefore we can allow more failed attempts.
+ unixTime := s.unixTime()
+ plannedTimeout := float64(n.redialWaitEnd - n.redialWaitStart) // last planned redialWait timeout
+ var actualWait float64 // actual waiting time elapsed
+ if unixTime > n.redialWaitEnd {
+ // the planned timeout has elapsed
+ actualWait = plannedTimeout
+ } else {
+ // if the node was redialed earlier then we do not raise the planned timeout
+ // exponentially because that could lead to the timeout rising very high in
+ // a short amount of time
+ // Note that in case of an early redial actualWait also includes the dial
+ // timeout or connection time of the last attempt but it still serves its
+ // purpose of preventing the timeout rising quicker than linearly as a function
+ // of total time elapsed without a successful connection.
+ actualWait = float64(unixTime - n.redialWaitStart)
+ }
+ // raise timeout exponentially if the last planned timeout has elapsed
+ // (use at least the last planned timeout otherwise)
+ nextTimeout := actualWait * waitStep
+ if plannedTimeout > nextTimeout {
+ nextTimeout = plannedTimeout
+ }
+ // we reduce the waiting time if the server has provided service value during the
+ // connection (but never under the minimum)
+ a := totalValue * dialCost * float64(minRedialWait)
+ b := float64(totalDialCost) * sessionValue
+ if a < b*nextTimeout {
+ nextTimeout = a / b
+ }
+ if nextTimeout < minRedialWait {
+ nextTimeout = minRedialWait
+ }
+ wait := time.Duration(float64(time.Second) * nextTimeout)
+ if wait < waitThreshold {
+ n.redialWaitStart = unixTime
+ n.redialWaitEnd = unixTime + int64(nextTimeout)
+ s.ns.SetField(node, sfiNodeHistory, n)
+ s.ns.SetState(node, sfRedialWait, nodestate.Flags{}, wait)
+ s.updateWeight(node, totalValue, totalDialCost)
+ } else {
+ // discard known node statistics if waiting time is very long because the node
+ // hasn't been responsive for a very long time
+ s.ns.SetField(node, sfiNodeHistory, nil)
+ s.ns.SetField(node, sfiNodeWeight, nil)
+ s.ns.SetState(node, nodestate.Flags{}, sfHasValue, 0)
+ }
+}
+
+// calculateWeight calculates and sets the node weight without altering the node history.
+// This function should be called during startup and shutdown only, otherwise setRedialWait
+// will keep the weights updated as the underlying statistics are adjusted.
+func (s *serverPool) calculateWeight(node *enode.Node) {
+ n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory)
+ _, totalValue := s.serviceValue(node)
+ totalDialCost := s.addDialCost(&n, 0)
+ s.updateWeight(node, totalValue, totalDialCost)
}
diff --git a/les/serverpool_test.go b/les/serverpool_test.go
new file mode 100644
index 000000000..3d0487d10
--- /dev/null
+++ b/les/serverpool_test.go
@@ -0,0 +1,352 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package les
+
+import (
+ "math/rand"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common/mclock"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/ethdb/memorydb"
+ lpc "github.com/ethereum/go-ethereum/les/lespay/client"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+)
+
+const (
+ spTestNodes = 1000
+ spTestTarget = 5
+ spTestLength = 10000
+ spMinTotal = 40000
+ spMaxTotal = 50000
+)
+
+func testNodeID(i int) enode.ID {
+ return enode.ID{42, byte(i % 256), byte(i / 256)}
+}
+
+func testNodeIndex(id enode.ID) int {
+ if id[0] != 42 {
+ return -1
+ }
+ return int(id[1]) + int(id[2])*256
+}
+
+type serverPoolTest struct {
+ db ethdb.KeyValueStore
+ clock *mclock.Simulated
+ quit chan struct{}
+ preNeg, preNegFail bool
+ vt *lpc.ValueTracker
+ sp *serverPool
+ input enode.Iterator
+ testNodes []spTestNode
+ trusted []string
+ waitCount, waitEnded int32
+
+ cycle, conn, servedConn int
+ serviceCycles, dialCount int
+ disconnect map[int][]int
+}
+
+type spTestNode struct {
+ connectCycles, waitCycles int
+ nextConnCycle, totalConn int
+ connected, service bool
+ peer *serverPeer
+}
+
+func newServerPoolTest(preNeg, preNegFail bool) *serverPoolTest {
+ nodes := make([]*enode.Node, spTestNodes)
+ for i := range nodes {
+ nodes[i] = enode.SignNull(&enr.Record{}, testNodeID(i))
+ }
+ return &serverPoolTest{
+ clock: &mclock.Simulated{},
+ db: memorydb.New(),
+ input: enode.CycleNodes(nodes),
+ testNodes: make([]spTestNode, spTestNodes),
+ preNeg: preNeg,
+ preNegFail: preNegFail,
+ }
+}
+
+func (s *serverPoolTest) beginWait() {
+ // ensure that dialIterator and the maximal number of pre-neg queries are not all stuck in a waiting state
+ for atomic.AddInt32(&s.waitCount, 1) > preNegLimit {
+ atomic.AddInt32(&s.waitCount, -1)
+ s.clock.Run(time.Second)
+ }
+}
+
+func (s *serverPoolTest) endWait() {
+ atomic.AddInt32(&s.waitCount, -1)
+ atomic.AddInt32(&s.waitEnded, 1)
+}
+
+func (s *serverPoolTest) addTrusted(i int) {
+ s.trusted = append(s.trusted, enode.SignNull(&enr.Record{}, testNodeID(i)).String())
+}
+
+func (s *serverPoolTest) start() {
+ var testQuery queryFunc
+ if s.preNeg {
+ testQuery = func(node *enode.Node) int {
+ idx := testNodeIndex(node.ID())
+ n := &s.testNodes[idx]
+ canConnect := !n.connected && n.connectCycles != 0 && s.cycle >= n.nextConnCycle
+ if s.preNegFail {
+ // simulate a scenario where UDP queries never work
+ s.beginWait()
+ s.clock.Sleep(time.Second * 5)
+ s.endWait()
+ return -1
+ } else {
+ switch idx % 3 {
+ case 0:
+ // pre-neg returns true only if connection is possible
+ if canConnect {
+ return 1
+ } else {
+ return 0
+ }
+ case 1:
+ // pre-neg returns true but connection might still fail
+ return 1
+ case 2:
+ // pre-neg returns true if connection is possible, otherwise timeout (node unresponsive)
+ if canConnect {
+ return 1
+ } else {
+ s.beginWait()
+ s.clock.Sleep(time.Second * 5)
+ s.endWait()
+ return -1
+ }
+ }
+ return -1
+ }
+ }
+ }
+
+ s.vt = lpc.NewValueTracker(s.db, s.clock, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000))
+ s.sp = newServerPool(s.db, []byte("serverpool:"), s.vt, s.input, 0, testQuery, s.clock, s.trusted)
+ s.sp.validSchemes = enode.ValidSchemesForTesting
+ s.sp.unixTime = func() int64 { return int64(s.clock.Now()) / int64(time.Second) }
+ s.disconnect = make(map[int][]int)
+ s.sp.start()
+ s.quit = make(chan struct{})
+ go func() {
+ last := int32(-1)
+ for {
+ select {
+ case <-time.After(time.Millisecond * 100):
+ c := atomic.LoadInt32(&s.waitEnded)
+ if c == last {
+ // advance clock if test is stuck (might happen in rare cases)
+ s.clock.Run(time.Second)
+ }
+ last = c
+ case <-s.quit:
+ return
+ }
+ }
+ }()
+}
+
+func (s *serverPoolTest) stop() {
+ close(s.quit)
+ s.sp.stop()
+ s.vt.Stop()
+ for i := range s.testNodes {
+ n := &s.testNodes[i]
+ if n.connected {
+ n.totalConn += s.cycle
+ }
+ n.connected = false
+ n.peer = nil
+ n.nextConnCycle = 0
+ }
+ s.conn, s.servedConn = 0, 0
+}
+
+func (s *serverPoolTest) run() {
+ for count := spTestLength; count > 0; count-- {
+ if dcList := s.disconnect[s.cycle]; dcList != nil {
+ for _, idx := range dcList {
+ n := &s.testNodes[idx]
+ s.sp.unregisterPeer(n.peer)
+ n.totalConn += s.cycle
+ n.connected = false
+ n.peer = nil
+ s.conn--
+ if n.service {
+ s.servedConn--
+ }
+ n.nextConnCycle = s.cycle + n.waitCycles
+ }
+ delete(s.disconnect, s.cycle)
+ }
+ if s.conn < spTestTarget {
+ s.dialCount++
+ s.beginWait()
+ s.sp.dialIterator.Next()
+ s.endWait()
+ dial := s.sp.dialIterator.Node()
+ id := dial.ID()
+ idx := testNodeIndex(id)
+ n := &s.testNodes[idx]
+ if !n.connected && n.connectCycles != 0 && s.cycle >= n.nextConnCycle {
+ s.conn++
+ if n.service {
+ s.servedConn++
+ }
+ n.totalConn -= s.cycle
+ n.connected = true
+ dc := s.cycle + n.connectCycles
+ s.disconnect[dc] = append(s.disconnect[dc], idx)
+ n.peer = &serverPeer{peerCommons: peerCommons{Peer: p2p.NewPeer(id, "", nil)}}
+ s.sp.registerPeer(n.peer)
+ if n.service {
+ s.vt.Served(s.vt.GetNode(id), []lpc.ServedRequest{{ReqType: 0, Amount: 100}}, 0)
+ }
+ }
+ }
+ s.serviceCycles += s.servedConn
+ s.clock.Run(time.Second)
+ s.cycle++
+ }
+}
+
+func (s *serverPoolTest) setNodes(count, conn, wait int, service, trusted bool) (res []int) {
+ for ; count > 0; count-- {
+ idx := rand.Intn(spTestNodes)
+ for s.testNodes[idx].connectCycles != 0 || s.testNodes[idx].connected {
+ idx = rand.Intn(spTestNodes)
+ }
+ res = append(res, idx)
+ s.testNodes[idx] = spTestNode{
+ connectCycles: conn,
+ waitCycles: wait,
+ service: service,
+ }
+ if trusted {
+ s.addTrusted(idx)
+ }
+ }
+ return
+}
+
+func (s *serverPoolTest) resetNodes() {
+ for i, n := range s.testNodes {
+ if n.connected {
+ n.totalConn += s.cycle
+ s.sp.unregisterPeer(n.peer)
+ }
+ s.testNodes[i] = spTestNode{totalConn: n.totalConn}
+ }
+ s.conn, s.servedConn = 0, 0
+ s.disconnect = make(map[int][]int)
+ s.trusted = nil
+}
+
+func (s *serverPoolTest) checkNodes(t *testing.T, nodes []int) {
+ var sum int
+ for _, idx := range nodes {
+ n := &s.testNodes[idx]
+ if n.connected {
+ n.totalConn += s.cycle
+ }
+ sum += n.totalConn
+ n.totalConn = 0
+ if n.connected {
+ n.totalConn -= s.cycle
+ }
+ }
+ if sum < spMinTotal || sum > spMaxTotal {
+ t.Errorf("Total connection amount %d outside expected range %d to %d", sum, spMinTotal, spMaxTotal)
+ }
+}
+
+func TestServerPool(t *testing.T) { testServerPool(t, false, false) }
+func TestServerPoolWithPreNeg(t *testing.T) { testServerPool(t, true, false) }
+func TestServerPoolWithPreNegFail(t *testing.T) { testServerPool(t, true, true) }
+func testServerPool(t *testing.T, preNeg, fail bool) {
+ s := newServerPoolTest(preNeg, fail)
+ nodes := s.setNodes(100, 200, 200, true, false)
+ s.setNodes(100, 20, 20, false, false)
+ s.start()
+ s.run()
+ s.stop()
+ s.checkNodes(t, nodes)
+}
+
+func TestServerPoolChangedNodes(t *testing.T) { testServerPoolChangedNodes(t, false) }
+func TestServerPoolChangedNodesWithPreNeg(t *testing.T) { testServerPoolChangedNodes(t, true) }
+func testServerPoolChangedNodes(t *testing.T, preNeg bool) {
+ s := newServerPoolTest(preNeg, false)
+ nodes := s.setNodes(100, 200, 200, true, false)
+ s.setNodes(100, 20, 20, false, false)
+ s.start()
+ s.run()
+ s.checkNodes(t, nodes)
+ for i := 0; i < 3; i++ {
+ s.resetNodes()
+ nodes := s.setNodes(100, 200, 200, true, false)
+ s.setNodes(100, 20, 20, false, false)
+ s.run()
+ s.checkNodes(t, nodes)
+ }
+ s.stop()
+}
+
+func TestServerPoolRestartNoDiscovery(t *testing.T) { testServerPoolRestartNoDiscovery(t, false) }
+func TestServerPoolRestartNoDiscoveryWithPreNeg(t *testing.T) {
+ testServerPoolRestartNoDiscovery(t, true)
+}
+func testServerPoolRestartNoDiscovery(t *testing.T, preNeg bool) {
+ s := newServerPoolTest(preNeg, false)
+ nodes := s.setNodes(100, 200, 200, true, false)
+ s.setNodes(100, 20, 20, false, false)
+ s.start()
+ s.run()
+ s.stop()
+ s.checkNodes(t, nodes)
+ s.input = nil
+ s.start()
+ s.run()
+ s.stop()
+ s.checkNodes(t, nodes)
+}
+
+func TestServerPoolTrustedNoDiscovery(t *testing.T) { testServerPoolTrustedNoDiscovery(t, false) }
+func TestServerPoolTrustedNoDiscoveryWithPreNeg(t *testing.T) {
+ testServerPoolTrustedNoDiscovery(t, true)
+}
+func testServerPoolTrustedNoDiscovery(t *testing.T, preNeg bool) {
+ s := newServerPoolTest(preNeg, false)
+ trusted := s.setNodes(200, 200, 200, true, true)
+ s.input = nil
+ s.start()
+ s.run()
+ s.stop()
+ s.checkNodes(t, trusted)
+}
diff --git a/les/test_helper.go b/les/test_helper.go
index 1f02d2529..2a2bbb440 100644
--- a/les/test_helper.go
+++ b/les/test_helper.go
@@ -508,7 +508,7 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexer
clock = &mclock.Simulated{}
}
dist := newRequestDistributor(speers, clock)
- rm := newRetrieveManager(speers, dist, nil)
+ rm := newRetrieveManager(speers, dist, func() time.Duration { return time.Millisecond * 500 })
odr := NewLesOdr(cdb, light.TestClientIndexerConfig, rm)
sindexers := testIndexers(sdb, nil, light.TestServerIndexerConfig)
diff --git a/les/utils/expiredvalue.go b/les/utils/expiredvalue.go
index 85f9b88b7..a58587368 100644
--- a/les/utils/expiredvalue.go
+++ b/les/utils/expiredvalue.go
@@ -63,14 +63,7 @@ func ExpFactor(logOffset Fixed64) ExpirationFactor {
// Value calculates the expired value based on a floating point base and integer
// power-of-2 exponent. This function should be used by multi-value expired structures.
func (e ExpirationFactor) Value(base float64, exp uint64) float64 {
- res := base / e.Factor
- if exp > e.Exp {
- res *= float64(uint64(1) << (exp - e.Exp))
- }
- if exp < e.Exp {
- res /= float64(uint64(1) << (e.Exp - exp))
- }
- return res
+ return base / e.Factor * math.Pow(2, float64(int64(exp-e.Exp)))
}
// value calculates the value at the given moment.
diff --git a/les/utils/weighted_select.go b/les/utils/weighted_select.go
index fbf1f37d6..d6db3c0e6 100644
--- a/les/utils/weighted_select.go
+++ b/les/utils/weighted_select.go
@@ -16,28 +16,44 @@
package utils
-import "math/rand"
+import (
+ "math/rand"
+)
-// wrsItem interface should be implemented by any entries that are to be selected from
-// a WeightedRandomSelect set. Note that recalculating monotonously decreasing item
-// weights on-demand (without constantly calling Update) is allowed
-type wrsItem interface {
- Weight() int64
-}
-
-// WeightedRandomSelect is capable of weighted random selection from a set of items
-type WeightedRandomSelect struct {
- root *wrsNode
- idx map[wrsItem]int
-}
+type (
+ // WeightedRandomSelect is capable of weighted random selection from a set of items
+ WeightedRandomSelect struct {
+ root *wrsNode
+ idx map[WrsItem]int
+ wfn WeightFn
+ }
+ WrsItem interface{}
+ WeightFn func(interface{}) uint64
+)
// NewWeightedRandomSelect returns a new WeightedRandomSelect structure
-func NewWeightedRandomSelect() *WeightedRandomSelect {
- return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[wrsItem]int)}
+func NewWeightedRandomSelect(wfn WeightFn) *WeightedRandomSelect {
+ return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[WrsItem]int), wfn: wfn}
+}
+
+// Update updates an item's weight, adds it if it was non-existent or removes it if
+// the new weight is zero. Note that explicitly updating decreasing weights is not necessary.
+func (w *WeightedRandomSelect) Update(item WrsItem) {
+ w.setWeight(item, w.wfn(item))
+}
+
+// Remove removes an item from the set
+func (w *WeightedRandomSelect) Remove(item WrsItem) {
+ w.setWeight(item, 0)
+}
+
+// IsEmpty returns true if the set is empty
+func (w *WeightedRandomSelect) IsEmpty() bool {
+ return w.root.sumWeight == 0
}
// setWeight sets an item's weight to a specific value (removes it if zero)
-func (w *WeightedRandomSelect) setWeight(item wrsItem, weight int64) {
+func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) {
idx, ok := w.idx[item]
if ok {
w.root.setWeight(idx, weight)
@@ -58,33 +74,22 @@ func (w *WeightedRandomSelect) setWeight(item wrsItem, weight int64) {
}
}
-// Update updates an item's weight, adds it if it was non-existent or removes it if
-// the new weight is zero. Note that explicitly updating decreasing weights is not necessary.
-func (w *WeightedRandomSelect) Update(item wrsItem) {
- w.setWeight(item, item.Weight())
-}
-
-// Remove removes an item from the set
-func (w *WeightedRandomSelect) Remove(item wrsItem) {
- w.setWeight(item, 0)
-}
-
// Choose randomly selects an item from the set, with a chance proportional to its
// current weight. If the weight of the chosen element has been decreased since the
// last stored value, returns it with a newWeight/oldWeight chance, otherwise just
// updates its weight and selects another one
-func (w *WeightedRandomSelect) Choose() wrsItem {
+func (w *WeightedRandomSelect) Choose() WrsItem {
for {
if w.root.sumWeight == 0 {
return nil
}
- val := rand.Int63n(w.root.sumWeight)
+ val := uint64(rand.Int63n(int64(w.root.sumWeight)))
choice, lastWeight := w.root.choose(val)
- weight := choice.Weight()
+ weight := w.wfn(choice)
if weight != lastWeight {
w.setWeight(choice, weight)
}
- if weight >= lastWeight || rand.Int63n(lastWeight) < weight {
+ if weight >= lastWeight || uint64(rand.Int63n(int64(lastWeight))) < weight {
return choice
}
}
@@ -92,16 +97,16 @@ func (w *WeightedRandomSelect) Choose() wrsItem {
const wrsBranches = 8 // max number of branches in the wrsNode tree
-// wrsNode is a node of a tree structure that can store wrsItems or further wrsNodes.
+// wrsNode is a node of a tree structure that can store WrsItems or further wrsNodes.
type wrsNode struct {
items [wrsBranches]interface{}
- weights [wrsBranches]int64
- sumWeight int64
+ weights [wrsBranches]uint64
+ sumWeight uint64
level, itemCnt, maxItems int
}
// insert recursively inserts a new item to the tree and returns the item index
-func (n *wrsNode) insert(item wrsItem, weight int64) int {
+func (n *wrsNode) insert(item WrsItem, weight uint64) int {
branch := 0
for n.items[branch] != nil && (n.level == 0 || n.items[branch].(*wrsNode).itemCnt == n.items[branch].(*wrsNode).maxItems) {
branch++
@@ -129,7 +134,7 @@ func (n *wrsNode) insert(item wrsItem, weight int64) int {
// setWeight updates the weight of a certain item (which should exist) and returns
// the change of the last weight value stored in the tree
-func (n *wrsNode) setWeight(idx int, weight int64) int64 {
+func (n *wrsNode) setWeight(idx int, weight uint64) uint64 {
if n.level == 0 {
oldWeight := n.weights[idx]
n.weights[idx] = weight
@@ -152,12 +157,12 @@ func (n *wrsNode) setWeight(idx int, weight int64) int64 {
return diff
}
-// Choose recursively selects an item from the tree and returns it along with its weight
-func (n *wrsNode) choose(val int64) (wrsItem, int64) {
+// choose recursively selects an item from the tree and returns it along with its weight
+func (n *wrsNode) choose(val uint64) (WrsItem, uint64) {
for i, w := range n.weights {
if val < w {
if n.level == 0 {
- return n.items[i].(wrsItem), n.weights[i]
+ return n.items[i].(WrsItem), n.weights[i]
}
return n.items[i].(*wrsNode).choose(val)
}
diff --git a/les/utils/weighted_select_test.go b/les/utils/weighted_select_test.go
index e1969e1a6..3e1c0ad98 100644
--- a/les/utils/weighted_select_test.go
+++ b/les/utils/weighted_select_test.go
@@ -26,17 +26,18 @@ type testWrsItem struct {
widx *int
}
-func (t *testWrsItem) Weight() int64 {
+func testWeight(i interface{}) uint64 {
+ t := i.(*testWrsItem)
w := *t.widx
if w == -1 || w == t.idx {
- return int64(t.idx + 1)
+ return uint64(t.idx + 1)
}
return 0
}
func TestWeightedRandomSelect(t *testing.T) {
testFn := func(cnt int) {
- s := NewWeightedRandomSelect()
+ s := NewWeightedRandomSelect(testWeight)
w := -1
list := make([]testWrsItem, cnt)
for i := range list {
diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go
new file mode 100644
index 000000000..7091281ae
--- /dev/null
+++ b/p2p/nodestate/nodestate.go
@@ -0,0 +1,880 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package nodestate
+
+import (
+ "errors"
+ "reflect"
+ "sync"
+ "time"
+ "unsafe"
+
+ "github.com/ethereum/go-ethereum/common/mclock"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/metrics"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+type (
+ // NodeStateMachine connects different system components operating on subsets of
+ // network nodes. Node states are represented by 64 bit vectors with each bit assigned
+ // to a state flag. Each state flag has a descriptor structure and the mapping is
+ // created automatically. It is possible to subscribe to subsets of state flags and
+ // receive a callback if one of the nodes has a relevant state flag changed.
+ // Callbacks can also modify further flags of the same node or other nodes. State
+ // updates only return after all immediate effects throughout the system have happened
+ // (deadlocks should be avoided by design of the implemented state logic). The caller
+ // can also add timeouts assigned to a certain node and a subset of state flags.
+ // If the timeout elapses, the flags are reset. If all relevant flags are reset then
+ // the timer is dropped. State flags with no timeout are persisted in the database
+ // if the flag descriptor enables saving. If a node has no state flags set at any
+ // moment then it is discarded.
+ //
+ // Extra node fields can also be registered so system components can also store more
+ // complex state for each node that is relevant to them, without creating a custom
+ // peer set. Fields can be shared across multiple components if they all know the
+ // field ID. Subscription to fields is also possible. Persistent fields should have
+ // an encoder and a decoder function.
+ NodeStateMachine struct {
+ started, stopped bool
+ lock sync.Mutex
+ clock mclock.Clock
+ db ethdb.KeyValueStore
+ dbNodeKey []byte
+ nodes map[enode.ID]*nodeInfo
+ offlineCallbackList []offlineCallback
+
+ // Registered state flags or fields. Modifications are allowed
+ // only when the node state machine has not been started.
+ setup *Setup
+ fields []*fieldInfo
+ saveFlags bitMask
+
+ // Installed callbacks. Modifications are allowed only when the
+ // node state machine has not been started.
+ stateSubs []stateSub
+
+ // Testing hooks, only for testing purposes.
+ saveNodeHook func(*nodeInfo)
+ }
+
+ // Flags represents a set of flags from a certain setup
+ Flags struct {
+ mask bitMask
+ setup *Setup
+ }
+
+ // Field represents a field from a certain setup
+ Field struct {
+ index int
+ setup *Setup
+ }
+
+ // flagDefinition describes a node state flag. Each registered instance is automatically
+ // mapped to a bit of the 64 bit node states.
+ // If persistent is true then the node is saved when state machine is shutdown.
+ flagDefinition struct {
+ name string
+ persistent bool
+ }
+
+ // fieldDefinition describes an optional node field of the given type. The contents
+ // of the field are only retained for each node as long as at least one of the
+ // state flags is set.
+ fieldDefinition struct {
+ name string
+ ftype reflect.Type
+ encode func(interface{}) ([]byte, error)
+ decode func([]byte) (interface{}, error)
+ }
+
+ // stateSetup contains the list of flags and fields used by the application
+ Setup struct {
+ Version uint
+ flags []flagDefinition
+ fields []fieldDefinition
+ }
+
+ // bitMask describes a node state or state mask. It represents a subset
+ // of node flags with each bit assigned to a flag index (LSB represents flag 0).
+ bitMask uint64
+
+ // StateCallback is a subscription callback which is called when one of the
+ // state flags that is included in the subscription state mask is changed.
+ // Note: oldState and newState are also masked with the subscription mask so only
+ // the relevant bits are included.
+ StateCallback func(n *enode.Node, oldState, newState Flags)
+
+ // FieldCallback is a subscription callback which is called when the value of
+ // a specific field is changed.
+ FieldCallback func(n *enode.Node, state Flags, oldValue, newValue interface{})
+
+ // nodeInfo contains node state, fields and state timeouts
+ nodeInfo struct {
+ node *enode.Node
+ state bitMask
+ timeouts []*nodeStateTimeout
+ fields []interface{}
+ db, dirty bool
+ }
+
+ nodeInfoEnc struct {
+ Enr enr.Record
+ Version uint
+ State bitMask
+ Fields [][]byte
+ }
+
+ stateSub struct {
+ mask bitMask
+ callback StateCallback
+ }
+
+ nodeStateTimeout struct {
+ mask bitMask
+ timer mclock.Timer
+ }
+
+ fieldInfo struct {
+ fieldDefinition
+ subs []FieldCallback
+ }
+
+ offlineCallback struct {
+ node *enode.Node
+ state bitMask
+ fields []interface{}
+ }
+)
+
+// offlineState is a special state that is assumed to be set before a node is loaded from
+// the database and after it is shut down.
+const offlineState = bitMask(1)
+
+// NewFlag creates a new node state flag
+func (s *Setup) NewFlag(name string) Flags {
+ if s.flags == nil {
+ s.flags = []flagDefinition{{name: "offline"}}
+ }
+ f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s}
+ s.flags = append(s.flags, flagDefinition{name: name})
+ return f
+}
+
+// NewPersistentFlag creates a new persistent node state flag
+func (s *Setup) NewPersistentFlag(name string) Flags {
+ if s.flags == nil {
+ s.flags = []flagDefinition{{name: "offline"}}
+ }
+ f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s}
+ s.flags = append(s.flags, flagDefinition{name: name, persistent: true})
+ return f
+}
+
+// OfflineFlag returns the system-defined offline flag belonging to the given setup
+func (s *Setup) OfflineFlag() Flags {
+ return Flags{mask: offlineState, setup: s}
+}
+
+// NewField creates a new node state field
+func (s *Setup) NewField(name string, ftype reflect.Type) Field {
+ f := Field{index: len(s.fields), setup: s}
+ s.fields = append(s.fields, fieldDefinition{
+ name: name,
+ ftype: ftype,
+ })
+ return f
+}
+
+// NewPersistentField creates a new persistent node field
+func (s *Setup) NewPersistentField(name string, ftype reflect.Type, encode func(interface{}) ([]byte, error), decode func([]byte) (interface{}, error)) Field {
+ f := Field{index: len(s.fields), setup: s}
+ s.fields = append(s.fields, fieldDefinition{
+ name: name,
+ ftype: ftype,
+ encode: encode,
+ decode: decode,
+ })
+ return f
+}
+
+// flagOp implements binary flag operations and also checks whether the operands belong to the same setup
+func flagOp(a, b Flags, trueIfA, trueIfB, trueIfBoth bool) Flags {
+ if a.setup == nil {
+ if a.mask != 0 {
+ panic("Node state flags have no setup reference")
+ }
+ a.setup = b.setup
+ }
+ if b.setup == nil {
+ if b.mask != 0 {
+ panic("Node state flags have no setup reference")
+ }
+ b.setup = a.setup
+ }
+ if a.setup != b.setup {
+ panic("Node state flags belong to a different setup")
+ }
+ res := Flags{setup: a.setup}
+ if trueIfA {
+ res.mask |= a.mask & ^b.mask
+ }
+ if trueIfB {
+ res.mask |= b.mask & ^a.mask
+ }
+ if trueIfBoth {
+ res.mask |= a.mask & b.mask
+ }
+ return res
+}
+
+// And returns the set of flags present in both a and b
+func (a Flags) And(b Flags) Flags { return flagOp(a, b, false, false, true) }
+
+// AndNot returns the set of flags present in a but not in b
+func (a Flags) AndNot(b Flags) Flags { return flagOp(a, b, true, false, false) }
+
+// Or returns the set of flags present in either a or b
+func (a Flags) Or(b Flags) Flags { return flagOp(a, b, true, true, true) }
+
+// Xor returns the set of flags present in either a or b but not both
+func (a Flags) Xor(b Flags) Flags { return flagOp(a, b, true, true, false) }
+
+// HasAll returns true if b is a subset of a
+func (a Flags) HasAll(b Flags) bool { return flagOp(a, b, false, true, false).mask == 0 }
+
+// HasNone returns true if a and b have no shared flags
+func (a Flags) HasNone(b Flags) bool { return flagOp(a, b, false, false, true).mask == 0 }
+
+// Equals returns true if a and b have the same flags set
+func (a Flags) Equals(b Flags) bool { return flagOp(a, b, true, true, false).mask == 0 }
+
+// IsEmpty returns true if a has no flags set
+func (a Flags) IsEmpty() bool { return a.mask == 0 }
+
+// MergeFlags merges multiple sets of state flags
+func MergeFlags(list ...Flags) Flags {
+ if len(list) == 0 {
+ return Flags{}
+ }
+ res := list[0]
+ for i := 1; i < len(list); i++ {
+ res = res.Or(list[i])
+ }
+ return res
+}
+
+// String returns a list of the names of the flags specified in the bit mask
+func (f Flags) String() string {
+ if f.mask == 0 {
+ return "[]"
+ }
+ s := "["
+ comma := false
+ for index, flag := range f.setup.flags {
+ if f.mask&(bitMask(1)< 8*int(unsafe.Sizeof(bitMask(0))) {
+ panic("Too many node state flags")
+ }
+ ns := &NodeStateMachine{
+ db: db,
+ dbNodeKey: dbKey,
+ clock: clock,
+ setup: setup,
+ nodes: make(map[enode.ID]*nodeInfo),
+ fields: make([]*fieldInfo, len(setup.fields)),
+ }
+ stateNameMap := make(map[string]int)
+ for index, flag := range setup.flags {
+ if _, ok := stateNameMap[flag.name]; ok {
+ panic("Node state flag name collision")
+ }
+ stateNameMap[flag.name] = index
+ if flag.persistent {
+ ns.saveFlags |= bitMask(1) << uint(index)
+ }
+ }
+ fieldNameMap := make(map[string]int)
+ for index, field := range setup.fields {
+ if _, ok := fieldNameMap[field.name]; ok {
+ panic("Node field name collision")
+ }
+ ns.fields[index] = &fieldInfo{fieldDefinition: field}
+ fieldNameMap[field.name] = index
+ }
+ return ns
+}
+
+// stateMask checks whether the set of flags belongs to the same setup and returns its internal bit mask
+func (ns *NodeStateMachine) stateMask(flags Flags) bitMask {
+ if flags.setup != ns.setup && flags.mask != 0 {
+ panic("Node state flags belong to a different setup")
+ }
+ return flags.mask
+}
+
+// fieldIndex checks whether the field belongs to the same setup and returns its internal index
+func (ns *NodeStateMachine) fieldIndex(field Field) int {
+ if field.setup != ns.setup {
+ panic("Node field belongs to a different setup")
+ }
+ return field.index
+}
+
+// SubscribeState adds a node state subscription. The callback is called while the state
+// machine mutex is not held and it is allowed to make further state updates. All immediate
+// changes throughout the system are processed in the same thread/goroutine. It is the
+// responsibility of the implemented state logic to avoid deadlocks caused by the callbacks,
+// infinite toggling of flags or hazardous/non-deterministic state changes.
+// State subscriptions should be installed before loading the node database or making the
+// first state update.
+func (ns *NodeStateMachine) SubscribeState(flags Flags, callback StateCallback) {
+ ns.lock.Lock()
+ defer ns.lock.Unlock()
+
+ if ns.started {
+ panic("state machine already started")
+ }
+ ns.stateSubs = append(ns.stateSubs, stateSub{ns.stateMask(flags), callback})
+}
+
+// SubscribeField adds a node field subscription. Same rules apply as for SubscribeState.
+func (ns *NodeStateMachine) SubscribeField(field Field, callback FieldCallback) {
+ ns.lock.Lock()
+ defer ns.lock.Unlock()
+
+ if ns.started {
+ panic("state machine already started")
+ }
+ f := ns.fields[ns.fieldIndex(field)]
+ f.subs = append(f.subs, callback)
+}
+
+// newNode creates a new nodeInfo
+func (ns *NodeStateMachine) newNode(n *enode.Node) *nodeInfo {
+ return &nodeInfo{node: n, fields: make([]interface{}, len(ns.fields))}
+}
+
+// checkStarted checks whether the state machine has already been started and panics otherwise.
+func (ns *NodeStateMachine) checkStarted() {
+ if !ns.started {
+ panic("state machine not started yet")
+ }
+}
+
+// Start starts the state machine, enabling state and field operations and disabling
+// further subscriptions.
+func (ns *NodeStateMachine) Start() {
+ ns.lock.Lock()
+ if ns.started {
+ panic("state machine already started")
+ }
+ ns.started = true
+ if ns.db != nil {
+ ns.loadFromDb()
+ }
+ ns.lock.Unlock()
+ ns.offlineCallbacks(true)
+}
+
+// Stop stops the state machine and saves its state if a database was supplied
+func (ns *NodeStateMachine) Stop() {
+ ns.lock.Lock()
+ for _, node := range ns.nodes {
+ fields := make([]interface{}, len(node.fields))
+ copy(fields, node.fields)
+ ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields})
+ }
+ ns.stopped = true
+ if ns.db != nil {
+ ns.saveToDb()
+ ns.lock.Unlock()
+ } else {
+ ns.lock.Unlock()
+ }
+ ns.offlineCallbacks(false)
+}
+
+// loadFromDb loads persisted node states from the database
+func (ns *NodeStateMachine) loadFromDb() {
+ it := ns.db.NewIterator(ns.dbNodeKey, nil)
+ for it.Next() {
+ var id enode.ID
+ if len(it.Key()) != len(ns.dbNodeKey)+len(id) {
+ log.Error("Node state db entry with invalid length", "found", len(it.Key()), "expected", len(ns.dbNodeKey)+len(id))
+ continue
+ }
+ copy(id[:], it.Key()[len(ns.dbNodeKey):])
+ ns.decodeNode(id, it.Value())
+ }
+}
+
+type dummyIdentity enode.ID
+
+func (id dummyIdentity) Verify(r *enr.Record, sig []byte) error { return nil }
+func (id dummyIdentity) NodeAddr(r *enr.Record) []byte { return id[:] }
+
+// decodeNode decodes a node database entry and adds it to the node set if successful
+func (ns *NodeStateMachine) decodeNode(id enode.ID, data []byte) {
+ var enc nodeInfoEnc
+ if err := rlp.DecodeBytes(data, &enc); err != nil {
+ log.Error("Failed to decode node info", "id", id, "error", err)
+ return
+ }
+ n, _ := enode.New(dummyIdentity(id), &enc.Enr)
+ node := ns.newNode(n)
+ node.db = true
+
+ if enc.Version != ns.setup.Version {
+ log.Debug("Removing stored node with unknown version", "current", ns.setup.Version, "stored", enc.Version)
+ ns.deleteNode(id)
+ return
+ }
+ if len(enc.Fields) > len(ns.setup.fields) {
+ log.Error("Invalid node field count", "id", id, "stored", len(enc.Fields))
+ return
+ }
+ // Resolve persisted node fields
+ for i, encField := range enc.Fields {
+ if len(encField) == 0 {
+ continue
+ }
+ if decode := ns.fields[i].decode; decode != nil {
+ if field, err := decode(encField); err == nil {
+ node.fields[i] = field
+ } else {
+ log.Error("Failed to decode node field", "id", id, "field name", ns.fields[i].name, "error", err)
+ return
+ }
+ } else {
+ log.Error("Cannot decode node field", "id", id, "field name", ns.fields[i].name)
+ return
+ }
+ }
+ // It's a compatible node record, add it to set.
+ ns.nodes[id] = node
+ node.state = enc.State
+ fields := make([]interface{}, len(node.fields))
+ copy(fields, node.fields)
+ ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields})
+ log.Debug("Loaded node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup})
+}
+
+// saveNode saves the given node info to the database
+func (ns *NodeStateMachine) saveNode(id enode.ID, node *nodeInfo) error {
+ if ns.db == nil {
+ return nil
+ }
+
+ storedState := node.state & ns.saveFlags
+ for _, t := range node.timeouts {
+ storedState &= ^t.mask
+ }
+ if storedState == 0 {
+ if node.db {
+ node.db = false
+ ns.deleteNode(id)
+ }
+ node.dirty = false
+ return nil
+ }
+
+ enc := nodeInfoEnc{
+ Enr: *node.node.Record(),
+ Version: ns.setup.Version,
+ State: storedState,
+ Fields: make([][]byte, len(ns.fields)),
+ }
+ log.Debug("Saved node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup})
+ lastIndex := -1
+ for i, f := range node.fields {
+ if f == nil {
+ continue
+ }
+ encode := ns.fields[i].encode
+ if encode == nil {
+ continue
+ }
+ blob, err := encode(f)
+ if err != nil {
+ return err
+ }
+ enc.Fields[i] = blob
+ lastIndex = i
+ }
+ enc.Fields = enc.Fields[:lastIndex+1]
+ data, err := rlp.EncodeToBytes(&enc)
+ if err != nil {
+ return err
+ }
+ if err := ns.db.Put(append(ns.dbNodeKey, id[:]...), data); err != nil {
+ return err
+ }
+ node.dirty, node.db = false, true
+
+ if ns.saveNodeHook != nil {
+ ns.saveNodeHook(node)
+ }
+ return nil
+}
+
+// deleteNode removes a node info from the database
+func (ns *NodeStateMachine) deleteNode(id enode.ID) {
+ ns.db.Delete(append(ns.dbNodeKey, id[:]...))
+}
+
+// saveToDb saves the persistent flags and fields of all nodes that have been changed
+func (ns *NodeStateMachine) saveToDb() {
+ for id, node := range ns.nodes {
+ if node.dirty {
+ err := ns.saveNode(id, node)
+ if err != nil {
+ log.Error("Failed to save node", "id", id, "error", err)
+ }
+ }
+ }
+}
+
+// updateEnode updates the enode entry belonging to the given node if it already exists
+func (ns *NodeStateMachine) updateEnode(n *enode.Node) (enode.ID, *nodeInfo) {
+ id := n.ID()
+ node := ns.nodes[id]
+ if node != nil && n.Seq() > node.node.Seq() {
+ node.node = n
+ }
+ return id, node
+}
+
+// Persist saves the persistent state and fields of the given node immediately
+func (ns *NodeStateMachine) Persist(n *enode.Node) error {
+ ns.lock.Lock()
+ defer ns.lock.Unlock()
+
+ ns.checkStarted()
+ if id, node := ns.updateEnode(n); node != nil && node.dirty {
+ err := ns.saveNode(id, node)
+ if err != nil {
+ log.Error("Failed to save node", "id", id, "error", err)
+ }
+ return err
+ }
+ return nil
+}
+
+// SetState updates the given node state flags and processes all resulting callbacks.
+// It only returns after all subsequent immediate changes (including those changed by the
+// callbacks) have been processed. If a flag with a timeout is set again, the operation
+// removes or replaces the existing timeout.
+func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) {
+ ns.lock.Lock()
+ ns.checkStarted()
+ if ns.stopped {
+ ns.lock.Unlock()
+ return
+ }
+
+ set, reset := ns.stateMask(setFlags), ns.stateMask(resetFlags)
+ id, node := ns.updateEnode(n)
+ if node == nil {
+ if set == 0 {
+ ns.lock.Unlock()
+ return
+ }
+ node = ns.newNode(n)
+ ns.nodes[id] = node
+ }
+ oldState := node.state
+ newState := (node.state & (^reset)) | set
+ changed := oldState ^ newState
+ node.state = newState
+
+ // Remove the timeout callbacks for all reset and set flags,
+ // even they are not existent(it's noop).
+ ns.removeTimeouts(node, set|reset)
+
+ // Register the timeout callback if the new state is not empty
+ // and timeout itself is required.
+ if timeout != 0 && newState != 0 {
+ ns.addTimeout(n, set, timeout)
+ }
+ if newState == oldState {
+ ns.lock.Unlock()
+ return
+ }
+ if newState == 0 {
+ delete(ns.nodes, id)
+ if node.db {
+ ns.deleteNode(id)
+ }
+ } else {
+ if changed&ns.saveFlags != 0 {
+ node.dirty = true
+ }
+ }
+ ns.lock.Unlock()
+ // call state update subscription callbacks without holding the mutex
+ for _, sub := range ns.stateSubs {
+ if changed&sub.mask != 0 {
+ sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup})
+ }
+ }
+ if newState == 0 {
+ // call field subscriptions for discarded fields
+ for i, v := range node.fields {
+ if v != nil {
+ f := ns.fields[i]
+ if len(f.subs) > 0 {
+ for _, cb := range f.subs {
+ cb(n, Flags{setup: ns.setup}, v, nil)
+ }
+ }
+ }
+ }
+ }
+}
+
+// offlineCallbacks calls state update callbacks at startup or shutdown
+func (ns *NodeStateMachine) offlineCallbacks(start bool) {
+ for _, cb := range ns.offlineCallbackList {
+ for _, sub := range ns.stateSubs {
+ offState := offlineState & sub.mask
+ onState := cb.state & sub.mask
+ if offState != onState {
+ if start {
+ sub.callback(cb.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup})
+ } else {
+ sub.callback(cb.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup})
+ }
+ }
+ }
+ for i, f := range cb.fields {
+ if f != nil && ns.fields[i].subs != nil {
+ for _, fsub := range ns.fields[i].subs {
+ if start {
+ fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, nil, f)
+ } else {
+ fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, f, nil)
+ }
+ }
+ }
+ }
+ }
+ ns.offlineCallbackList = nil
+}
+
+// AddTimeout adds a node state timeout associated to the given state flag(s).
+// After the specified time interval, the relevant states will be reset.
+func (ns *NodeStateMachine) AddTimeout(n *enode.Node, flags Flags, timeout time.Duration) {
+ ns.lock.Lock()
+ defer ns.lock.Unlock()
+
+ ns.checkStarted()
+ if ns.stopped {
+ return
+ }
+ ns.addTimeout(n, ns.stateMask(flags), timeout)
+}
+
+// addTimeout adds a node state timeout associated to the given state flag(s).
+func (ns *NodeStateMachine) addTimeout(n *enode.Node, mask bitMask, timeout time.Duration) {
+ _, node := ns.updateEnode(n)
+ if node == nil {
+ return
+ }
+ mask &= node.state
+ if mask == 0 {
+ return
+ }
+ ns.removeTimeouts(node, mask)
+ t := &nodeStateTimeout{mask: mask}
+ t.timer = ns.clock.AfterFunc(timeout, func() {
+ ns.SetState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0)
+ })
+ node.timeouts = append(node.timeouts, t)
+ if mask&ns.saveFlags != 0 {
+ node.dirty = true
+ }
+}
+
+// removeTimeout removes node state timeouts associated to the given state flag(s).
+// If a timeout was associated to multiple flags which are not all included in the
+// specified remove mask then only the included flags are de-associated and the timer
+// stays active.
+func (ns *NodeStateMachine) removeTimeouts(node *nodeInfo, mask bitMask) {
+ for i := 0; i < len(node.timeouts); i++ {
+ t := node.timeouts[i]
+ match := t.mask & mask
+ if match == 0 {
+ continue
+ }
+ t.mask -= match
+ if t.mask != 0 {
+ continue
+ }
+ t.timer.Stop()
+ node.timeouts[i] = node.timeouts[len(node.timeouts)-1]
+ node.timeouts = node.timeouts[:len(node.timeouts)-1]
+ i--
+ if match&ns.saveFlags != 0 {
+ node.dirty = true
+ }
+ }
+}
+
+// GetField retrieves the given field of the given node
+func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} {
+ ns.lock.Lock()
+ defer ns.lock.Unlock()
+
+ ns.checkStarted()
+ if ns.stopped {
+ return nil
+ }
+ if _, node := ns.updateEnode(n); node != nil {
+ return node.fields[ns.fieldIndex(field)]
+ }
+ return nil
+}
+
+// SetField sets the given field of the given node
+func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface{}) error {
+ ns.lock.Lock()
+ ns.checkStarted()
+ if ns.stopped {
+ ns.lock.Unlock()
+ return nil
+ }
+ _, node := ns.updateEnode(n)
+ if node == nil {
+ ns.lock.Unlock()
+ return nil
+ }
+ fieldIndex := ns.fieldIndex(field)
+ f := ns.fields[fieldIndex]
+ if value != nil && reflect.TypeOf(value) != f.ftype {
+ log.Error("Invalid field type", "type", reflect.TypeOf(value), "required", f.ftype)
+ ns.lock.Unlock()
+ return errors.New("invalid field type")
+ }
+ oldValue := node.fields[fieldIndex]
+ if value == oldValue {
+ ns.lock.Unlock()
+ return nil
+ }
+ node.fields[fieldIndex] = value
+ if f.encode != nil {
+ node.dirty = true
+ }
+
+ state := node.state
+ ns.lock.Unlock()
+ if len(f.subs) > 0 {
+ for _, cb := range f.subs {
+ cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value)
+ }
+ }
+ return nil
+}
+
+// ForEach calls the callback for each node having all of the required and none of the
+// disabled flags set
+func (ns *NodeStateMachine) ForEach(requireFlags, disableFlags Flags, cb func(n *enode.Node, state Flags)) {
+ ns.lock.Lock()
+ ns.checkStarted()
+ type callback struct {
+ node *enode.Node
+ state bitMask
+ }
+ require, disable := ns.stateMask(requireFlags), ns.stateMask(disableFlags)
+ var callbacks []callback
+ for _, node := range ns.nodes {
+ if node.state&require == require && node.state&disable == 0 {
+ callbacks = append(callbacks, callback{node.node, node.state & (require | disable)})
+ }
+ }
+ ns.lock.Unlock()
+ for _, c := range callbacks {
+ cb(c.node, Flags{mask: c.state, setup: ns.setup})
+ }
+}
+
+// GetNode returns the enode currently associated with the given ID
+func (ns *NodeStateMachine) GetNode(id enode.ID) *enode.Node {
+ ns.lock.Lock()
+ defer ns.lock.Unlock()
+
+ ns.checkStarted()
+ if node := ns.nodes[id]; node != nil {
+ return node.node
+ }
+ return nil
+}
+
+// AddLogMetrics adds logging and/or metrics for nodes entering, exiting and currently
+// being in a given set specified by required and disabled state flags
+func (ns *NodeStateMachine) AddLogMetrics(requireFlags, disableFlags Flags, name string, inMeter, outMeter metrics.Meter, gauge metrics.Gauge) {
+ var count int64
+ ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState Flags) {
+ oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags)
+ newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags)
+ if newMatch == oldMatch {
+ return
+ }
+
+ if newMatch {
+ count++
+ if name != "" {
+ log.Debug("Node entered", "set", name, "id", n.ID(), "count", count)
+ }
+ if inMeter != nil {
+ inMeter.Mark(1)
+ }
+ } else {
+ count--
+ if name != "" {
+ log.Debug("Node left", "set", name, "id", n.ID(), "count", count)
+ }
+ if outMeter != nil {
+ outMeter.Mark(1)
+ }
+ }
+ if gauge != nil {
+ gauge.Update(count)
+ }
+ })
+}
diff --git a/p2p/nodestate/nodestate_test.go b/p2p/nodestate/nodestate_test.go
new file mode 100644
index 000000000..f6ff3ffc0
--- /dev/null
+++ b/p2p/nodestate/nodestate_test.go
@@ -0,0 +1,389 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package nodestate
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common/mclock"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/p2p/enr"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+func testSetup(flagPersist []bool, fieldType []reflect.Type) (*Setup, []Flags, []Field) {
+ setup := &Setup{}
+ flags := make([]Flags, len(flagPersist))
+ for i, persist := range flagPersist {
+ if persist {
+ flags[i] = setup.NewPersistentFlag(fmt.Sprintf("flag-%d", i))
+ } else {
+ flags[i] = setup.NewFlag(fmt.Sprintf("flag-%d", i))
+ }
+ }
+ fields := make([]Field, len(fieldType))
+ for i, ftype := range fieldType {
+ switch ftype {
+ case reflect.TypeOf(uint64(0)):
+ fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, uint64FieldEnc, uint64FieldDec)
+ case reflect.TypeOf(""):
+ fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, stringFieldEnc, stringFieldDec)
+ default:
+ fields[i] = setup.NewField(fmt.Sprintf("field-%d", i), ftype)
+ }
+ }
+ return setup, flags, fields
+}
+
+func testNode(b byte) *enode.Node {
+ r := &enr.Record{}
+ r.SetSig(dummyIdentity{b}, []byte{42})
+ n, _ := enode.New(dummyIdentity{b}, r)
+ return n
+}
+
+func TestCallback(t *testing.T) {
+ mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
+
+ s, flags, _ := testSetup([]bool{false, false, false}, nil)
+ ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
+
+ set0 := make(chan struct{}, 1)
+ set1 := make(chan struct{}, 1)
+ set2 := make(chan struct{}, 1)
+ ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { set0 <- struct{}{} })
+ ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { set1 <- struct{}{} })
+ ns.SubscribeState(flags[2], func(n *enode.Node, oldState, newState Flags) { set2 <- struct{}{} })
+
+ ns.Start()
+
+ ns.SetState(testNode(1), flags[0], Flags{}, 0)
+ ns.SetState(testNode(1), flags[1], Flags{}, time.Second)
+ ns.SetState(testNode(1), flags[2], Flags{}, 2*time.Second)
+
+ for i := 0; i < 3; i++ {
+ select {
+ case <-set0:
+ case <-set1:
+ case <-set2:
+ case <-time.After(time.Second):
+ t.Fatalf("failed to invoke callback")
+ }
+ }
+}
+
+func TestPersistentFlags(t *testing.T) {
+ mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
+
+ s, flags, _ := testSetup([]bool{true, true, true, false}, nil)
+ ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
+
+ saveNode := make(chan *nodeInfo, 5)
+ ns.saveNodeHook = func(node *nodeInfo) {
+ saveNode <- node
+ }
+
+ ns.Start()
+
+ ns.SetState(testNode(1), flags[0], Flags{}, time.Second) // state with timeout should not be saved
+ ns.SetState(testNode(2), flags[1], Flags{}, 0)
+ ns.SetState(testNode(3), flags[2], Flags{}, 0)
+ ns.SetState(testNode(4), flags[3], Flags{}, 0)
+ ns.SetState(testNode(5), flags[0], Flags{}, 0)
+ ns.Persist(testNode(5))
+ select {
+ case <-saveNode:
+ case <-time.After(time.Second):
+ t.Fatalf("Timeout")
+ }
+ ns.Stop()
+
+ for i := 0; i < 2; i++ {
+ select {
+ case <-saveNode:
+ case <-time.After(time.Second):
+ t.Fatalf("Timeout")
+ }
+ }
+ select {
+ case <-saveNode:
+ t.Fatalf("Unexpected saveNode")
+ case <-time.After(time.Millisecond * 100):
+ }
+}
+
+func TestSetField(t *testing.T) {
+ mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
+
+ s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf("")})
+ ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
+
+ saveNode := make(chan *nodeInfo, 1)
+ ns.saveNodeHook = func(node *nodeInfo) {
+ saveNode <- node
+ }
+
+ ns.Start()
+
+ // Set field before setting state
+ ns.SetField(testNode(1), fields[0], "hello world")
+ field := ns.GetField(testNode(1), fields[0])
+ if field != nil {
+ t.Fatalf("Field shouldn't be set before setting states")
+ }
+ // Set field after setting state
+ ns.SetState(testNode(1), flags[0], Flags{}, 0)
+ ns.SetField(testNode(1), fields[0], "hello world")
+ field = ns.GetField(testNode(1), fields[0])
+ if field == nil {
+ t.Fatalf("Field should be set after setting states")
+ }
+ if err := ns.SetField(testNode(1), fields[0], 123); err == nil {
+ t.Fatalf("Invalid field should be rejected")
+ }
+ // Dirty node should be written back
+ ns.Stop()
+ select {
+ case <-saveNode:
+ case <-time.After(time.Second):
+ t.Fatalf("Timeout")
+ }
+}
+
+func TestUnsetField(t *testing.T) {
+ mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
+
+ s, flags, fields := testSetup([]bool{false}, []reflect.Type{reflect.TypeOf("")})
+ ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
+
+ ns.Start()
+
+ ns.SetState(testNode(1), flags[0], Flags{}, time.Second)
+ ns.SetField(testNode(1), fields[0], "hello world")
+
+ ns.SetState(testNode(1), Flags{}, flags[0], 0)
+ if field := ns.GetField(testNode(1), fields[0]); field != nil {
+ t.Fatalf("Field should be unset")
+ }
+}
+
+func TestSetState(t *testing.T) {
+ mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
+
+ s, flags, _ := testSetup([]bool{false, false, false}, nil)
+ ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
+
+ type change struct{ old, new Flags }
+ set := make(chan change, 1)
+ ns.SubscribeState(flags[0].Or(flags[1]), func(n *enode.Node, oldState, newState Flags) {
+ set <- change{
+ old: oldState,
+ new: newState,
+ }
+ })
+
+ ns.Start()
+
+ check := func(expectOld, expectNew Flags, expectChange bool) {
+ if expectChange {
+ select {
+ case c := <-set:
+ if !c.old.Equals(expectOld) {
+ t.Fatalf("Old state mismatch")
+ }
+ if !c.new.Equals(expectNew) {
+ t.Fatalf("New state mismatch")
+ }
+ case <-time.After(time.Second):
+ }
+ return
+ }
+ select {
+ case <-set:
+ t.Fatalf("Unexpected change")
+ case <-time.After(time.Millisecond * 100):
+ return
+ }
+ }
+ ns.SetState(testNode(1), flags[0], Flags{}, 0)
+ check(Flags{}, flags[0], true)
+
+ ns.SetState(testNode(1), flags[1], Flags{}, 0)
+ check(flags[0], flags[0].Or(flags[1]), true)
+
+ ns.SetState(testNode(1), flags[2], Flags{}, 0)
+ check(Flags{}, Flags{}, false)
+
+ ns.SetState(testNode(1), Flags{}, flags[0], 0)
+ check(flags[0].Or(flags[1]), flags[1], true)
+
+ ns.SetState(testNode(1), Flags{}, flags[1], 0)
+ check(flags[1], Flags{}, true)
+
+ ns.SetState(testNode(1), Flags{}, flags[2], 0)
+ check(Flags{}, Flags{}, false)
+
+ ns.SetState(testNode(1), flags[0].Or(flags[1]), Flags{}, time.Second)
+ check(Flags{}, flags[0].Or(flags[1]), true)
+ clock.Run(time.Second)
+ check(flags[0].Or(flags[1]), Flags{}, true)
+}
+
+func uint64FieldEnc(field interface{}) ([]byte, error) {
+ if u, ok := field.(uint64); ok {
+ enc, err := rlp.EncodeToBytes(&u)
+ return enc, err
+ } else {
+ return nil, errors.New("invalid field type")
+ }
+}
+
+func uint64FieldDec(enc []byte) (interface{}, error) {
+ var u uint64
+ err := rlp.DecodeBytes(enc, &u)
+ return u, err
+}
+
+func stringFieldEnc(field interface{}) ([]byte, error) {
+ if s, ok := field.(string); ok {
+ return []byte(s), nil
+ } else {
+ return nil, errors.New("invalid field type")
+ }
+}
+
+func stringFieldDec(enc []byte) (interface{}, error) {
+ return string(enc), nil
+}
+
+func TestPersistentFields(t *testing.T) {
+ mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
+
+ s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0)), reflect.TypeOf("")})
+ ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
+
+ ns.Start()
+ ns.SetState(testNode(1), flags[0], Flags{}, 0)
+ ns.SetField(testNode(1), fields[0], uint64(100))
+ ns.SetField(testNode(1), fields[1], "hello world")
+ ns.Stop()
+
+ ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
+
+ ns2.Start()
+ field0 := ns2.GetField(testNode(1), fields[0])
+ if !reflect.DeepEqual(field0, uint64(100)) {
+ t.Fatalf("Field changed")
+ }
+ field1 := ns2.GetField(testNode(1), fields[1])
+ if !reflect.DeepEqual(field1, "hello world") {
+ t.Fatalf("Field changed")
+ }
+
+ s.Version++
+ ns3 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
+ ns3.Start()
+ if ns3.GetField(testNode(1), fields[0]) != nil {
+ t.Fatalf("Old field version should have been discarded")
+ }
+}
+
+func TestFieldSub(t *testing.T) {
+ mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
+
+ s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0))})
+ ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
+
+ var (
+ lastState Flags
+ lastOldValue, lastNewValue interface{}
+ )
+ ns.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) {
+ lastState, lastOldValue, lastNewValue = state, oldValue, newValue
+ })
+ check := func(state Flags, oldValue, newValue interface{}) {
+ if !lastState.Equals(state) || lastOldValue != oldValue || lastNewValue != newValue {
+ t.Fatalf("Incorrect field sub callback (expected [%v %v %v], got [%v %v %v])", state, oldValue, newValue, lastState, lastOldValue, lastNewValue)
+ }
+ }
+ ns.Start()
+ ns.SetState(testNode(1), flags[0], Flags{}, 0)
+ ns.SetField(testNode(1), fields[0], uint64(100))
+ check(flags[0], nil, uint64(100))
+ ns.Stop()
+ check(s.OfflineFlag(), uint64(100), nil)
+
+ ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
+ ns2.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) {
+ lastState, lastOldValue, lastNewValue = state, oldValue, newValue
+ })
+ ns2.Start()
+ check(s.OfflineFlag(), nil, uint64(100))
+ ns2.SetState(testNode(1), Flags{}, flags[0], 0)
+ check(Flags{}, uint64(100), nil)
+ ns2.Stop()
+}
+
+func TestDuplicatedFlags(t *testing.T) {
+ mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
+
+ s, flags, _ := testSetup([]bool{true}, nil)
+ ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
+
+ type change struct{ old, new Flags }
+ set := make(chan change, 1)
+ ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) {
+ set <- change{oldState, newState}
+ })
+
+ ns.Start()
+ defer ns.Stop()
+
+ check := func(expectOld, expectNew Flags, expectChange bool) {
+ if expectChange {
+ select {
+ case c := <-set:
+ if !c.old.Equals(expectOld) {
+ t.Fatalf("Old state mismatch")
+ }
+ if !c.new.Equals(expectNew) {
+ t.Fatalf("New state mismatch")
+ }
+ case <-time.After(time.Second):
+ }
+ return
+ }
+ select {
+ case <-set:
+ t.Fatalf("Unexpected change")
+ case <-time.After(time.Millisecond * 100):
+ return
+ }
+ }
+ ns.SetState(testNode(1), flags[0], Flags{}, time.Second)
+ check(Flags{}, flags[0], true)
+ ns.SetState(testNode(1), flags[0], Flags{}, 2*time.Second) // extend the timeout to 2s
+ check(Flags{}, flags[0], false)
+
+ clock.Run(2 * time.Second)
+ check(flags[0], Flags{}, true)
+}
diff --git a/params/bootnodes.go b/params/bootnodes.go
index 0d72321b0..e1898d762 100644
--- a/params/bootnodes.go
+++ b/params/bootnodes.go
@@ -65,11 +65,22 @@ var GoerliBootnodes = []string{
const dnsPrefix = "enrtree://AKA3AM6LPBYEUDMVNU3BSVQJ5AD45Y7YPOHJLEF6W26QOE4VTUDPE@"
-// These DNS names provide bootstrap connectivity for public testnets and the mainnet.
-// See https://github.com/ethereum/discv4-dns-lists for more information.
-var KnownDNSNetworks = map[common.Hash]string{
- MainnetGenesisHash: dnsPrefix + "all.mainnet.ethdisco.net",
- RopstenGenesisHash: dnsPrefix + "all.ropsten.ethdisco.net",
- RinkebyGenesisHash: dnsPrefix + "all.rinkeby.ethdisco.net",
- GoerliGenesisHash: dnsPrefix + "all.goerli.ethdisco.net",
+// KnownDNSNetwork returns the address of a public DNS-based node list for the given
+// genesis hash and protocol. See https://github.com/ethereum/discv4-dns-lists for more
+// information.
+func KnownDNSNetwork(genesis common.Hash, protocol string) string {
+ var net string
+ switch genesis {
+ case MainnetGenesisHash:
+ net = "mainnet"
+ case RopstenGenesisHash:
+ net = "ropsten"
+ case RinkebyGenesisHash:
+ net = "rinkeby"
+ case GoerliGenesisHash:
+ net = "goerli"
+ default:
+ return ""
+ }
+ return dnsPrefix + protocol + "." + net + ".ethdisco.net"
}