diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 6f6233e4d..59a2c9f83 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -1563,19 +1563,19 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { cfg.NetworkId = 3 } cfg.Genesis = core.DefaultRopstenGenesisBlock() - setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.RopstenGenesisHash]) + setDNSDiscoveryDefaults(cfg, params.RopstenGenesisHash) case ctx.GlobalBool(RinkebyFlag.Name): if !ctx.GlobalIsSet(NetworkIdFlag.Name) { cfg.NetworkId = 4 } cfg.Genesis = core.DefaultRinkebyGenesisBlock() - setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.RinkebyGenesisHash]) + setDNSDiscoveryDefaults(cfg, params.RinkebyGenesisHash) case ctx.GlobalBool(GoerliFlag.Name): if !ctx.GlobalIsSet(NetworkIdFlag.Name) { cfg.NetworkId = 5 } cfg.Genesis = core.DefaultGoerliGenesisBlock() - setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.GoerliGenesisHash]) + setDNSDiscoveryDefaults(cfg, params.GoerliGenesisHash) case ctx.GlobalBool(DeveloperFlag.Name): if !ctx.GlobalIsSet(NetworkIdFlag.Name) { cfg.NetworkId = 1337 @@ -1604,18 +1604,25 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { } default: if cfg.NetworkId == 1 { - setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.MainnetGenesisHash]) + setDNSDiscoveryDefaults(cfg, params.MainnetGenesisHash) } } } // setDNSDiscoveryDefaults configures DNS discovery with the given URL if // no URLs are set. -func setDNSDiscoveryDefaults(cfg *eth.Config, url string) { +func setDNSDiscoveryDefaults(cfg *eth.Config, genesis common.Hash) { if cfg.DiscoveryURLs != nil { - return + return // already set through flags/config + } + + protocol := "eth" + if cfg.SyncMode == downloader.LightSync { + protocol = "les" + } + if url := params.KnownDNSNetwork(genesis, protocol); url != "" { + cfg.DiscoveryURLs = []string{url} } - cfg.DiscoveryURLs = []string{url} } // RegisterEthService adds an Ethereum client to the stack. diff --git a/core/forkid/forkid.go b/core/forkid/forkid.go index e433db446..08a948510 100644 --- a/core/forkid/forkid.go +++ b/core/forkid/forkid.go @@ -27,7 +27,7 @@ import ( "strings" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" ) @@ -44,6 +44,18 @@ var ( ErrLocalIncompatibleOrStale = errors.New("local incompatible or needs update") ) +// Blockchain defines all necessary method to build a forkID. +type Blockchain interface { + // Config retrieves the chain's fork configuration. + Config() *params.ChainConfig + + // Genesis retrieves the chain's genesis block. + Genesis() *types.Block + + // CurrentHeader retrieves the current head header of the canonical chain. + CurrentHeader() *types.Header +} + // ID is a fork identifier as defined by EIP-2124. type ID struct { Hash [4]byte // CRC32 checksum of the genesis block and passed fork block numbers @@ -54,7 +66,7 @@ type ID struct { type Filter func(id ID) error // NewID calculates the Ethereum fork ID from the chain config and head. -func NewID(chain *core.BlockChain) ID { +func NewID(chain Blockchain) ID { return newID( chain.Config(), chain.Genesis().Hash(), @@ -85,7 +97,7 @@ func newID(config *params.ChainConfig, genesis common.Hash, head uint64) ID { // NewFilter creates a filter that returns if a fork ID should be rejected or not // based on the local chain's status. -func NewFilter(chain *core.BlockChain) Filter { +func NewFilter(chain Blockchain) Filter { return newFilter( chain.Config(), chain.Genesis().Hash(), diff --git a/eth/backend.go b/eth/backend.go index 391e3c0e6..e4f98360d 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -72,7 +72,7 @@ type Ethereum struct { blockchain *core.BlockChain protocolManager *ProtocolManager lesServer LesServer - dialCandiates enode.Iterator + dialCandidates enode.Iterator // DB interfaces chainDb ethdb.Database // Block chain database @@ -226,7 +226,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { } eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams) - eth.dialCandiates, err = eth.setupDiscovery(&ctx.Config.P2P) + eth.dialCandidates, err = eth.setupDiscovery(&ctx.Config.P2P) if err != nil { return nil, err } @@ -523,7 +523,7 @@ func (s *Ethereum) Protocols() []p2p.Protocol { for i, vsn := range ProtocolVersions { protos[i] = s.protocolManager.makeProtocol(vsn) protos[i].Attributes = []enr.Entry{s.currentEthEntry()} - protos[i].DialCandidates = s.dialCandiates + protos[i].DialCandidates = s.dialCandidates } if s.lesServer != nil { protos = append(protos, s.lesServer.Protocols()...) diff --git a/les/client.go b/les/client.go index 2ac3285d8..34a654e22 100644 --- a/les/client.go +++ b/les/client.go @@ -51,16 +51,17 @@ import ( type LightEthereum struct { lesCommons - peers *serverPeerSet - reqDist *requestDistributor - retriever *retrieveManager - odr *LesOdr - relay *lesTxRelay - handler *clientHandler - txPool *light.TxPool - blockchain *light.LightChain - serverPool *serverPool - valueTracker *lpc.ValueTracker + peers *serverPeerSet + reqDist *requestDistributor + retriever *retrieveManager + odr *LesOdr + relay *lesTxRelay + handler *clientHandler + txPool *light.TxPool + blockchain *light.LightChain + serverPool *serverPool + valueTracker *lpc.ValueTracker + dialCandidates enode.Iterator bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests bloomIndexer *core.ChainIndexer // Bloom indexer operating during block imports @@ -104,11 +105,19 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb), bloomRequests: make(chan chan *bloombits.Retrieval), bloomIndexer: eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations), - serverPool: newServerPool(chainDb, config.UltraLightServers), valueTracker: lpc.NewValueTracker(lespayDb, &mclock.System{}, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)), } peers.subscribe((*vtSubscription)(leth.valueTracker)) - leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool) + + dnsdisc, err := leth.setupDiscovery(&ctx.Config.P2P) + if err != nil { + return nil, err + } + leth.serverPool = newServerPool(lespayDb, []byte("serverpool:"), leth.valueTracker, dnsdisc, time.Second, nil, &mclock.System{}, config.UltraLightServers) + peers.subscribe(leth.serverPool) + leth.dialCandidates = leth.serverPool.dialIterator + + leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool.getTimeout) leth.relay = newLesTxRelay(peers, leth.retriever) leth.odr = NewLesOdr(chainDb, light.DefaultClientIndexerConfig, leth.retriever) @@ -140,11 +149,6 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { leth.chtIndexer.Start(leth.blockchain) leth.bloomIndexer.Start(leth.blockchain) - leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth) - if leth.handler.ulc != nil { - log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction) - leth.blockchain.DisableCheckFreq() - } // Rewind the chain in case of an incompatible config upgrade. if compat, ok := genesisErr.(*params.ConfigCompatError); ok { log.Warn("Rewinding chain to upgrade configuration", "err", compat) @@ -159,6 +163,11 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { } leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams) + leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth) + if leth.handler.ulc != nil { + log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction) + leth.blockchain.DisableCheckFreq() + } return leth, nil } @@ -260,7 +269,7 @@ func (s *LightEthereum) Protocols() []p2p.Protocol { return p.Info() } return nil - }) + }, s.dialCandidates) } // Start implements node.Service, starting all internal goroutines needed by the @@ -268,15 +277,12 @@ func (s *LightEthereum) Protocols() []p2p.Protocol { func (s *LightEthereum) Start(srvr *p2p.Server) error { log.Warn("Light client mode is an experimental feature") + s.serverPool.start() // Start bloom request workers. s.wg.Add(bloomServiceThreads) s.startBloomHandlers(params.BloomBitsBlocksClient) s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.config.NetworkId) - - // clients are searching for the first advertised protocol in the list - protocolVersion := AdvertiseProtocolVersions[0] - s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion)) return nil } @@ -284,6 +290,8 @@ func (s *LightEthereum) Start(srvr *p2p.Server) error { // Ethereum protocol. func (s *LightEthereum) Stop() error { close(s.closeCh) + s.serverPool.stop() + s.valueTracker.Stop() s.peers.close() s.reqDist.close() s.odr.Stop() @@ -295,8 +303,6 @@ func (s *LightEthereum) Stop() error { s.txPool.Stop() s.engine.Close() s.eventMux.Stop() - s.serverPool.stop() - s.valueTracker.Stop() s.chainDb.Close() s.wg.Wait() log.Info("Light ethereum stopped") diff --git a/les/client_handler.go b/les/client_handler.go index 6367fdb6b..abe472e46 100644 --- a/les/client_handler.go +++ b/les/client_handler.go @@ -64,7 +64,7 @@ func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.T if checkpoint != nil { height = (checkpoint.SectionIndex+1)*params.CHTFrequency - 1 } - handler.fetcher = newLightFetcher(handler) + handler.fetcher = newLightFetcher(handler, backend.serverPool.getTimeout) handler.downloader = downloader.New(height, backend.chainDb, nil, backend.eventMux, nil, backend.blockchain, handler.removePeer) handler.backend.peers.subscribe((*downloaderPeerNotify)(handler)) return handler @@ -85,14 +85,9 @@ func (h *clientHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) } peer := newServerPeer(int(version), h.backend.config.NetworkId, trusted, p, newMeteredMsgWriter(rw, int(version))) defer peer.close() - peer.poolEntry = h.backend.serverPool.connect(peer, peer.Node()) - if peer.poolEntry == nil { - return p2p.DiscRequested - } h.wg.Add(1) defer h.wg.Done() err := h.handle(peer) - h.backend.serverPool.disconnect(peer.poolEntry) return err } @@ -129,10 +124,6 @@ func (h *clientHandler) handle(p *serverPeer) error { h.fetcher.announce(p, &announceData{Hash: p.headInfo.Hash, Number: p.headInfo.Number, Td: p.headInfo.Td}) - // pool entry can be nil during the unit test. - if p.poolEntry != nil { - h.backend.serverPool.registered(p.poolEntry) - } // Mark the peer starts to be served. atomic.StoreUint32(&p.serving, 1) defer atomic.StoreUint32(&p.serving, 0) diff --git a/les/commons.go b/les/commons.go index 29b5b7660..cd8a22834 100644 --- a/les/commons.go +++ b/les/commons.go @@ -81,7 +81,7 @@ type NodeInfo struct { } // makeProtocols creates protocol descriptors for the given LES versions. -func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}) []p2p.Protocol { +func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}, dialCandidates enode.Iterator) []p2p.Protocol { protos := make([]p2p.Protocol, len(versions)) for i, version := range versions { version := version @@ -93,7 +93,8 @@ func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p Run: func(peer *p2p.Peer, rw p2p.MsgReadWriter) error { return runPeer(version, peer, rw) }, - PeerInfo: peerInfo, + PeerInfo: peerInfo, + DialCandidates: dialCandidates, } } return protos diff --git a/les/distributor.go b/les/distributor.go index 97d2ccdfe..31150e4d7 100644 --- a/les/distributor.go +++ b/les/distributor.go @@ -180,12 +180,11 @@ func (d *requestDistributor) loop() { type selectPeerItem struct { peer distPeer req *distReq - weight int64 + weight uint64 } -// Weight implements wrsItem interface -func (sp selectPeerItem) Weight() int64 { - return sp.weight +func selectPeerWeight(i interface{}) uint64 { + return i.(selectPeerItem).weight } // nextRequest returns the next possible request from any peer, along with the @@ -220,9 +219,9 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) { wait, bufRemain := peer.waitBefore(cost) if wait == 0 { if sel == nil { - sel = utils.NewWeightedRandomSelect() + sel = utils.NewWeightedRandomSelect(selectPeerWeight) } - sel.Update(selectPeerItem{peer: peer, req: req, weight: int64(bufRemain*1000000) + 1}) + sel.Update(selectPeerItem{peer: peer, req: req, weight: uint64(bufRemain*1000000) + 1}) } else { if bestWait == 0 || wait < bestWait { bestWait = wait diff --git a/les/enr_entry.go b/les/enr_entry.go index c2a92dd99..65d0d1fdb 100644 --- a/les/enr_entry.go +++ b/les/enr_entry.go @@ -17,6 +17,9 @@ package les import ( + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/dnsdisc" + "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/rlp" ) @@ -30,3 +33,12 @@ type lesEntry struct { func (e lesEntry) ENRKey() string { return "les" } + +// setupDiscovery creates the node discovery source for the eth protocol. +func (eth *LightEthereum) setupDiscovery(cfg *p2p.Config) (enode.Iterator, error) { + if /*cfg.NoDiscovery || */ len(eth.config.DiscoveryURLs) == 0 { + return nil, nil + } + client := dnsdisc.NewClient(dnsdisc.Config{}) + return client.NewIterator(eth.config.DiscoveryURLs...) +} diff --git a/les/fetcher.go b/les/fetcher.go index 7fa81cbcb..aaf74aaa1 100644 --- a/les/fetcher.go +++ b/les/fetcher.go @@ -40,8 +40,9 @@ const ( // ODR system to ensure that we only request data related to a certain block from peers who have already processed // and announced that block. type lightFetcher struct { - handler *clientHandler - chain *light.LightChain + handler *clientHandler + chain *light.LightChain + softRequestTimeout func() time.Duration lock sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests maxConfirmedTd *big.Int @@ -109,18 +110,19 @@ type fetchResponse struct { } // newLightFetcher creates a new light fetcher -func newLightFetcher(h *clientHandler) *lightFetcher { +func newLightFetcher(h *clientHandler, softRequestTimeout func() time.Duration) *lightFetcher { f := &lightFetcher{ - handler: h, - chain: h.backend.blockchain, - peers: make(map[*serverPeer]*fetcherPeerInfo), - deliverChn: make(chan fetchResponse, 100), - requested: make(map[uint64]fetchRequest), - timeoutChn: make(chan uint64), - requestTrigger: make(chan struct{}, 1), - syncDone: make(chan *serverPeer), - closeCh: make(chan struct{}), - maxConfirmedTd: big.NewInt(0), + handler: h, + chain: h.backend.blockchain, + peers: make(map[*serverPeer]*fetcherPeerInfo), + deliverChn: make(chan fetchResponse, 100), + requested: make(map[uint64]fetchRequest), + timeoutChn: make(chan uint64), + requestTrigger: make(chan struct{}, 1), + syncDone: make(chan *serverPeer), + closeCh: make(chan struct{}), + maxConfirmedTd: big.NewInt(0), + softRequestTimeout: softRequestTimeout, } h.backend.peers.subscribe(f) @@ -163,7 +165,7 @@ func (f *lightFetcher) syncLoop() { f.lock.Unlock() } else { go func() { - time.Sleep(softRequestTimeout) + time.Sleep(f.softRequestTimeout()) f.reqMu.Lock() req, ok := f.requested[reqID] if ok { @@ -187,7 +189,6 @@ func (f *lightFetcher) syncLoop() { } f.reqMu.Unlock() if ok { - f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), true) req.peer.Log().Debug("Fetching data timed out hard") go f.handler.removePeer(req.peer.id) } @@ -201,9 +202,6 @@ func (f *lightFetcher) syncLoop() { delete(f.requested, resp.reqID) } f.reqMu.Unlock() - if ok { - f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), req.timeout) - } f.lock.Lock() if !ok || !(f.syncing || f.processResponse(req, resp)) { resp.peer.Log().Debug("Failed processing response") @@ -879,12 +877,10 @@ func (f *lightFetcher) checkUpdateStats(p *serverPeer, newEntry *updateStatsEntr fp.firstUpdateStats = newEntry } for fp.firstUpdateStats != nil && fp.firstUpdateStats.time <= now-mclock.AbsTime(blockDelayTimeout) { - f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, blockDelayTimeout) fp.firstUpdateStats = fp.firstUpdateStats.next } if fp.confirmedTd != nil { for fp.firstUpdateStats != nil && fp.firstUpdateStats.td.Cmp(fp.confirmedTd) <= 0 { - f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, time.Duration(now-fp.firstUpdateStats.time)) fp.firstUpdateStats = fp.firstUpdateStats.next } } diff --git a/les/lespay/client/fillset.go b/les/lespay/client/fillset.go new file mode 100644 index 000000000..0da850bca --- /dev/null +++ b/les/lespay/client/fillset.go @@ -0,0 +1,107 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "sync" + + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/nodestate" +) + +// FillSet tries to read nodes from an input iterator and add them to a node set by +// setting the specified node state flag(s) until the size of the set reaches the target. +// Note that other mechanisms (like other FillSet instances reading from different inputs) +// can also set the same flag(s) and FillSet will always care about the total number of +// nodes having those flags. +type FillSet struct { + lock sync.Mutex + cond *sync.Cond + ns *nodestate.NodeStateMachine + input enode.Iterator + closed bool + flags nodestate.Flags + count, target int +} + +// NewFillSet creates a new FillSet +func NewFillSet(ns *nodestate.NodeStateMachine, input enode.Iterator, flags nodestate.Flags) *FillSet { + fs := &FillSet{ + ns: ns, + input: input, + flags: flags, + } + fs.cond = sync.NewCond(&fs.lock) + + ns.SubscribeState(flags, func(n *enode.Node, oldState, newState nodestate.Flags) { + fs.lock.Lock() + if oldState.Equals(flags) { + fs.count-- + } + if newState.Equals(flags) { + fs.count++ + } + if fs.target > fs.count { + fs.cond.Signal() + } + fs.lock.Unlock() + }) + + go fs.readLoop() + return fs +} + +// readLoop keeps reading nodes from the input and setting the specified flags for them +// whenever the node set size is under the current target +func (fs *FillSet) readLoop() { + for { + fs.lock.Lock() + for fs.target <= fs.count && !fs.closed { + fs.cond.Wait() + } + + fs.lock.Unlock() + if !fs.input.Next() { + return + } + fs.ns.SetState(fs.input.Node(), fs.flags, nodestate.Flags{}, 0) + } +} + +// SetTarget sets the current target for node set size. If the previous target was not +// reached and FillSet was still waiting for the next node from the input then the next +// incoming node will be added to the set regardless of the target. This ensures that +// all nodes coming from the input are eventually added to the set. +func (fs *FillSet) SetTarget(target int) { + fs.lock.Lock() + defer fs.lock.Unlock() + + fs.target = target + if fs.target > fs.count { + fs.cond.Signal() + } +} + +// Close shuts FillSet down and closes the input iterator +func (fs *FillSet) Close() { + fs.lock.Lock() + defer fs.lock.Unlock() + + fs.closed = true + fs.input.Close() + fs.cond.Signal() +} diff --git a/les/lespay/client/fillset_test.go b/les/lespay/client/fillset_test.go new file mode 100644 index 000000000..58240682c --- /dev/null +++ b/les/lespay/client/fillset_test.go @@ -0,0 +1,113 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "math/rand" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/nodestate" +) + +type testIter struct { + waitCh chan struct{} + nodeCh chan *enode.Node + node *enode.Node +} + +func (i *testIter) Next() bool { + i.waitCh <- struct{}{} + i.node = <-i.nodeCh + return i.node != nil +} + +func (i *testIter) Node() *enode.Node { + return i.node +} + +func (i *testIter) Close() {} + +func (i *testIter) push() { + var id enode.ID + rand.Read(id[:]) + i.nodeCh <- enode.SignNull(new(enr.Record), id) +} + +func (i *testIter) waiting(timeout time.Duration) bool { + select { + case <-i.waitCh: + return true + case <-time.After(timeout): + return false + } +} + +func TestFillSet(t *testing.T) { + ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup) + iter := &testIter{ + waitCh: make(chan struct{}), + nodeCh: make(chan *enode.Node), + } + fs := NewFillSet(ns, iter, sfTest1) + ns.Start() + + expWaiting := func(i int, push bool) { + for ; i > 0; i-- { + if !iter.waiting(time.Second * 10) { + t.Fatalf("FillSet not waiting for new nodes") + } + if push { + iter.push() + } + } + } + + expNotWaiting := func() { + if iter.waiting(time.Millisecond * 100) { + t.Fatalf("FillSet unexpectedly waiting for new nodes") + } + } + + expNotWaiting() + fs.SetTarget(3) + expWaiting(3, true) + expNotWaiting() + fs.SetTarget(100) + expWaiting(2, true) + expWaiting(1, false) + // lower the target before the previous one has been filled up + fs.SetTarget(0) + iter.push() + expNotWaiting() + fs.SetTarget(10) + expWaiting(4, true) + expNotWaiting() + // remove all previosly set flags + ns.ForEach(sfTest1, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { + ns.SetState(node, nodestate.Flags{}, sfTest1, 0) + }) + // now expect FillSet to fill the set up again with 10 new nodes + expWaiting(10, true) + expNotWaiting() + + fs.Close() + ns.Stop() +} diff --git a/les/lespay/client/queueiterator.go b/les/lespay/client/queueiterator.go new file mode 100644 index 000000000..ad3f8df5b --- /dev/null +++ b/les/lespay/client/queueiterator.go @@ -0,0 +1,123 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "sync" + + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/nodestate" +) + +// QueueIterator returns nodes from the specified selectable set in the same order as +// they entered the set. +type QueueIterator struct { + lock sync.Mutex + cond *sync.Cond + + ns *nodestate.NodeStateMachine + queue []*enode.Node + nextNode *enode.Node + waitCallback func(bool) + fifo, closed bool +} + +// NewQueueIterator creates a new QueueIterator. Nodes are selectable if they have all the required +// and none of the disabled flags set. When a node is selected the selectedFlag is set which also +// disables further selectability until it is removed or times out. +func NewQueueIterator(ns *nodestate.NodeStateMachine, requireFlags, disableFlags nodestate.Flags, fifo bool, waitCallback func(bool)) *QueueIterator { + qi := &QueueIterator{ + ns: ns, + fifo: fifo, + waitCallback: waitCallback, + } + qi.cond = sync.NewCond(&qi.lock) + + ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState nodestate.Flags) { + oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags) + newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags) + if newMatch == oldMatch { + return + } + + qi.lock.Lock() + defer qi.lock.Unlock() + + if newMatch { + qi.queue = append(qi.queue, n) + } else { + id := n.ID() + for i, qn := range qi.queue { + if qn.ID() == id { + copy(qi.queue[i:len(qi.queue)-1], qi.queue[i+1:]) + qi.queue = qi.queue[:len(qi.queue)-1] + break + } + } + } + qi.cond.Signal() + }) + return qi +} + +// Next moves to the next selectable node. +func (qi *QueueIterator) Next() bool { + qi.lock.Lock() + if !qi.closed && len(qi.queue) == 0 { + if qi.waitCallback != nil { + qi.waitCallback(true) + } + for !qi.closed && len(qi.queue) == 0 { + qi.cond.Wait() + } + if qi.waitCallback != nil { + qi.waitCallback(false) + } + } + if qi.closed { + qi.nextNode = nil + qi.lock.Unlock() + return false + } + // Move to the next node in queue. + if qi.fifo { + qi.nextNode = qi.queue[0] + copy(qi.queue[:len(qi.queue)-1], qi.queue[1:]) + qi.queue = qi.queue[:len(qi.queue)-1] + } else { + qi.nextNode = qi.queue[len(qi.queue)-1] + qi.queue = qi.queue[:len(qi.queue)-1] + } + qi.lock.Unlock() + return true +} + +// Close ends the iterator. +func (qi *QueueIterator) Close() { + qi.lock.Lock() + qi.closed = true + qi.lock.Unlock() + qi.cond.Signal() +} + +// Node returns the current node. +func (qi *QueueIterator) Node() *enode.Node { + qi.lock.Lock() + defer qi.lock.Unlock() + + return qi.nextNode +} diff --git a/les/lespay/client/queueiterator_test.go b/les/lespay/client/queueiterator_test.go new file mode 100644 index 000000000..a74301c7d --- /dev/null +++ b/les/lespay/client/queueiterator_test.go @@ -0,0 +1,106 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/nodestate" +) + +func testNodeID(i int) enode.ID { + return enode.ID{42, byte(i % 256), byte(i / 256)} +} + +func testNodeIndex(id enode.ID) int { + if id[0] != 42 { + return -1 + } + return int(id[1]) + int(id[2])*256 +} + +func testNode(i int) *enode.Node { + return enode.SignNull(new(enr.Record), testNodeID(i)) +} + +func TestQueueIteratorFIFO(t *testing.T) { + testQueueIterator(t, true) +} + +func TestQueueIteratorLIFO(t *testing.T) { + testQueueIterator(t, false) +} + +func testQueueIterator(t *testing.T, fifo bool) { + ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup) + qi := NewQueueIterator(ns, sfTest2, sfTest3.Or(sfTest4), fifo, nil) + ns.Start() + for i := 1; i <= iterTestNodeCount; i++ { + ns.SetState(testNode(i), sfTest1, nodestate.Flags{}, 0) + } + next := func() int { + ch := make(chan struct{}) + go func() { + qi.Next() + close(ch) + }() + select { + case <-ch: + case <-time.After(time.Second * 5): + t.Fatalf("Iterator.Next() timeout") + } + node := qi.Node() + ns.SetState(node, sfTest4, nodestate.Flags{}, 0) + return testNodeIndex(node.ID()) + } + exp := func(i int) { + n := next() + if n != i { + t.Errorf("Wrong item returned by iterator (expected %d, got %d)", i, n) + } + } + explist := func(list []int) { + for i := range list { + if fifo { + exp(list[i]) + } else { + exp(list[len(list)-1-i]) + } + } + } + + ns.SetState(testNode(1), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(2), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(3), sfTest2, nodestate.Flags{}, 0) + explist([]int{1, 2, 3}) + ns.SetState(testNode(4), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(5), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(6), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(5), sfTest3, nodestate.Flags{}, 0) + explist([]int{4, 6}) + ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0) + ns.SetState(testNode(2), nodestate.Flags{}, sfTest4, 0) + ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0) + ns.SetState(testNode(2), sfTest3, nodestate.Flags{}, 0) + ns.SetState(testNode(2), nodestate.Flags{}, sfTest3, 0) + explist([]int{1, 3, 2}) + ns.Stop() +} diff --git a/les/lespay/client/valuetracker.go b/les/lespay/client/valuetracker.go index 92bfd694e..4e67b31d9 100644 --- a/les/lespay/client/valuetracker.go +++ b/les/lespay/client/valuetracker.go @@ -213,6 +213,15 @@ func (vt *ValueTracker) StatsExpirer() *utils.Expirer { return &vt.statsExpirer } +// StatsExpirer returns the current expiration factor so that other values can be expired +// with the same rate as the service value statistics. +func (vt *ValueTracker) StatsExpFactor() utils.ExpirationFactor { + vt.statsExpLock.RLock() + defer vt.statsExpLock.RUnlock() + + return vt.statsExpFactor +} + // loadFromDb loads the value tracker's state from the database and converts saved // request basket index mapping if it does not match the specified index to name mapping. func (vt *ValueTracker) loadFromDb(mapping []string) error { @@ -500,16 +509,3 @@ func (vt *ValueTracker) RequestStats() []RequestStatsItem { } return res } - -// TotalServiceValue returns the total service value provided by the given node (as -// a function of the weights which are calculated from the request timeout value). -func (vt *ValueTracker) TotalServiceValue(nv *NodeValueTracker, weights ResponseTimeWeights) float64 { - vt.statsExpLock.RLock() - expFactor := vt.statsExpFactor - vt.statsExpLock.RUnlock() - - nv.lock.Lock() - defer nv.lock.Unlock() - - return nv.rtStats.Value(weights, expFactor) -} diff --git a/les/lespay/client/wrsiterator.go b/les/lespay/client/wrsiterator.go new file mode 100644 index 000000000..8a2e39ad4 --- /dev/null +++ b/les/lespay/client/wrsiterator.go @@ -0,0 +1,128 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "sync" + + "github.com/ethereum/go-ethereum/les/utils" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/nodestate" +) + +// WrsIterator returns nodes from the specified selectable set with a weighted random +// selection. Selection weights are provided by a callback function. +type WrsIterator struct { + lock sync.Mutex + cond *sync.Cond + + ns *nodestate.NodeStateMachine + wrs *utils.WeightedRandomSelect + nextNode *enode.Node + closed bool +} + +// NewWrsIterator creates a new WrsIterator. Nodes are selectable if they have all the required +// and none of the disabled flags set. When a node is selected the selectedFlag is set which also +// disables further selectability until it is removed or times out. +func NewWrsIterator(ns *nodestate.NodeStateMachine, requireFlags, disableFlags nodestate.Flags, weightField nodestate.Field) *WrsIterator { + wfn := func(i interface{}) uint64 { + n := ns.GetNode(i.(enode.ID)) + if n == nil { + return 0 + } + wt, _ := ns.GetField(n, weightField).(uint64) + return wt + } + + w := &WrsIterator{ + ns: ns, + wrs: utils.NewWeightedRandomSelect(wfn), + } + w.cond = sync.NewCond(&w.lock) + + ns.SubscribeField(weightField, func(n *enode.Node, state nodestate.Flags, oldValue, newValue interface{}) { + if state.HasAll(requireFlags) && state.HasNone(disableFlags) { + w.lock.Lock() + w.wrs.Update(n.ID()) + w.lock.Unlock() + w.cond.Signal() + } + }) + + ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState nodestate.Flags) { + oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags) + newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags) + if newMatch == oldMatch { + return + } + + w.lock.Lock() + if newMatch { + w.wrs.Update(n.ID()) + } else { + w.wrs.Remove(n.ID()) + } + w.lock.Unlock() + w.cond.Signal() + }) + return w +} + +// Next selects the next node. +func (w *WrsIterator) Next() bool { + w.nextNode = w.chooseNode() + return w.nextNode != nil +} + +func (w *WrsIterator) chooseNode() *enode.Node { + w.lock.Lock() + defer w.lock.Unlock() + + for { + for !w.closed && w.wrs.IsEmpty() { + w.cond.Wait() + } + if w.closed { + return nil + } + // Choose the next node at random. Even though w.wrs is guaranteed + // non-empty here, Choose might return nil if all items have weight + // zero. + if c := w.wrs.Choose(); c != nil { + id := c.(enode.ID) + w.wrs.Remove(id) + return w.ns.GetNode(id) + } + } + +} + +// Close ends the iterator. +func (w *WrsIterator) Close() { + w.lock.Lock() + w.closed = true + w.lock.Unlock() + w.cond.Signal() +} + +// Node returns the current node. +func (w *WrsIterator) Node() *enode.Node { + w.lock.Lock() + defer w.lock.Unlock() + return w.nextNode +} diff --git a/les/lespay/client/wrsiterator_test.go b/les/lespay/client/wrsiterator_test.go new file mode 100644 index 000000000..77bb5ee0c --- /dev/null +++ b/les/lespay/client/wrsiterator_test.go @@ -0,0 +1,103 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "reflect" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/nodestate" +) + +var ( + testSetup = &nodestate.Setup{} + sfTest1 = testSetup.NewFlag("test1") + sfTest2 = testSetup.NewFlag("test2") + sfTest3 = testSetup.NewFlag("test3") + sfTest4 = testSetup.NewFlag("test4") + sfiTestWeight = testSetup.NewField("nodeWeight", reflect.TypeOf(uint64(0))) +) + +const iterTestNodeCount = 6 + +func TestWrsIterator(t *testing.T) { + ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup) + w := NewWrsIterator(ns, sfTest2, sfTest3.Or(sfTest4), sfiTestWeight) + ns.Start() + for i := 1; i <= iterTestNodeCount; i++ { + ns.SetState(testNode(i), sfTest1, nodestate.Flags{}, 0) + ns.SetField(testNode(i), sfiTestWeight, uint64(1)) + } + next := func() int { + ch := make(chan struct{}) + go func() { + w.Next() + close(ch) + }() + select { + case <-ch: + case <-time.After(time.Second * 5): + t.Fatalf("Iterator.Next() timeout") + } + node := w.Node() + ns.SetState(node, sfTest4, nodestate.Flags{}, 0) + return testNodeIndex(node.ID()) + } + set := make(map[int]bool) + expset := func() { + for len(set) > 0 { + n := next() + if !set[n] { + t.Errorf("Item returned by iterator not in the expected set (got %d)", n) + } + delete(set, n) + } + } + + ns.SetState(testNode(1), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(2), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(3), sfTest2, nodestate.Flags{}, 0) + set[1] = true + set[2] = true + set[3] = true + expset() + ns.SetState(testNode(4), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(5), sfTest2.Or(sfTest3), nodestate.Flags{}, 0) + ns.SetState(testNode(6), sfTest2, nodestate.Flags{}, 0) + set[4] = true + set[6] = true + expset() + ns.SetField(testNode(2), sfiTestWeight, uint64(0)) + ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0) + ns.SetState(testNode(2), nodestate.Flags{}, sfTest4, 0) + ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0) + set[1] = true + set[3] = true + expset() + ns.SetField(testNode(2), sfiTestWeight, uint64(1)) + ns.SetState(testNode(2), nodestate.Flags{}, sfTest2, 0) + ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0) + ns.SetState(testNode(2), sfTest2, sfTest4, 0) + ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0) + set[1] = true + set[2] = true + set[3] = true + expset() + ns.Stop() +} diff --git a/les/metrics.go b/les/metrics.go index 9ef8c3651..c5edb61c3 100644 --- a/les/metrics.go +++ b/les/metrics.go @@ -107,6 +107,13 @@ var ( requestRTT = metrics.NewRegisteredTimer("les/client/req/rtt", nil) requestSendDelay = metrics.NewRegisteredTimer("les/client/req/sendDelay", nil) + + serverSelectableGauge = metrics.NewRegisteredGauge("les/client/serverPool/selectable", nil) + serverDialedMeter = metrics.NewRegisteredMeter("les/client/serverPool/dialed", nil) + serverConnectedGauge = metrics.NewRegisteredGauge("les/client/serverPool/connected", nil) + sessionValueMeter = metrics.NewRegisteredMeter("les/client/serverPool/sessionValue", nil) + totalValueGauge = metrics.NewRegisteredGauge("les/client/serverPool/totalValue", nil) + suggestedTimeoutGauge = metrics.NewRegisteredGauge("les/client/serverPool/timeout", nil) ) // meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of diff --git a/les/peer.go b/les/peer.go index e92b4580d..4793d9026 100644 --- a/les/peer.go +++ b/les/peer.go @@ -336,7 +336,6 @@ type serverPeer struct { checkpointNumber uint64 // The block height which the checkpoint is registered. checkpoint params.TrustedCheckpoint // The advertised checkpoint sent by server. - poolEntry *poolEntry // Statistic for server peer. fcServer *flowcontrol.ServerNode // Client side mirror token bucket. vtLock sync.Mutex valueTracker *lpc.ValueTracker diff --git a/les/protocol.go b/les/protocol.go index f8ad94a7b..4fd19f9be 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -130,7 +130,6 @@ func init() { } requestMapping[uint32(code)] = rm } - } type errCode int diff --git a/les/retrieve.go b/les/retrieve.go index 5fa68b745..4f77004f2 100644 --- a/les/retrieve.go +++ b/les/retrieve.go @@ -24,22 +24,20 @@ import ( "sync" "time" - "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/light" ) var ( retryQueue = time.Millisecond * 100 - softRequestTimeout = time.Millisecond * 500 hardRequestTimeout = time.Second * 10 ) // retrieveManager is a layer on top of requestDistributor which takes care of // matching replies by request ID and handles timeouts and resends if necessary. type retrieveManager struct { - dist *requestDistributor - peers *serverPeerSet - serverPool peerSelector + dist *requestDistributor + peers *serverPeerSet + softRequestTimeout func() time.Duration lock sync.RWMutex sentReqs map[uint64]*sentReq @@ -48,11 +46,6 @@ type retrieveManager struct { // validatorFunc is a function that processes a reply message type validatorFunc func(distPeer, *Msg) error -// peerSelector receives feedback info about response times and timeouts -type peerSelector interface { - adjustResponseTime(*poolEntry, time.Duration, bool) -} - // sentReq represents a request sent and tracked by retrieveManager type sentReq struct { rm *retrieveManager @@ -99,12 +92,12 @@ const ( ) // newRetrieveManager creates the retrieve manager -func newRetrieveManager(peers *serverPeerSet, dist *requestDistributor, serverPool peerSelector) *retrieveManager { +func newRetrieveManager(peers *serverPeerSet, dist *requestDistributor, srto func() time.Duration) *retrieveManager { return &retrieveManager{ - peers: peers, - dist: dist, - serverPool: serverPool, - sentReqs: make(map[uint64]*sentReq), + peers: peers, + dist: dist, + sentReqs: make(map[uint64]*sentReq), + softRequestTimeout: srto, } } @@ -325,8 +318,7 @@ func (r *sentReq) tryRequest() { return } - reqSent := mclock.Now() - srto, hrto := false, false + hrto := false r.lock.RLock() s, ok := r.sentTo[p] @@ -338,11 +330,7 @@ func (r *sentReq) tryRequest() { defer func() { // send feedback to server pool and remove peer if hard timeout happened pp, ok := p.(*serverPeer) - if ok && r.rm.serverPool != nil { - respTime := time.Duration(mclock.Now() - reqSent) - r.rm.serverPool.adjustResponseTime(pp.poolEntry, respTime, srto) - } - if hrto { + if hrto && ok { pp.Log().Debug("Request timed out hard") if r.rm.peers != nil { r.rm.peers.unregister(pp.id) @@ -363,8 +351,7 @@ func (r *sentReq) tryRequest() { } r.eventsCh <- reqPeerEvent{event, p} return - case <-time.After(softRequestTimeout): - srto = true + case <-time.After(r.rm.softRequestTimeout()): r.eventsCh <- reqPeerEvent{rpSoftTimeout, p} } diff --git a/les/server.go b/les/server.go index f72f31321..4b623f61e 100644 --- a/les/server.go +++ b/les/server.go @@ -157,7 +157,7 @@ func (s *LesServer) Protocols() []p2p.Protocol { return p.Info() } return nil - }) + }, nil) // Add "les" ENR entries. for i := range ps { ps[i].Attributes = []enr.Entry{&lesEntry{}} diff --git a/les/serverpool.go b/les/serverpool.go index d1c53295a..aff774324 100644 --- a/les/serverpool.go +++ b/les/serverpool.go @@ -1,4 +1,4 @@ -// Copyright 2016 The go-ethereum Authors +// Copyright 2020 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify @@ -17,904 +17,457 @@ package les import ( - "crypto/ecdsa" - "fmt" - "io" - "math" + "errors" "math/rand" - "net" - "strconv" + "reflect" "sync" + "sync/atomic" "time" "github.com/ethereum/go-ethereum/common/mclock" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" + lpc "github.com/ethereum/go-ethereum/les/lespay/client" "github.com/ethereum/go-ethereum/les/utils" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/nodestate" "github.com/ethereum/go-ethereum/rlp" ) const ( - // After a connection has been ended or timed out, there is a waiting period - // before it can be selected for connection again. - // waiting period = base delay * (1 + random(1)) - // base delay = shortRetryDelay for the first shortRetryCnt times after a - // successful connection, after that longRetryDelay is applied - shortRetryCnt = 5 - shortRetryDelay = time.Second * 5 - longRetryDelay = time.Minute * 10 - // maxNewEntries is the maximum number of newly discovered (never connected) nodes. - // If the limit is reached, the least recently discovered one is thrown out. - maxNewEntries = 1000 - // maxKnownEntries is the maximum number of known (already connected) nodes. - // If the limit is reached, the least recently connected one is thrown out. - // (not that unlike new entries, known entries are persistent) - maxKnownEntries = 1000 - // target for simultaneously connected servers - targetServerCount = 5 - // target for servers selected from the known table - // (we leave room for trying new ones if there is any) - targetKnownSelect = 3 - // after dialTimeout, consider the server unavailable and adjust statistics - dialTimeout = time.Second * 30 - // targetConnTime is the minimum expected connection duration before a server - // drops a client without any specific reason - targetConnTime = time.Minute * 10 - // new entry selection weight calculation based on most recent discovery time: - // unity until discoverExpireStart, then exponential decay with discoverExpireConst - discoverExpireStart = time.Minute * 20 - discoverExpireConst = time.Minute * 20 - // known entry selection weight is dropped by a factor of exp(-failDropLn) after - // each unsuccessful connection (restored after a successful one) - failDropLn = 0.1 - // known node connection success and quality statistics have a long term average - // and a short term value which is adjusted exponentially with a factor of - // pstatRecentAdjust with each dial/connection and also returned exponentially - // to the average with the time constant pstatReturnToMeanTC - pstatReturnToMeanTC = time.Hour - // node address selection weight is dropped by a factor of exp(-addrFailDropLn) after - // each unsuccessful connection (restored after a successful one) - addrFailDropLn = math.Ln2 - // responseScoreTC and delayScoreTC are exponential decay time constants for - // calculating selection chances from response times and block delay times - responseScoreTC = time.Millisecond * 100 - delayScoreTC = time.Second * 5 - timeoutPow = 10 - // initStatsWeight is used to initialize previously unknown peers with good - // statistics to give a chance to prove themselves - initStatsWeight = 1 + minTimeout = time.Millisecond * 500 // minimum request timeout suggested by the server pool + timeoutRefresh = time.Second * 5 // recalculate timeout if older than this + dialCost = 10000 // cost of a TCP dial (used for known node selection weight calculation) + dialWaitStep = 1.5 // exponential multiplier of redial wait time when no value was provided by the server + queryCost = 500 // cost of a UDP pre-negotiation query + queryWaitStep = 1.02 // exponential multiplier of redial wait time when no value was provided by the server + waitThreshold = time.Hour * 2000 // drop node if waiting time is over the threshold + nodeWeightMul = 1000000 // multiplier constant for node weight calculation + nodeWeightThreshold = 100 // minimum weight for keeping a node in the the known (valuable) set + minRedialWait = 10 // minimum redial wait time in seconds + preNegLimit = 5 // maximum number of simultaneous pre-negotiation queries + maxQueryFails = 100 // number of consecutive UDP query failures before we print a warning ) -// connReq represents a request for peer connection. -type connReq struct { - p *serverPeer - node *enode.Node - result chan *poolEntry -} - -// disconnReq represents a request for peer disconnection. -type disconnReq struct { - entry *poolEntry - stopped bool - done chan struct{} -} - -// registerReq represents a request for peer registration. -type registerReq struct { - entry *poolEntry - done chan struct{} -} - -// serverPool implements a pool for storing and selecting newly discovered and already -// known light server nodes. It received discovered nodes, stores statistics about -// known nodes and takes care of always having enough good quality servers connected. +// serverPool provides a node iterator for dial candidates. The output is a mix of newly discovered +// nodes, a weighted random selection of known (previously valuable) nodes and trusted/paid nodes. type serverPool struct { - db ethdb.Database - dbKey []byte - server *p2p.Server - connWg sync.WaitGroup + clock mclock.Clock + unixTime func() int64 + db ethdb.KeyValueStore - topic discv5.Topic + ns *nodestate.NodeStateMachine + vt *lpc.ValueTracker + mixer *enode.FairMix + mixSources []enode.Iterator + dialIterator enode.Iterator + validSchemes enr.IdentityScheme + trustedURLs []string + fillSet *lpc.FillSet + queryFails uint32 - discSetPeriod chan time.Duration - discNodes chan *enode.Node - discLookups chan bool - - trustedNodes map[enode.ID]*enode.Node - entries map[enode.ID]*poolEntry - timeout, enableRetry chan *poolEntry - adjustStats chan poolStatAdjust - - knownQueue, newQueue poolEntryQueue - knownSelect, newSelect *utils.WeightedRandomSelect - knownSelected, newSelected int - fastDiscover bool - connCh chan *connReq - disconnCh chan *disconnReq - registerCh chan *registerReq - - closeCh chan struct{} - wg sync.WaitGroup + timeoutLock sync.RWMutex + timeout time.Duration + timeWeights lpc.ResponseTimeWeights + timeoutRefreshed mclock.AbsTime } -// newServerPool creates a new serverPool instance -func newServerPool(db ethdb.Database, ulcServers []string) *serverPool { - pool := &serverPool{ +// nodeHistory keeps track of dial costs which determine node weight together with the +// service value calculated by lpc.ValueTracker. +type nodeHistory struct { + dialCost utils.ExpiredValue + redialWaitStart, redialWaitEnd int64 // unix time (seconds) +} + +type nodeHistoryEnc struct { + DialCost utils.ExpiredValue + RedialWaitStart, RedialWaitEnd uint64 +} + +// queryFunc sends a pre-negotiation query and blocks until a response arrives or timeout occurs. +// It returns 1 if the remote node has confirmed that connection is possible, 0 if not +// possible and -1 if no response arrived (timeout). +type queryFunc func(*enode.Node) int + +var ( + serverPoolSetup = &nodestate.Setup{Version: 1} + sfHasValue = serverPoolSetup.NewPersistentFlag("hasValue") + sfQueried = serverPoolSetup.NewFlag("queried") + sfCanDial = serverPoolSetup.NewFlag("canDial") + sfDialing = serverPoolSetup.NewFlag("dialed") + sfWaitDialTimeout = serverPoolSetup.NewFlag("dialTimeout") + sfConnected = serverPoolSetup.NewFlag("connected") + sfRedialWait = serverPoolSetup.NewFlag("redialWait") + sfAlwaysConnect = serverPoolSetup.NewFlag("alwaysConnect") + sfDisableSelection = nodestate.MergeFlags(sfQueried, sfCanDial, sfDialing, sfConnected, sfRedialWait) + + sfiNodeHistory = serverPoolSetup.NewPersistentField("nodeHistory", reflect.TypeOf(nodeHistory{}), + func(field interface{}) ([]byte, error) { + if n, ok := field.(nodeHistory); ok { + ne := nodeHistoryEnc{ + DialCost: n.dialCost, + RedialWaitStart: uint64(n.redialWaitStart), + RedialWaitEnd: uint64(n.redialWaitEnd), + } + enc, err := rlp.EncodeToBytes(&ne) + return enc, err + } else { + return nil, errors.New("invalid field type") + } + }, + func(enc []byte) (interface{}, error) { + var ne nodeHistoryEnc + err := rlp.DecodeBytes(enc, &ne) + n := nodeHistory{ + dialCost: ne.DialCost, + redialWaitStart: int64(ne.RedialWaitStart), + redialWaitEnd: int64(ne.RedialWaitEnd), + } + return n, err + }, + ) + sfiNodeWeight = serverPoolSetup.NewField("nodeWeight", reflect.TypeOf(uint64(0))) + sfiConnectedStats = serverPoolSetup.NewField("connectedStats", reflect.TypeOf(lpc.ResponseTimeStats{})) +) + +// newServerPool creates a new server pool +func newServerPool(db ethdb.KeyValueStore, dbKey []byte, vt *lpc.ValueTracker, discovery enode.Iterator, mixTimeout time.Duration, query queryFunc, clock mclock.Clock, trustedURLs []string) *serverPool { + s := &serverPool{ db: db, - entries: make(map[enode.ID]*poolEntry), - timeout: make(chan *poolEntry, 1), - adjustStats: make(chan poolStatAdjust, 100), - enableRetry: make(chan *poolEntry, 1), - connCh: make(chan *connReq), - disconnCh: make(chan *disconnReq), - registerCh: make(chan *registerReq), - closeCh: make(chan struct{}), - knownSelect: utils.NewWeightedRandomSelect(), - newSelect: utils.NewWeightedRandomSelect(), - fastDiscover: true, - trustedNodes: parseTrustedNodes(ulcServers), + clock: clock, + unixTime: func() int64 { return time.Now().Unix() }, + validSchemes: enode.ValidSchemes, + trustedURLs: trustedURLs, + vt: vt, + ns: nodestate.NewNodeStateMachine(db, []byte(string(dbKey)+"ns:"), clock, serverPoolSetup), + } + s.recalTimeout() + s.mixer = enode.NewFairMix(mixTimeout) + knownSelector := lpc.NewWrsIterator(s.ns, sfHasValue, sfDisableSelection, sfiNodeWeight) + alwaysConnect := lpc.NewQueueIterator(s.ns, sfAlwaysConnect, sfDisableSelection, true, nil) + s.mixSources = append(s.mixSources, knownSelector) + s.mixSources = append(s.mixSources, alwaysConnect) + if discovery != nil { + s.mixSources = append(s.mixSources, discovery) } - pool.knownQueue = newPoolEntryQueue(maxKnownEntries, pool.removeEntry) - pool.newQueue = newPoolEntryQueue(maxNewEntries, pool.removeEntry) - return pool + iter := enode.Iterator(s.mixer) + if query != nil { + iter = s.addPreNegFilter(iter, query) + } + s.dialIterator = enode.Filter(iter, func(node *enode.Node) bool { + s.ns.SetState(node, sfDialing, sfCanDial, 0) + s.ns.SetState(node, sfWaitDialTimeout, nodestate.Flags{}, time.Second*10) + return true + }) + + s.ns.SubscribeState(nodestate.MergeFlags(sfWaitDialTimeout, sfConnected), func(n *enode.Node, oldState, newState nodestate.Flags) { + if oldState.Equals(sfWaitDialTimeout) && newState.IsEmpty() { + // dial timeout, no connection + s.setRedialWait(n, dialCost, dialWaitStep) + s.ns.SetState(n, nodestate.Flags{}, sfDialing, 0) + } + }) + + s.ns.AddLogMetrics(sfHasValue, sfDisableSelection, "selectable", nil, nil, serverSelectableGauge) + s.ns.AddLogMetrics(sfDialing, nodestate.Flags{}, "dialed", serverDialedMeter, nil, nil) + s.ns.AddLogMetrics(sfConnected, nodestate.Flags{}, "connected", nil, nil, serverConnectedGauge) + return s } -func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) { - pool.server = server - pool.topic = topic - pool.dbKey = append([]byte("serverPool/"), []byte(topic)...) - pool.loadNodes() - pool.connectToTrustedNodes() - - if pool.server.DiscV5 != nil { - pool.discSetPeriod = make(chan time.Duration, 1) - pool.discNodes = make(chan *enode.Node, 100) - pool.discLookups = make(chan bool, 100) - go pool.discoverNodes() - } - pool.checkDial() - pool.wg.Add(1) - go pool.eventLoop() - - // Inject the bootstrap nodes as initial dial candiates. - pool.wg.Add(1) - go func() { - defer pool.wg.Done() - for _, n := range server.BootstrapNodes { - select { - case pool.discNodes <- n: - case <-pool.closeCh: +// addPreNegFilter installs a node filter mechanism that performs a pre-negotiation query. +// Nodes that are filtered out and does not appear on the output iterator are put back +// into redialWait state. +func (s *serverPool) addPreNegFilter(input enode.Iterator, query queryFunc) enode.Iterator { + s.fillSet = lpc.NewFillSet(s.ns, input, sfQueried) + s.ns.SubscribeState(sfQueried, func(n *enode.Node, oldState, newState nodestate.Flags) { + if newState.Equals(sfQueried) { + fails := atomic.LoadUint32(&s.queryFails) + if fails == maxQueryFails { + log.Warn("UDP pre-negotiation query does not seem to work") + } + if fails > maxQueryFails { + fails = maxQueryFails + } + if rand.Intn(maxQueryFails*2) < int(fails) { + // skip pre-negotiation with increasing chance, max 50% + // this ensures that the client can operate even if UDP is not working at all + s.ns.SetState(n, sfCanDial, nodestate.Flags{}, time.Second*10) + // set canDial before resetting queried so that FillSet will not read more + // candidates unnecessarily + s.ns.SetState(n, nodestate.Flags{}, sfQueried, 0) return } - } - }() -} - -func (pool *serverPool) stop() { - close(pool.closeCh) - pool.wg.Wait() -} - -// discoverNodes wraps SearchTopic, converting result nodes to enode.Node. -func (pool *serverPool) discoverNodes() { - ch := make(chan *discv5.Node) - go func() { - pool.server.DiscV5.SearchTopic(pool.topic, pool.discSetPeriod, ch, pool.discLookups) - close(ch) - }() - for n := range ch { - pubkey, err := decodePubkey64(n.ID[:]) - if err != nil { - continue - } - pool.discNodes <- enode.NewV4(pubkey, n.IP, int(n.TCP), int(n.UDP)) - } -} - -// connect should be called upon any incoming connection. If the connection has been -// dialed by the server pool recently, the appropriate pool entry is returned. -// Otherwise, the connection should be rejected. -// Note that whenever a connection has been accepted and a pool entry has been returned, -// disconnect should also always be called. -func (pool *serverPool) connect(p *serverPeer, node *enode.Node) *poolEntry { - log.Debug("Connect new entry", "enode", p.id) - req := &connReq{p: p, node: node, result: make(chan *poolEntry, 1)} - select { - case pool.connCh <- req: - case <-pool.closeCh: - return nil - } - return <-req.result -} - -// registered should be called after a successful handshake -func (pool *serverPool) registered(entry *poolEntry) { - log.Debug("Registered new entry", "enode", entry.node.ID()) - req := ®isterReq{entry: entry, done: make(chan struct{})} - select { - case pool.registerCh <- req: - case <-pool.closeCh: - return - } - <-req.done -} - -// disconnect should be called when ending a connection. Service quality statistics -// can be updated optionally (not updated if no registration happened, in this case -// only connection statistics are updated, just like in case of timeout) -func (pool *serverPool) disconnect(entry *poolEntry) { - stopped := false - select { - case <-pool.closeCh: - stopped = true - default: - } - log.Debug("Disconnected old entry", "enode", entry.node.ID()) - req := &disconnReq{entry: entry, stopped: stopped, done: make(chan struct{})} - - // Block until disconnection request is served. - pool.disconnCh <- req - <-req.done -} - -const ( - pseBlockDelay = iota - pseResponseTime - pseResponseTimeout -) - -// poolStatAdjust records are sent to adjust peer block delay/response time statistics -type poolStatAdjust struct { - adjustType int - entry *poolEntry - time time.Duration -} - -// adjustBlockDelay adjusts the block announce delay statistics of a node -func (pool *serverPool) adjustBlockDelay(entry *poolEntry, time time.Duration) { - if entry == nil { - return - } - pool.adjustStats <- poolStatAdjust{pseBlockDelay, entry, time} -} - -// adjustResponseTime adjusts the request response time statistics of a node -func (pool *serverPool) adjustResponseTime(entry *poolEntry, time time.Duration, timeout bool) { - if entry == nil { - return - } - if timeout { - pool.adjustStats <- poolStatAdjust{pseResponseTimeout, entry, time} - } else { - pool.adjustStats <- poolStatAdjust{pseResponseTime, entry, time} - } -} - -// eventLoop handles pool events and mutex locking for all internal functions -func (pool *serverPool) eventLoop() { - defer pool.wg.Done() - lookupCnt := 0 - var convTime mclock.AbsTime - if pool.discSetPeriod != nil { - pool.discSetPeriod <- time.Millisecond * 100 - } - - // disconnect updates service quality statistics depending on the connection time - // and disconnection initiator. - disconnect := func(req *disconnReq, stopped bool) { - // Handle peer disconnection requests. - entry := req.entry - if entry.state == psRegistered { - connAdjust := float64(mclock.Now()-entry.regTime) / float64(targetConnTime) - if connAdjust > 1 { - connAdjust = 1 - } - if stopped { - // disconnect requested by ourselves. - entry.connectStats.add(1, connAdjust) - } else { - // disconnect requested by server side. - entry.connectStats.add(connAdjust, 1) - } - } - entry.state = psNotConnected - - if entry.knownSelected { - pool.knownSelected-- - } else { - pool.newSelected-- - } - pool.setRetryDial(entry) - pool.connWg.Done() - close(req.done) - } - - for { - select { - case entry := <-pool.timeout: - if !entry.removed { - pool.checkDialTimeout(entry) - } - - case entry := <-pool.enableRetry: - if !entry.removed { - entry.delayedRetry = false - pool.updateCheckDial(entry) - } - - case adj := <-pool.adjustStats: - switch adj.adjustType { - case pseBlockDelay: - adj.entry.delayStats.add(float64(adj.time), 1) - case pseResponseTime: - adj.entry.responseStats.add(float64(adj.time), 1) - adj.entry.timeoutStats.add(0, 1) - case pseResponseTimeout: - adj.entry.timeoutStats.add(1, 1) - } - - case node := <-pool.discNodes: - if pool.trustedNodes[node.ID()] == nil { - entry := pool.findOrNewNode(node) - pool.updateCheckDial(entry) - } - - case conv := <-pool.discLookups: - if conv { - if lookupCnt == 0 { - convTime = mclock.Now() - } - lookupCnt++ - if pool.fastDiscover && (lookupCnt == 50 || time.Duration(mclock.Now()-convTime) > time.Minute) { - pool.fastDiscover = false - if pool.discSetPeriod != nil { - pool.discSetPeriod <- time.Minute - } - } - } - - case req := <-pool.connCh: - if pool.trustedNodes[req.p.ID()] != nil { - // ignore trusted nodes - req.result <- &poolEntry{trusted: true} - } else { - // Handle peer connection requests. - entry := pool.entries[req.p.ID()] - if entry == nil { - entry = pool.findOrNewNode(req.node) - } - if entry.state == psConnected || entry.state == psRegistered { - req.result <- nil - continue - } - pool.connWg.Add(1) - entry.peer = req.p - entry.state = psConnected - addr := &poolEntryAddress{ - ip: req.node.IP(), - port: uint16(req.node.TCP()), - lastSeen: mclock.Now(), - } - entry.lastConnected = addr - entry.addr = make(map[string]*poolEntryAddress) - entry.addr[addr.strKey()] = addr - entry.addrSelect = *utils.NewWeightedRandomSelect() - entry.addrSelect.Update(addr) - req.result <- entry - } - - case req := <-pool.registerCh: - if req.entry.trusted { - continue - } - // Handle peer registration requests. - entry := req.entry - entry.state = psRegistered - entry.regTime = mclock.Now() - if !entry.known { - pool.newQueue.remove(entry) - entry.known = true - } - pool.knownQueue.setLatest(entry) - entry.shortRetry = shortRetryCnt - close(req.done) - - case req := <-pool.disconnCh: - if req.entry.trusted { - continue - } - // Handle peer disconnection requests. - disconnect(req, req.stopped) - - case <-pool.closeCh: - if pool.discSetPeriod != nil { - close(pool.discSetPeriod) - } - - // Spawn a goroutine to close the disconnCh after all connections are disconnected. go func() { - pool.connWg.Wait() - close(pool.disconnCh) + q := query(n) + if q == -1 { + atomic.AddUint32(&s.queryFails, 1) + } else { + atomic.StoreUint32(&s.queryFails, 0) + } + if q == 1 { + s.ns.SetState(n, sfCanDial, nodestate.Flags{}, time.Second*10) + } else { + s.setRedialWait(n, queryCost, queryWaitStep) + } + s.ns.SetState(n, nodestate.Flags{}, sfQueried, 0) }() - - // Handle all remaining disconnection requests before exit. - for req := range pool.disconnCh { - disconnect(req, true) - } - pool.saveNodes() - return } - } -} - -func (pool *serverPool) findOrNewNode(node *enode.Node) *poolEntry { - now := mclock.Now() - entry := pool.entries[node.ID()] - if entry == nil { - log.Debug("Discovered new entry", "id", node.ID()) - entry = &poolEntry{ - node: node, - addr: make(map[string]*poolEntryAddress), - addrSelect: *utils.NewWeightedRandomSelect(), - shortRetry: shortRetryCnt, + }) + return lpc.NewQueueIterator(s.ns, sfCanDial, nodestate.Flags{}, false, func(waiting bool) { + if waiting { + s.fillSet.SetTarget(preNegLimit) + } else { + s.fillSet.SetTarget(0) } - pool.entries[node.ID()] = entry - // initialize previously unknown peers with good statistics to give a chance to prove themselves - entry.connectStats.add(1, initStatsWeight) - entry.delayStats.add(0, initStatsWeight) - entry.responseStats.add(0, initStatsWeight) - entry.timeoutStats.add(0, initStatsWeight) - } - entry.lastDiscovered = now - addr := &poolEntryAddress{ip: node.IP(), port: uint16(node.TCP())} - if a, ok := entry.addr[addr.strKey()]; ok { - addr = a - } else { - entry.addr[addr.strKey()] = addr - } - addr.lastSeen = now - entry.addrSelect.Update(addr) - if !entry.known { - pool.newQueue.setLatest(entry) - } - return entry -} - -// loadNodes loads known nodes and their statistics from the database -func (pool *serverPool) loadNodes() { - enc, err := pool.db.Get(pool.dbKey) - if err != nil { - return - } - var list []*poolEntry - err = rlp.DecodeBytes(enc, &list) - if err != nil { - log.Debug("Failed to decode node list", "err", err) - return - } - for _, e := range list { - log.Debug("Loaded server stats", "id", e.node.ID(), "fails", e.lastConnected.fails, - "conn", fmt.Sprintf("%v/%v", e.connectStats.avg, e.connectStats.weight), - "delay", fmt.Sprintf("%v/%v", time.Duration(e.delayStats.avg), e.delayStats.weight), - "response", fmt.Sprintf("%v/%v", time.Duration(e.responseStats.avg), e.responseStats.weight), - "timeout", fmt.Sprintf("%v/%v", e.timeoutStats.avg, e.timeoutStats.weight)) - pool.entries[e.node.ID()] = e - if pool.trustedNodes[e.node.ID()] == nil { - pool.knownQueue.setLatest(e) - pool.knownSelect.Update((*knownEntry)(e)) - } - } -} - -// connectToTrustedNodes adds trusted server nodes as static trusted peers. -// -// Note: trusted nodes are not handled by the server pool logic, they are not -// added to either the known or new selection pools. They are connected/reconnected -// by p2p.Server whenever possible. -func (pool *serverPool) connectToTrustedNodes() { - //connect to trusted nodes - for _, node := range pool.trustedNodes { - pool.server.AddTrustedPeer(node) - pool.server.AddPeer(node) - log.Debug("Added trusted node", "id", node.ID().String()) - } -} - -// parseTrustedNodes returns valid and parsed enodes -func parseTrustedNodes(trustedNodes []string) map[enode.ID]*enode.Node { - nodes := make(map[enode.ID]*enode.Node) - - for _, node := range trustedNodes { - node, err := enode.Parse(enode.ValidSchemes, node) - if err != nil { - log.Warn("Trusted node URL invalid", "enode", node, "err", err) - continue - } - nodes[node.ID()] = node - } - return nodes -} - -// saveNodes saves known nodes and their statistics into the database. Nodes are -// ordered from least to most recently connected. -func (pool *serverPool) saveNodes() { - list := make([]*poolEntry, len(pool.knownQueue.queue)) - for i := range list { - list[i] = pool.knownQueue.fetchOldest() - } - enc, err := rlp.EncodeToBytes(list) - if err == nil { - pool.db.Put(pool.dbKey, enc) - } -} - -// removeEntry removes a pool entry when the entry count limit is reached. -// Note that it is called by the new/known queues from which the entry has already -// been removed so removing it from the queues is not necessary. -func (pool *serverPool) removeEntry(entry *poolEntry) { - pool.newSelect.Remove((*discoveredEntry)(entry)) - pool.knownSelect.Remove((*knownEntry)(entry)) - entry.removed = true - delete(pool.entries, entry.node.ID()) -} - -// setRetryDial starts the timer which will enable dialing a certain node again -func (pool *serverPool) setRetryDial(entry *poolEntry) { - delay := longRetryDelay - if entry.shortRetry > 0 { - entry.shortRetry-- - delay = shortRetryDelay - } - delay += time.Duration(rand.Int63n(int64(delay) + 1)) - entry.delayedRetry = true - go func() { - select { - case <-pool.closeCh: - case <-time.After(delay): - select { - case <-pool.closeCh: - case pool.enableRetry <- entry: - } - } - }() -} - -// updateCheckDial is called when an entry can potentially be dialed again. It updates -// its selection weights and checks if new dials can/should be made. -func (pool *serverPool) updateCheckDial(entry *poolEntry) { - pool.newSelect.Update((*discoveredEntry)(entry)) - pool.knownSelect.Update((*knownEntry)(entry)) - pool.checkDial() -} - -// checkDial checks if new dials can/should be made. It tries to select servers both -// based on good statistics and recent discovery. -func (pool *serverPool) checkDial() { - fillWithKnownSelects := !pool.fastDiscover - for pool.knownSelected < targetKnownSelect { - entry := pool.knownSelect.Choose() - if entry == nil { - fillWithKnownSelects = false - break - } - pool.dial((*poolEntry)(entry.(*knownEntry)), true) - } - for pool.knownSelected+pool.newSelected < targetServerCount { - entry := pool.newSelect.Choose() - if entry == nil { - break - } - pool.dial((*poolEntry)(entry.(*discoveredEntry)), false) - } - if fillWithKnownSelects { - // no more newly discovered nodes to select and since fast discover period - // is over, we probably won't find more in the near future so select more - // known entries if possible - for pool.knownSelected < targetServerCount { - entry := pool.knownSelect.Choose() - if entry == nil { - break - } - pool.dial((*poolEntry)(entry.(*knownEntry)), true) - } - } -} - -// dial initiates a new connection -func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) { - if pool.server == nil || entry.state != psNotConnected { - return - } - entry.state = psDialed - entry.knownSelected = knownSelected - if knownSelected { - pool.knownSelected++ - } else { - pool.newSelected++ - } - addr := entry.addrSelect.Choose().(*poolEntryAddress) - log.Debug("Dialing new peer", "lesaddr", entry.node.ID().String()+"@"+addr.strKey(), "set", len(entry.addr), "known", knownSelected) - entry.dialed = addr - go func() { - pool.server.AddPeer(entry.node) - select { - case <-pool.closeCh: - case <-time.After(dialTimeout): - select { - case <-pool.closeCh: - case pool.timeout <- entry: - } - } - }() -} - -// checkDialTimeout checks if the node is still in dialed state and if so, resets it -// and adjusts connection statistics accordingly. -func (pool *serverPool) checkDialTimeout(entry *poolEntry) { - if entry.state != psDialed { - return - } - log.Debug("Dial timeout", "lesaddr", entry.node.ID().String()+"@"+entry.dialed.strKey()) - entry.state = psNotConnected - if entry.knownSelected { - pool.knownSelected-- - } else { - pool.newSelected-- - } - entry.connectStats.add(0, 1) - entry.dialed.fails++ - pool.setRetryDial(entry) -} - -const ( - psNotConnected = iota - psDialed - psConnected - psRegistered -) - -// poolEntry represents a server node and stores its current state and statistics. -type poolEntry struct { - peer *serverPeer - pubkey [64]byte // secp256k1 key of the node - addr map[string]*poolEntryAddress - node *enode.Node - lastConnected, dialed *poolEntryAddress - addrSelect utils.WeightedRandomSelect - - lastDiscovered mclock.AbsTime - known, knownSelected, trusted bool - connectStats, delayStats poolStats - responseStats, timeoutStats poolStats - state int - regTime mclock.AbsTime - queueIdx int - removed bool - - delayedRetry bool - shortRetry int -} - -// poolEntryEnc is the RLP encoding of poolEntry. -type poolEntryEnc struct { - Pubkey []byte - IP net.IP - Port uint16 - Fails uint - CStat, DStat, RStat, TStat poolStats -} - -func (e *poolEntry) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, &poolEntryEnc{ - Pubkey: encodePubkey64(e.node.Pubkey()), - IP: e.lastConnected.ip, - Port: e.lastConnected.port, - Fails: e.lastConnected.fails, - CStat: e.connectStats, - DStat: e.delayStats, - RStat: e.responseStats, - TStat: e.timeoutStats, }) } -func (e *poolEntry) DecodeRLP(s *rlp.Stream) error { - var entry poolEntryEnc - if err := s.Decode(&entry); err != nil { - return err +// start starts the server pool. Note that NodeStateMachine should be started first. +func (s *serverPool) start() { + s.ns.Start() + for _, iter := range s.mixSources { + // add sources to mixer at startup because the mixer instantly tries to read them + // which should only happen after NodeStateMachine has been started + s.mixer.AddSource(iter) } - pubkey, err := decodePubkey64(entry.Pubkey) - if err != nil { - return err - } - addr := &poolEntryAddress{ip: entry.IP, port: entry.Port, fails: entry.Fails, lastSeen: mclock.Now()} - e.node = enode.NewV4(pubkey, entry.IP, int(entry.Port), int(entry.Port)) - e.addr = make(map[string]*poolEntryAddress) - e.addr[addr.strKey()] = addr - e.addrSelect = *utils.NewWeightedRandomSelect() - e.addrSelect.Update(addr) - e.lastConnected = addr - e.connectStats = entry.CStat - e.delayStats = entry.DStat - e.responseStats = entry.RStat - e.timeoutStats = entry.TStat - e.shortRetry = shortRetryCnt - e.known = true - return nil -} - -func encodePubkey64(pub *ecdsa.PublicKey) []byte { - return crypto.FromECDSAPub(pub)[1:] -} - -func decodePubkey64(b []byte) (*ecdsa.PublicKey, error) { - return crypto.UnmarshalPubkey(append([]byte{0x04}, b...)) -} - -// discoveredEntry implements wrsItem -type discoveredEntry poolEntry - -// Weight calculates random selection weight for newly discovered entries -func (e *discoveredEntry) Weight() int64 { - if e.state != psNotConnected || e.delayedRetry { - return 0 - } - t := time.Duration(mclock.Now() - e.lastDiscovered) - if t <= discoverExpireStart { - return 1000000000 - } - return int64(1000000000 * math.Exp(-float64(t-discoverExpireStart)/float64(discoverExpireConst))) -} - -// knownEntry implements wrsItem -type knownEntry poolEntry - -// Weight calculates random selection weight for known entries -func (e *knownEntry) Weight() int64 { - if e.state != psNotConnected || !e.known || e.delayedRetry { - return 0 - } - return int64(1000000000 * e.connectStats.recentAvg() * math.Exp(-float64(e.lastConnected.fails)*failDropLn-e.responseStats.recentAvg()/float64(responseScoreTC)-e.delayStats.recentAvg()/float64(delayScoreTC)) * math.Pow(1-e.timeoutStats.recentAvg(), timeoutPow)) -} - -// poolEntryAddress is a separate object because currently it is necessary to remember -// multiple potential network addresses for a pool entry. This will be removed after -// the final implementation of v5 discovery which will retrieve signed and serial -// numbered advertisements, making it clear which IP/port is the latest one. -type poolEntryAddress struct { - ip net.IP - port uint16 - lastSeen mclock.AbsTime // last time it was discovered, connected or loaded from db - fails uint // connection failures since last successful connection (persistent) -} - -func (a *poolEntryAddress) Weight() int64 { - t := time.Duration(mclock.Now() - a.lastSeen) - return int64(1000000*math.Exp(-float64(t)/float64(discoverExpireConst)-float64(a.fails)*addrFailDropLn)) + 1 -} - -func (a *poolEntryAddress) strKey() string { - return a.ip.String() + ":" + strconv.Itoa(int(a.port)) -} - -// poolStats implement statistics for a certain quantity with a long term average -// and a short term value which is adjusted exponentially with a factor of -// pstatRecentAdjust with each update and also returned exponentially to the -// average with the time constant pstatReturnToMeanTC -type poolStats struct { - sum, weight, avg, recent float64 - lastRecalc mclock.AbsTime -} - -// init initializes stats with a long term sum/update count pair retrieved from the database -func (s *poolStats) init(sum, weight float64) { - s.sum = sum - s.weight = weight - var avg float64 - if weight > 0 { - avg = s.sum / weight - } - s.avg = avg - s.recent = avg - s.lastRecalc = mclock.Now() -} - -// recalc recalculates recent value return-to-mean and long term average -func (s *poolStats) recalc() { - now := mclock.Now() - s.recent = s.avg + (s.recent-s.avg)*math.Exp(-float64(now-s.lastRecalc)/float64(pstatReturnToMeanTC)) - if s.sum == 0 { - s.avg = 0 - } else { - if s.sum > s.weight*1e30 { - s.avg = 1e30 + for _, url := range s.trustedURLs { + if node, err := enode.Parse(s.validSchemes, url); err == nil { + s.ns.SetState(node, sfAlwaysConnect, nodestate.Flags{}, 0) } else { - s.avg = s.sum / s.weight + log.Error("Invalid trusted server URL", "url", url, "error", err) } } - s.lastRecalc = now -} - -// add updates the stats with a new value -func (s *poolStats) add(value, weight float64) { - s.weight += weight - s.sum += value * weight - s.recalc() -} - -// recentAvg returns the short-term adjusted average -func (s *poolStats) recentAvg() float64 { - s.recalc() - return s.recent -} - -func (s *poolStats) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, []interface{}{math.Float64bits(s.sum), math.Float64bits(s.weight)}) -} - -func (s *poolStats) DecodeRLP(st *rlp.Stream) error { - var stats struct { - SumUint, WeightUint uint64 - } - if err := st.Decode(&stats); err != nil { - return err - } - s.init(math.Float64frombits(stats.SumUint), math.Float64frombits(stats.WeightUint)) - return nil -} - -// poolEntryQueue keeps track of its least recently accessed entries and removes -// them when the number of entries reaches the limit -type poolEntryQueue struct { - queue map[int]*poolEntry // known nodes indexed by their latest lastConnCnt value - newPtr, oldPtr, maxCnt int - removeFromPool func(*poolEntry) -} - -// newPoolEntryQueue returns a new poolEntryQueue -func newPoolEntryQueue(maxCnt int, removeFromPool func(*poolEntry)) poolEntryQueue { - return poolEntryQueue{queue: make(map[int]*poolEntry), maxCnt: maxCnt, removeFromPool: removeFromPool} -} - -// fetchOldest returns and removes the least recently accessed entry -func (q *poolEntryQueue) fetchOldest() *poolEntry { - if len(q.queue) == 0 { - return nil - } - for { - if e := q.queue[q.oldPtr]; e != nil { - delete(q.queue, q.oldPtr) - q.oldPtr++ - return e + unixTime := s.unixTime() + s.ns.ForEach(sfHasValue, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { + s.calculateWeight(node) + if n, ok := s.ns.GetField(node, sfiNodeHistory).(nodeHistory); ok && n.redialWaitEnd > unixTime { + wait := n.redialWaitEnd - unixTime + lastWait := n.redialWaitEnd - n.redialWaitStart + if wait > lastWait { + // if the time until expiration is larger than the last suggested + // waiting time then the system clock was probably adjusted + wait = lastWait + } + s.ns.SetState(node, sfRedialWait, nodestate.Flags{}, time.Duration(wait)*time.Second) } - q.oldPtr++ - } + }) } -// remove removes an entry from the queue -func (q *poolEntryQueue) remove(entry *poolEntry) { - if q.queue[entry.queueIdx] == entry { - delete(q.queue, entry.queueIdx) +// stop stops the server pool +func (s *serverPool) stop() { + s.dialIterator.Close() + if s.fillSet != nil { + s.fillSet.Close() } + s.ns.ForEach(sfConnected, nodestate.Flags{}, func(n *enode.Node, state nodestate.Flags) { + // recalculate weight of connected nodes in order to update hasValue flag if necessary + s.calculateWeight(n) + }) + s.ns.Stop() } -// setLatest adds or updates a recently accessed entry. It also checks if an old entry -// needs to be removed and removes it from the parent pool too with a callback function. -func (q *poolEntryQueue) setLatest(entry *poolEntry) { - if q.queue[entry.queueIdx] == entry { - delete(q.queue, entry.queueIdx) +// registerPeer implements serverPeerSubscriber +func (s *serverPool) registerPeer(p *serverPeer) { + s.ns.SetState(p.Node(), sfConnected, sfDialing.Or(sfWaitDialTimeout), 0) + nvt := s.vt.Register(p.ID()) + s.ns.SetField(p.Node(), sfiConnectedStats, nvt.RtStats()) + p.setValueTracker(s.vt, nvt) + p.updateVtParams() +} + +// unregisterPeer implements serverPeerSubscriber +func (s *serverPool) unregisterPeer(p *serverPeer) { + s.setRedialWait(p.Node(), dialCost, dialWaitStep) + s.ns.SetState(p.Node(), nodestate.Flags{}, sfConnected, 0) + s.ns.SetField(p.Node(), sfiConnectedStats, nil) + s.vt.Unregister(p.ID()) + p.setValueTracker(nil, nil) +} + +// recalTimeout calculates the current recommended timeout. This value is used by +// the client as a "soft timeout" value. It also affects the service value calculation +// of individual nodes. +func (s *serverPool) recalTimeout() { + // Use cached result if possible, avoid recalculating too frequently. + s.timeoutLock.RLock() + refreshed := s.timeoutRefreshed + s.timeoutLock.RUnlock() + now := s.clock.Now() + if refreshed != 0 && time.Duration(now-refreshed) < timeoutRefresh { + return + } + // Cached result is stale, recalculate a new one. + rts := s.vt.RtStats() + + // Add a fake statistic here. It is an easy way to initialize with some + // conservative values when the database is new. As soon as we have a + // considerable amount of real stats this small value won't matter. + rts.Add(time.Second*2, 10, s.vt.StatsExpFactor()) + + // Use either 10% failure rate timeout or twice the median response time + // as the recommended timeout. + timeout := minTimeout + if t := rts.Timeout(0.1); t > timeout { + timeout = t + } + if t := rts.Timeout(0.5) * 2; t > timeout { + timeout = t + } + s.timeoutLock.Lock() + if s.timeout != timeout { + s.timeout = timeout + s.timeWeights = lpc.TimeoutWeights(s.timeout) + + suggestedTimeoutGauge.Update(int64(s.timeout / time.Millisecond)) + totalValueGauge.Update(int64(rts.Value(s.timeWeights, s.vt.StatsExpFactor()))) + } + s.timeoutRefreshed = now + s.timeoutLock.Unlock() +} + +// getTimeout returns the recommended request timeout. +func (s *serverPool) getTimeout() time.Duration { + s.recalTimeout() + s.timeoutLock.RLock() + defer s.timeoutLock.RUnlock() + return s.timeout +} + +// getTimeoutAndWeight returns the recommended request timeout as well as the +// response time weight which is necessary to calculate service value. +func (s *serverPool) getTimeoutAndWeight() (time.Duration, lpc.ResponseTimeWeights) { + s.recalTimeout() + s.timeoutLock.RLock() + defer s.timeoutLock.RUnlock() + return s.timeout, s.timeWeights +} + +// addDialCost adds the given amount of dial cost to the node history and returns the current +// amount of total dial cost +func (s *serverPool) addDialCost(n *nodeHistory, amount int64) uint64 { + logOffset := s.vt.StatsExpirer().LogOffset(s.clock.Now()) + if amount > 0 { + n.dialCost.Add(amount, logOffset) + } + totalDialCost := n.dialCost.Value(logOffset) + if totalDialCost < dialCost { + totalDialCost = dialCost + } + return totalDialCost +} + +// serviceValue returns the service value accumulated in this session and in total +func (s *serverPool) serviceValue(node *enode.Node) (sessionValue, totalValue float64) { + nvt := s.vt.GetNode(node.ID()) + if nvt == nil { + return 0, 0 + } + currentStats := nvt.RtStats() + _, timeWeights := s.getTimeoutAndWeight() + expFactor := s.vt.StatsExpFactor() + + totalValue = currentStats.Value(timeWeights, expFactor) + if connStats, ok := s.ns.GetField(node, sfiConnectedStats).(lpc.ResponseTimeStats); ok { + diff := currentStats + diff.SubStats(&connStats) + sessionValue = diff.Value(timeWeights, expFactor) + sessionValueMeter.Mark(int64(sessionValue)) + } + return +} + +// updateWeight calculates the node weight and updates the nodeWeight field and the +// hasValue flag. It also saves the node state if necessary. +func (s *serverPool) updateWeight(node *enode.Node, totalValue float64, totalDialCost uint64) { + weight := uint64(totalValue * nodeWeightMul / float64(totalDialCost)) + if weight >= nodeWeightThreshold { + s.ns.SetState(node, sfHasValue, nodestate.Flags{}, 0) + s.ns.SetField(node, sfiNodeWeight, weight) } else { - if len(q.queue) == q.maxCnt { - e := q.fetchOldest() - q.remove(e) - q.removeFromPool(e) - } + s.ns.SetState(node, nodestate.Flags{}, sfHasValue, 0) + s.ns.SetField(node, sfiNodeWeight, nil) } - entry.queueIdx = q.newPtr - q.queue[entry.queueIdx] = entry - q.newPtr++ + s.ns.Persist(node) // saved if node history or hasValue changed +} + +// setRedialWait calculates and sets the redialWait timeout based on the service value +// and dial cost accumulated during the last session/attempt and in total. +// The waiting time is raised exponentially if no service value has been received in order +// to prevent dialing an unresponsive node frequently for a very long time just because it +// was useful in the past. It can still be occasionally dialed though and once it provides +// a significant amount of service value again its waiting time is quickly reduced or reset +// to the minimum. +// Note: node weight is also recalculated and updated by this function. +func (s *serverPool) setRedialWait(node *enode.Node, addDialCost int64, waitStep float64) { + n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory) + sessionValue, totalValue := s.serviceValue(node) + totalDialCost := s.addDialCost(&n, addDialCost) + + // if the current dial session has yielded at least the average value/dial cost ratio + // then the waiting time should be reset to the minimum. If the session value + // is below average but still positive then timeout is limited to the ratio of + // average / current service value multiplied by the minimum timeout. If the attempt + // was unsuccessful then timeout is raised exponentially without limitation. + // Note: dialCost is used in the formula below even if dial was not attempted at all + // because the pre-negotiation query did not return a positive result. In this case + // the ratio has no meaning anyway and waitFactor is always raised, though in smaller + // steps because queries are cheaper and therefore we can allow more failed attempts. + unixTime := s.unixTime() + plannedTimeout := float64(n.redialWaitEnd - n.redialWaitStart) // last planned redialWait timeout + var actualWait float64 // actual waiting time elapsed + if unixTime > n.redialWaitEnd { + // the planned timeout has elapsed + actualWait = plannedTimeout + } else { + // if the node was redialed earlier then we do not raise the planned timeout + // exponentially because that could lead to the timeout rising very high in + // a short amount of time + // Note that in case of an early redial actualWait also includes the dial + // timeout or connection time of the last attempt but it still serves its + // purpose of preventing the timeout rising quicker than linearly as a function + // of total time elapsed without a successful connection. + actualWait = float64(unixTime - n.redialWaitStart) + } + // raise timeout exponentially if the last planned timeout has elapsed + // (use at least the last planned timeout otherwise) + nextTimeout := actualWait * waitStep + if plannedTimeout > nextTimeout { + nextTimeout = plannedTimeout + } + // we reduce the waiting time if the server has provided service value during the + // connection (but never under the minimum) + a := totalValue * dialCost * float64(minRedialWait) + b := float64(totalDialCost) * sessionValue + if a < b*nextTimeout { + nextTimeout = a / b + } + if nextTimeout < minRedialWait { + nextTimeout = minRedialWait + } + wait := time.Duration(float64(time.Second) * nextTimeout) + if wait < waitThreshold { + n.redialWaitStart = unixTime + n.redialWaitEnd = unixTime + int64(nextTimeout) + s.ns.SetField(node, sfiNodeHistory, n) + s.ns.SetState(node, sfRedialWait, nodestate.Flags{}, wait) + s.updateWeight(node, totalValue, totalDialCost) + } else { + // discard known node statistics if waiting time is very long because the node + // hasn't been responsive for a very long time + s.ns.SetField(node, sfiNodeHistory, nil) + s.ns.SetField(node, sfiNodeWeight, nil) + s.ns.SetState(node, nodestate.Flags{}, sfHasValue, 0) + } +} + +// calculateWeight calculates and sets the node weight without altering the node history. +// This function should be called during startup and shutdown only, otherwise setRedialWait +// will keep the weights updated as the underlying statistics are adjusted. +func (s *serverPool) calculateWeight(node *enode.Node) { + n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory) + _, totalValue := s.serviceValue(node) + totalDialCost := s.addDialCost(&n, 0) + s.updateWeight(node, totalValue, totalDialCost) } diff --git a/les/serverpool_test.go b/les/serverpool_test.go new file mode 100644 index 000000000..3d0487d10 --- /dev/null +++ b/les/serverpool_test.go @@ -0,0 +1,352 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "math/rand" + "sync/atomic" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + lpc "github.com/ethereum/go-ethereum/les/lespay/client" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" +) + +const ( + spTestNodes = 1000 + spTestTarget = 5 + spTestLength = 10000 + spMinTotal = 40000 + spMaxTotal = 50000 +) + +func testNodeID(i int) enode.ID { + return enode.ID{42, byte(i % 256), byte(i / 256)} +} + +func testNodeIndex(id enode.ID) int { + if id[0] != 42 { + return -1 + } + return int(id[1]) + int(id[2])*256 +} + +type serverPoolTest struct { + db ethdb.KeyValueStore + clock *mclock.Simulated + quit chan struct{} + preNeg, preNegFail bool + vt *lpc.ValueTracker + sp *serverPool + input enode.Iterator + testNodes []spTestNode + trusted []string + waitCount, waitEnded int32 + + cycle, conn, servedConn int + serviceCycles, dialCount int + disconnect map[int][]int +} + +type spTestNode struct { + connectCycles, waitCycles int + nextConnCycle, totalConn int + connected, service bool + peer *serverPeer +} + +func newServerPoolTest(preNeg, preNegFail bool) *serverPoolTest { + nodes := make([]*enode.Node, spTestNodes) + for i := range nodes { + nodes[i] = enode.SignNull(&enr.Record{}, testNodeID(i)) + } + return &serverPoolTest{ + clock: &mclock.Simulated{}, + db: memorydb.New(), + input: enode.CycleNodes(nodes), + testNodes: make([]spTestNode, spTestNodes), + preNeg: preNeg, + preNegFail: preNegFail, + } +} + +func (s *serverPoolTest) beginWait() { + // ensure that dialIterator and the maximal number of pre-neg queries are not all stuck in a waiting state + for atomic.AddInt32(&s.waitCount, 1) > preNegLimit { + atomic.AddInt32(&s.waitCount, -1) + s.clock.Run(time.Second) + } +} + +func (s *serverPoolTest) endWait() { + atomic.AddInt32(&s.waitCount, -1) + atomic.AddInt32(&s.waitEnded, 1) +} + +func (s *serverPoolTest) addTrusted(i int) { + s.trusted = append(s.trusted, enode.SignNull(&enr.Record{}, testNodeID(i)).String()) +} + +func (s *serverPoolTest) start() { + var testQuery queryFunc + if s.preNeg { + testQuery = func(node *enode.Node) int { + idx := testNodeIndex(node.ID()) + n := &s.testNodes[idx] + canConnect := !n.connected && n.connectCycles != 0 && s.cycle >= n.nextConnCycle + if s.preNegFail { + // simulate a scenario where UDP queries never work + s.beginWait() + s.clock.Sleep(time.Second * 5) + s.endWait() + return -1 + } else { + switch idx % 3 { + case 0: + // pre-neg returns true only if connection is possible + if canConnect { + return 1 + } else { + return 0 + } + case 1: + // pre-neg returns true but connection might still fail + return 1 + case 2: + // pre-neg returns true if connection is possible, otherwise timeout (node unresponsive) + if canConnect { + return 1 + } else { + s.beginWait() + s.clock.Sleep(time.Second * 5) + s.endWait() + return -1 + } + } + return -1 + } + } + } + + s.vt = lpc.NewValueTracker(s.db, s.clock, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)) + s.sp = newServerPool(s.db, []byte("serverpool:"), s.vt, s.input, 0, testQuery, s.clock, s.trusted) + s.sp.validSchemes = enode.ValidSchemesForTesting + s.sp.unixTime = func() int64 { return int64(s.clock.Now()) / int64(time.Second) } + s.disconnect = make(map[int][]int) + s.sp.start() + s.quit = make(chan struct{}) + go func() { + last := int32(-1) + for { + select { + case <-time.After(time.Millisecond * 100): + c := atomic.LoadInt32(&s.waitEnded) + if c == last { + // advance clock if test is stuck (might happen in rare cases) + s.clock.Run(time.Second) + } + last = c + case <-s.quit: + return + } + } + }() +} + +func (s *serverPoolTest) stop() { + close(s.quit) + s.sp.stop() + s.vt.Stop() + for i := range s.testNodes { + n := &s.testNodes[i] + if n.connected { + n.totalConn += s.cycle + } + n.connected = false + n.peer = nil + n.nextConnCycle = 0 + } + s.conn, s.servedConn = 0, 0 +} + +func (s *serverPoolTest) run() { + for count := spTestLength; count > 0; count-- { + if dcList := s.disconnect[s.cycle]; dcList != nil { + for _, idx := range dcList { + n := &s.testNodes[idx] + s.sp.unregisterPeer(n.peer) + n.totalConn += s.cycle + n.connected = false + n.peer = nil + s.conn-- + if n.service { + s.servedConn-- + } + n.nextConnCycle = s.cycle + n.waitCycles + } + delete(s.disconnect, s.cycle) + } + if s.conn < spTestTarget { + s.dialCount++ + s.beginWait() + s.sp.dialIterator.Next() + s.endWait() + dial := s.sp.dialIterator.Node() + id := dial.ID() + idx := testNodeIndex(id) + n := &s.testNodes[idx] + if !n.connected && n.connectCycles != 0 && s.cycle >= n.nextConnCycle { + s.conn++ + if n.service { + s.servedConn++ + } + n.totalConn -= s.cycle + n.connected = true + dc := s.cycle + n.connectCycles + s.disconnect[dc] = append(s.disconnect[dc], idx) + n.peer = &serverPeer{peerCommons: peerCommons{Peer: p2p.NewPeer(id, "", nil)}} + s.sp.registerPeer(n.peer) + if n.service { + s.vt.Served(s.vt.GetNode(id), []lpc.ServedRequest{{ReqType: 0, Amount: 100}}, 0) + } + } + } + s.serviceCycles += s.servedConn + s.clock.Run(time.Second) + s.cycle++ + } +} + +func (s *serverPoolTest) setNodes(count, conn, wait int, service, trusted bool) (res []int) { + for ; count > 0; count-- { + idx := rand.Intn(spTestNodes) + for s.testNodes[idx].connectCycles != 0 || s.testNodes[idx].connected { + idx = rand.Intn(spTestNodes) + } + res = append(res, idx) + s.testNodes[idx] = spTestNode{ + connectCycles: conn, + waitCycles: wait, + service: service, + } + if trusted { + s.addTrusted(idx) + } + } + return +} + +func (s *serverPoolTest) resetNodes() { + for i, n := range s.testNodes { + if n.connected { + n.totalConn += s.cycle + s.sp.unregisterPeer(n.peer) + } + s.testNodes[i] = spTestNode{totalConn: n.totalConn} + } + s.conn, s.servedConn = 0, 0 + s.disconnect = make(map[int][]int) + s.trusted = nil +} + +func (s *serverPoolTest) checkNodes(t *testing.T, nodes []int) { + var sum int + for _, idx := range nodes { + n := &s.testNodes[idx] + if n.connected { + n.totalConn += s.cycle + } + sum += n.totalConn + n.totalConn = 0 + if n.connected { + n.totalConn -= s.cycle + } + } + if sum < spMinTotal || sum > spMaxTotal { + t.Errorf("Total connection amount %d outside expected range %d to %d", sum, spMinTotal, spMaxTotal) + } +} + +func TestServerPool(t *testing.T) { testServerPool(t, false, false) } +func TestServerPoolWithPreNeg(t *testing.T) { testServerPool(t, true, false) } +func TestServerPoolWithPreNegFail(t *testing.T) { testServerPool(t, true, true) } +func testServerPool(t *testing.T, preNeg, fail bool) { + s := newServerPoolTest(preNeg, fail) + nodes := s.setNodes(100, 200, 200, true, false) + s.setNodes(100, 20, 20, false, false) + s.start() + s.run() + s.stop() + s.checkNodes(t, nodes) +} + +func TestServerPoolChangedNodes(t *testing.T) { testServerPoolChangedNodes(t, false) } +func TestServerPoolChangedNodesWithPreNeg(t *testing.T) { testServerPoolChangedNodes(t, true) } +func testServerPoolChangedNodes(t *testing.T, preNeg bool) { + s := newServerPoolTest(preNeg, false) + nodes := s.setNodes(100, 200, 200, true, false) + s.setNodes(100, 20, 20, false, false) + s.start() + s.run() + s.checkNodes(t, nodes) + for i := 0; i < 3; i++ { + s.resetNodes() + nodes := s.setNodes(100, 200, 200, true, false) + s.setNodes(100, 20, 20, false, false) + s.run() + s.checkNodes(t, nodes) + } + s.stop() +} + +func TestServerPoolRestartNoDiscovery(t *testing.T) { testServerPoolRestartNoDiscovery(t, false) } +func TestServerPoolRestartNoDiscoveryWithPreNeg(t *testing.T) { + testServerPoolRestartNoDiscovery(t, true) +} +func testServerPoolRestartNoDiscovery(t *testing.T, preNeg bool) { + s := newServerPoolTest(preNeg, false) + nodes := s.setNodes(100, 200, 200, true, false) + s.setNodes(100, 20, 20, false, false) + s.start() + s.run() + s.stop() + s.checkNodes(t, nodes) + s.input = nil + s.start() + s.run() + s.stop() + s.checkNodes(t, nodes) +} + +func TestServerPoolTrustedNoDiscovery(t *testing.T) { testServerPoolTrustedNoDiscovery(t, false) } +func TestServerPoolTrustedNoDiscoveryWithPreNeg(t *testing.T) { + testServerPoolTrustedNoDiscovery(t, true) +} +func testServerPoolTrustedNoDiscovery(t *testing.T, preNeg bool) { + s := newServerPoolTest(preNeg, false) + trusted := s.setNodes(200, 200, 200, true, true) + s.input = nil + s.start() + s.run() + s.stop() + s.checkNodes(t, trusted) +} diff --git a/les/test_helper.go b/les/test_helper.go index 1f02d2529..2a2bbb440 100644 --- a/les/test_helper.go +++ b/les/test_helper.go @@ -508,7 +508,7 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexer clock = &mclock.Simulated{} } dist := newRequestDistributor(speers, clock) - rm := newRetrieveManager(speers, dist, nil) + rm := newRetrieveManager(speers, dist, func() time.Duration { return time.Millisecond * 500 }) odr := NewLesOdr(cdb, light.TestClientIndexerConfig, rm) sindexers := testIndexers(sdb, nil, light.TestServerIndexerConfig) diff --git a/les/utils/expiredvalue.go b/les/utils/expiredvalue.go index 85f9b88b7..a58587368 100644 --- a/les/utils/expiredvalue.go +++ b/les/utils/expiredvalue.go @@ -63,14 +63,7 @@ func ExpFactor(logOffset Fixed64) ExpirationFactor { // Value calculates the expired value based on a floating point base and integer // power-of-2 exponent. This function should be used by multi-value expired structures. func (e ExpirationFactor) Value(base float64, exp uint64) float64 { - res := base / e.Factor - if exp > e.Exp { - res *= float64(uint64(1) << (exp - e.Exp)) - } - if exp < e.Exp { - res /= float64(uint64(1) << (e.Exp - exp)) - } - return res + return base / e.Factor * math.Pow(2, float64(int64(exp-e.Exp))) } // value calculates the value at the given moment. diff --git a/les/utils/weighted_select.go b/les/utils/weighted_select.go index fbf1f37d6..d6db3c0e6 100644 --- a/les/utils/weighted_select.go +++ b/les/utils/weighted_select.go @@ -16,28 +16,44 @@ package utils -import "math/rand" +import ( + "math/rand" +) -// wrsItem interface should be implemented by any entries that are to be selected from -// a WeightedRandomSelect set. Note that recalculating monotonously decreasing item -// weights on-demand (without constantly calling Update) is allowed -type wrsItem interface { - Weight() int64 -} - -// WeightedRandomSelect is capable of weighted random selection from a set of items -type WeightedRandomSelect struct { - root *wrsNode - idx map[wrsItem]int -} +type ( + // WeightedRandomSelect is capable of weighted random selection from a set of items + WeightedRandomSelect struct { + root *wrsNode + idx map[WrsItem]int + wfn WeightFn + } + WrsItem interface{} + WeightFn func(interface{}) uint64 +) // NewWeightedRandomSelect returns a new WeightedRandomSelect structure -func NewWeightedRandomSelect() *WeightedRandomSelect { - return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[wrsItem]int)} +func NewWeightedRandomSelect(wfn WeightFn) *WeightedRandomSelect { + return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[WrsItem]int), wfn: wfn} +} + +// Update updates an item's weight, adds it if it was non-existent or removes it if +// the new weight is zero. Note that explicitly updating decreasing weights is not necessary. +func (w *WeightedRandomSelect) Update(item WrsItem) { + w.setWeight(item, w.wfn(item)) +} + +// Remove removes an item from the set +func (w *WeightedRandomSelect) Remove(item WrsItem) { + w.setWeight(item, 0) +} + +// IsEmpty returns true if the set is empty +func (w *WeightedRandomSelect) IsEmpty() bool { + return w.root.sumWeight == 0 } // setWeight sets an item's weight to a specific value (removes it if zero) -func (w *WeightedRandomSelect) setWeight(item wrsItem, weight int64) { +func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) { idx, ok := w.idx[item] if ok { w.root.setWeight(idx, weight) @@ -58,33 +74,22 @@ func (w *WeightedRandomSelect) setWeight(item wrsItem, weight int64) { } } -// Update updates an item's weight, adds it if it was non-existent or removes it if -// the new weight is zero. Note that explicitly updating decreasing weights is not necessary. -func (w *WeightedRandomSelect) Update(item wrsItem) { - w.setWeight(item, item.Weight()) -} - -// Remove removes an item from the set -func (w *WeightedRandomSelect) Remove(item wrsItem) { - w.setWeight(item, 0) -} - // Choose randomly selects an item from the set, with a chance proportional to its // current weight. If the weight of the chosen element has been decreased since the // last stored value, returns it with a newWeight/oldWeight chance, otherwise just // updates its weight and selects another one -func (w *WeightedRandomSelect) Choose() wrsItem { +func (w *WeightedRandomSelect) Choose() WrsItem { for { if w.root.sumWeight == 0 { return nil } - val := rand.Int63n(w.root.sumWeight) + val := uint64(rand.Int63n(int64(w.root.sumWeight))) choice, lastWeight := w.root.choose(val) - weight := choice.Weight() + weight := w.wfn(choice) if weight != lastWeight { w.setWeight(choice, weight) } - if weight >= lastWeight || rand.Int63n(lastWeight) < weight { + if weight >= lastWeight || uint64(rand.Int63n(int64(lastWeight))) < weight { return choice } } @@ -92,16 +97,16 @@ func (w *WeightedRandomSelect) Choose() wrsItem { const wrsBranches = 8 // max number of branches in the wrsNode tree -// wrsNode is a node of a tree structure that can store wrsItems or further wrsNodes. +// wrsNode is a node of a tree structure that can store WrsItems or further wrsNodes. type wrsNode struct { items [wrsBranches]interface{} - weights [wrsBranches]int64 - sumWeight int64 + weights [wrsBranches]uint64 + sumWeight uint64 level, itemCnt, maxItems int } // insert recursively inserts a new item to the tree and returns the item index -func (n *wrsNode) insert(item wrsItem, weight int64) int { +func (n *wrsNode) insert(item WrsItem, weight uint64) int { branch := 0 for n.items[branch] != nil && (n.level == 0 || n.items[branch].(*wrsNode).itemCnt == n.items[branch].(*wrsNode).maxItems) { branch++ @@ -129,7 +134,7 @@ func (n *wrsNode) insert(item wrsItem, weight int64) int { // setWeight updates the weight of a certain item (which should exist) and returns // the change of the last weight value stored in the tree -func (n *wrsNode) setWeight(idx int, weight int64) int64 { +func (n *wrsNode) setWeight(idx int, weight uint64) uint64 { if n.level == 0 { oldWeight := n.weights[idx] n.weights[idx] = weight @@ -152,12 +157,12 @@ func (n *wrsNode) setWeight(idx int, weight int64) int64 { return diff } -// Choose recursively selects an item from the tree and returns it along with its weight -func (n *wrsNode) choose(val int64) (wrsItem, int64) { +// choose recursively selects an item from the tree and returns it along with its weight +func (n *wrsNode) choose(val uint64) (WrsItem, uint64) { for i, w := range n.weights { if val < w { if n.level == 0 { - return n.items[i].(wrsItem), n.weights[i] + return n.items[i].(WrsItem), n.weights[i] } return n.items[i].(*wrsNode).choose(val) } diff --git a/les/utils/weighted_select_test.go b/les/utils/weighted_select_test.go index e1969e1a6..3e1c0ad98 100644 --- a/les/utils/weighted_select_test.go +++ b/les/utils/weighted_select_test.go @@ -26,17 +26,18 @@ type testWrsItem struct { widx *int } -func (t *testWrsItem) Weight() int64 { +func testWeight(i interface{}) uint64 { + t := i.(*testWrsItem) w := *t.widx if w == -1 || w == t.idx { - return int64(t.idx + 1) + return uint64(t.idx + 1) } return 0 } func TestWeightedRandomSelect(t *testing.T) { testFn := func(cnt int) { - s := NewWeightedRandomSelect() + s := NewWeightedRandomSelect(testWeight) w := -1 list := make([]testWrsItem, cnt) for i := range list { diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go new file mode 100644 index 000000000..7091281ae --- /dev/null +++ b/p2p/nodestate/nodestate.go @@ -0,0 +1,880 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package nodestate + +import ( + "errors" + "reflect" + "sync" + "time" + "unsafe" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/rlp" +) + +type ( + // NodeStateMachine connects different system components operating on subsets of + // network nodes. Node states are represented by 64 bit vectors with each bit assigned + // to a state flag. Each state flag has a descriptor structure and the mapping is + // created automatically. It is possible to subscribe to subsets of state flags and + // receive a callback if one of the nodes has a relevant state flag changed. + // Callbacks can also modify further flags of the same node or other nodes. State + // updates only return after all immediate effects throughout the system have happened + // (deadlocks should be avoided by design of the implemented state logic). The caller + // can also add timeouts assigned to a certain node and a subset of state flags. + // If the timeout elapses, the flags are reset. If all relevant flags are reset then + // the timer is dropped. State flags with no timeout are persisted in the database + // if the flag descriptor enables saving. If a node has no state flags set at any + // moment then it is discarded. + // + // Extra node fields can also be registered so system components can also store more + // complex state for each node that is relevant to them, without creating a custom + // peer set. Fields can be shared across multiple components if they all know the + // field ID. Subscription to fields is also possible. Persistent fields should have + // an encoder and a decoder function. + NodeStateMachine struct { + started, stopped bool + lock sync.Mutex + clock mclock.Clock + db ethdb.KeyValueStore + dbNodeKey []byte + nodes map[enode.ID]*nodeInfo + offlineCallbackList []offlineCallback + + // Registered state flags or fields. Modifications are allowed + // only when the node state machine has not been started. + setup *Setup + fields []*fieldInfo + saveFlags bitMask + + // Installed callbacks. Modifications are allowed only when the + // node state machine has not been started. + stateSubs []stateSub + + // Testing hooks, only for testing purposes. + saveNodeHook func(*nodeInfo) + } + + // Flags represents a set of flags from a certain setup + Flags struct { + mask bitMask + setup *Setup + } + + // Field represents a field from a certain setup + Field struct { + index int + setup *Setup + } + + // flagDefinition describes a node state flag. Each registered instance is automatically + // mapped to a bit of the 64 bit node states. + // If persistent is true then the node is saved when state machine is shutdown. + flagDefinition struct { + name string + persistent bool + } + + // fieldDefinition describes an optional node field of the given type. The contents + // of the field are only retained for each node as long as at least one of the + // state flags is set. + fieldDefinition struct { + name string + ftype reflect.Type + encode func(interface{}) ([]byte, error) + decode func([]byte) (interface{}, error) + } + + // stateSetup contains the list of flags and fields used by the application + Setup struct { + Version uint + flags []flagDefinition + fields []fieldDefinition + } + + // bitMask describes a node state or state mask. It represents a subset + // of node flags with each bit assigned to a flag index (LSB represents flag 0). + bitMask uint64 + + // StateCallback is a subscription callback which is called when one of the + // state flags that is included in the subscription state mask is changed. + // Note: oldState and newState are also masked with the subscription mask so only + // the relevant bits are included. + StateCallback func(n *enode.Node, oldState, newState Flags) + + // FieldCallback is a subscription callback which is called when the value of + // a specific field is changed. + FieldCallback func(n *enode.Node, state Flags, oldValue, newValue interface{}) + + // nodeInfo contains node state, fields and state timeouts + nodeInfo struct { + node *enode.Node + state bitMask + timeouts []*nodeStateTimeout + fields []interface{} + db, dirty bool + } + + nodeInfoEnc struct { + Enr enr.Record + Version uint + State bitMask + Fields [][]byte + } + + stateSub struct { + mask bitMask + callback StateCallback + } + + nodeStateTimeout struct { + mask bitMask + timer mclock.Timer + } + + fieldInfo struct { + fieldDefinition + subs []FieldCallback + } + + offlineCallback struct { + node *enode.Node + state bitMask + fields []interface{} + } +) + +// offlineState is a special state that is assumed to be set before a node is loaded from +// the database and after it is shut down. +const offlineState = bitMask(1) + +// NewFlag creates a new node state flag +func (s *Setup) NewFlag(name string) Flags { + if s.flags == nil { + s.flags = []flagDefinition{{name: "offline"}} + } + f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s} + s.flags = append(s.flags, flagDefinition{name: name}) + return f +} + +// NewPersistentFlag creates a new persistent node state flag +func (s *Setup) NewPersistentFlag(name string) Flags { + if s.flags == nil { + s.flags = []flagDefinition{{name: "offline"}} + } + f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s} + s.flags = append(s.flags, flagDefinition{name: name, persistent: true}) + return f +} + +// OfflineFlag returns the system-defined offline flag belonging to the given setup +func (s *Setup) OfflineFlag() Flags { + return Flags{mask: offlineState, setup: s} +} + +// NewField creates a new node state field +func (s *Setup) NewField(name string, ftype reflect.Type) Field { + f := Field{index: len(s.fields), setup: s} + s.fields = append(s.fields, fieldDefinition{ + name: name, + ftype: ftype, + }) + return f +} + +// NewPersistentField creates a new persistent node field +func (s *Setup) NewPersistentField(name string, ftype reflect.Type, encode func(interface{}) ([]byte, error), decode func([]byte) (interface{}, error)) Field { + f := Field{index: len(s.fields), setup: s} + s.fields = append(s.fields, fieldDefinition{ + name: name, + ftype: ftype, + encode: encode, + decode: decode, + }) + return f +} + +// flagOp implements binary flag operations and also checks whether the operands belong to the same setup +func flagOp(a, b Flags, trueIfA, trueIfB, trueIfBoth bool) Flags { + if a.setup == nil { + if a.mask != 0 { + panic("Node state flags have no setup reference") + } + a.setup = b.setup + } + if b.setup == nil { + if b.mask != 0 { + panic("Node state flags have no setup reference") + } + b.setup = a.setup + } + if a.setup != b.setup { + panic("Node state flags belong to a different setup") + } + res := Flags{setup: a.setup} + if trueIfA { + res.mask |= a.mask & ^b.mask + } + if trueIfB { + res.mask |= b.mask & ^a.mask + } + if trueIfBoth { + res.mask |= a.mask & b.mask + } + return res +} + +// And returns the set of flags present in both a and b +func (a Flags) And(b Flags) Flags { return flagOp(a, b, false, false, true) } + +// AndNot returns the set of flags present in a but not in b +func (a Flags) AndNot(b Flags) Flags { return flagOp(a, b, true, false, false) } + +// Or returns the set of flags present in either a or b +func (a Flags) Or(b Flags) Flags { return flagOp(a, b, true, true, true) } + +// Xor returns the set of flags present in either a or b but not both +func (a Flags) Xor(b Flags) Flags { return flagOp(a, b, true, true, false) } + +// HasAll returns true if b is a subset of a +func (a Flags) HasAll(b Flags) bool { return flagOp(a, b, false, true, false).mask == 0 } + +// HasNone returns true if a and b have no shared flags +func (a Flags) HasNone(b Flags) bool { return flagOp(a, b, false, false, true).mask == 0 } + +// Equals returns true if a and b have the same flags set +func (a Flags) Equals(b Flags) bool { return flagOp(a, b, true, true, false).mask == 0 } + +// IsEmpty returns true if a has no flags set +func (a Flags) IsEmpty() bool { return a.mask == 0 } + +// MergeFlags merges multiple sets of state flags +func MergeFlags(list ...Flags) Flags { + if len(list) == 0 { + return Flags{} + } + res := list[0] + for i := 1; i < len(list); i++ { + res = res.Or(list[i]) + } + return res +} + +// String returns a list of the names of the flags specified in the bit mask +func (f Flags) String() string { + if f.mask == 0 { + return "[]" + } + s := "[" + comma := false + for index, flag := range f.setup.flags { + if f.mask&(bitMask(1)< 8*int(unsafe.Sizeof(bitMask(0))) { + panic("Too many node state flags") + } + ns := &NodeStateMachine{ + db: db, + dbNodeKey: dbKey, + clock: clock, + setup: setup, + nodes: make(map[enode.ID]*nodeInfo), + fields: make([]*fieldInfo, len(setup.fields)), + } + stateNameMap := make(map[string]int) + for index, flag := range setup.flags { + if _, ok := stateNameMap[flag.name]; ok { + panic("Node state flag name collision") + } + stateNameMap[flag.name] = index + if flag.persistent { + ns.saveFlags |= bitMask(1) << uint(index) + } + } + fieldNameMap := make(map[string]int) + for index, field := range setup.fields { + if _, ok := fieldNameMap[field.name]; ok { + panic("Node field name collision") + } + ns.fields[index] = &fieldInfo{fieldDefinition: field} + fieldNameMap[field.name] = index + } + return ns +} + +// stateMask checks whether the set of flags belongs to the same setup and returns its internal bit mask +func (ns *NodeStateMachine) stateMask(flags Flags) bitMask { + if flags.setup != ns.setup && flags.mask != 0 { + panic("Node state flags belong to a different setup") + } + return flags.mask +} + +// fieldIndex checks whether the field belongs to the same setup and returns its internal index +func (ns *NodeStateMachine) fieldIndex(field Field) int { + if field.setup != ns.setup { + panic("Node field belongs to a different setup") + } + return field.index +} + +// SubscribeState adds a node state subscription. The callback is called while the state +// machine mutex is not held and it is allowed to make further state updates. All immediate +// changes throughout the system are processed in the same thread/goroutine. It is the +// responsibility of the implemented state logic to avoid deadlocks caused by the callbacks, +// infinite toggling of flags or hazardous/non-deterministic state changes. +// State subscriptions should be installed before loading the node database or making the +// first state update. +func (ns *NodeStateMachine) SubscribeState(flags Flags, callback StateCallback) { + ns.lock.Lock() + defer ns.lock.Unlock() + + if ns.started { + panic("state machine already started") + } + ns.stateSubs = append(ns.stateSubs, stateSub{ns.stateMask(flags), callback}) +} + +// SubscribeField adds a node field subscription. Same rules apply as for SubscribeState. +func (ns *NodeStateMachine) SubscribeField(field Field, callback FieldCallback) { + ns.lock.Lock() + defer ns.lock.Unlock() + + if ns.started { + panic("state machine already started") + } + f := ns.fields[ns.fieldIndex(field)] + f.subs = append(f.subs, callback) +} + +// newNode creates a new nodeInfo +func (ns *NodeStateMachine) newNode(n *enode.Node) *nodeInfo { + return &nodeInfo{node: n, fields: make([]interface{}, len(ns.fields))} +} + +// checkStarted checks whether the state machine has already been started and panics otherwise. +func (ns *NodeStateMachine) checkStarted() { + if !ns.started { + panic("state machine not started yet") + } +} + +// Start starts the state machine, enabling state and field operations and disabling +// further subscriptions. +func (ns *NodeStateMachine) Start() { + ns.lock.Lock() + if ns.started { + panic("state machine already started") + } + ns.started = true + if ns.db != nil { + ns.loadFromDb() + } + ns.lock.Unlock() + ns.offlineCallbacks(true) +} + +// Stop stops the state machine and saves its state if a database was supplied +func (ns *NodeStateMachine) Stop() { + ns.lock.Lock() + for _, node := range ns.nodes { + fields := make([]interface{}, len(node.fields)) + copy(fields, node.fields) + ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields}) + } + ns.stopped = true + if ns.db != nil { + ns.saveToDb() + ns.lock.Unlock() + } else { + ns.lock.Unlock() + } + ns.offlineCallbacks(false) +} + +// loadFromDb loads persisted node states from the database +func (ns *NodeStateMachine) loadFromDb() { + it := ns.db.NewIterator(ns.dbNodeKey, nil) + for it.Next() { + var id enode.ID + if len(it.Key()) != len(ns.dbNodeKey)+len(id) { + log.Error("Node state db entry with invalid length", "found", len(it.Key()), "expected", len(ns.dbNodeKey)+len(id)) + continue + } + copy(id[:], it.Key()[len(ns.dbNodeKey):]) + ns.decodeNode(id, it.Value()) + } +} + +type dummyIdentity enode.ID + +func (id dummyIdentity) Verify(r *enr.Record, sig []byte) error { return nil } +func (id dummyIdentity) NodeAddr(r *enr.Record) []byte { return id[:] } + +// decodeNode decodes a node database entry and adds it to the node set if successful +func (ns *NodeStateMachine) decodeNode(id enode.ID, data []byte) { + var enc nodeInfoEnc + if err := rlp.DecodeBytes(data, &enc); err != nil { + log.Error("Failed to decode node info", "id", id, "error", err) + return + } + n, _ := enode.New(dummyIdentity(id), &enc.Enr) + node := ns.newNode(n) + node.db = true + + if enc.Version != ns.setup.Version { + log.Debug("Removing stored node with unknown version", "current", ns.setup.Version, "stored", enc.Version) + ns.deleteNode(id) + return + } + if len(enc.Fields) > len(ns.setup.fields) { + log.Error("Invalid node field count", "id", id, "stored", len(enc.Fields)) + return + } + // Resolve persisted node fields + for i, encField := range enc.Fields { + if len(encField) == 0 { + continue + } + if decode := ns.fields[i].decode; decode != nil { + if field, err := decode(encField); err == nil { + node.fields[i] = field + } else { + log.Error("Failed to decode node field", "id", id, "field name", ns.fields[i].name, "error", err) + return + } + } else { + log.Error("Cannot decode node field", "id", id, "field name", ns.fields[i].name) + return + } + } + // It's a compatible node record, add it to set. + ns.nodes[id] = node + node.state = enc.State + fields := make([]interface{}, len(node.fields)) + copy(fields, node.fields) + ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields}) + log.Debug("Loaded node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup}) +} + +// saveNode saves the given node info to the database +func (ns *NodeStateMachine) saveNode(id enode.ID, node *nodeInfo) error { + if ns.db == nil { + return nil + } + + storedState := node.state & ns.saveFlags + for _, t := range node.timeouts { + storedState &= ^t.mask + } + if storedState == 0 { + if node.db { + node.db = false + ns.deleteNode(id) + } + node.dirty = false + return nil + } + + enc := nodeInfoEnc{ + Enr: *node.node.Record(), + Version: ns.setup.Version, + State: storedState, + Fields: make([][]byte, len(ns.fields)), + } + log.Debug("Saved node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup}) + lastIndex := -1 + for i, f := range node.fields { + if f == nil { + continue + } + encode := ns.fields[i].encode + if encode == nil { + continue + } + blob, err := encode(f) + if err != nil { + return err + } + enc.Fields[i] = blob + lastIndex = i + } + enc.Fields = enc.Fields[:lastIndex+1] + data, err := rlp.EncodeToBytes(&enc) + if err != nil { + return err + } + if err := ns.db.Put(append(ns.dbNodeKey, id[:]...), data); err != nil { + return err + } + node.dirty, node.db = false, true + + if ns.saveNodeHook != nil { + ns.saveNodeHook(node) + } + return nil +} + +// deleteNode removes a node info from the database +func (ns *NodeStateMachine) deleteNode(id enode.ID) { + ns.db.Delete(append(ns.dbNodeKey, id[:]...)) +} + +// saveToDb saves the persistent flags and fields of all nodes that have been changed +func (ns *NodeStateMachine) saveToDb() { + for id, node := range ns.nodes { + if node.dirty { + err := ns.saveNode(id, node) + if err != nil { + log.Error("Failed to save node", "id", id, "error", err) + } + } + } +} + +// updateEnode updates the enode entry belonging to the given node if it already exists +func (ns *NodeStateMachine) updateEnode(n *enode.Node) (enode.ID, *nodeInfo) { + id := n.ID() + node := ns.nodes[id] + if node != nil && n.Seq() > node.node.Seq() { + node.node = n + } + return id, node +} + +// Persist saves the persistent state and fields of the given node immediately +func (ns *NodeStateMachine) Persist(n *enode.Node) error { + ns.lock.Lock() + defer ns.lock.Unlock() + + ns.checkStarted() + if id, node := ns.updateEnode(n); node != nil && node.dirty { + err := ns.saveNode(id, node) + if err != nil { + log.Error("Failed to save node", "id", id, "error", err) + } + return err + } + return nil +} + +// SetState updates the given node state flags and processes all resulting callbacks. +// It only returns after all subsequent immediate changes (including those changed by the +// callbacks) have been processed. If a flag with a timeout is set again, the operation +// removes or replaces the existing timeout. +func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { + ns.lock.Lock() + ns.checkStarted() + if ns.stopped { + ns.lock.Unlock() + return + } + + set, reset := ns.stateMask(setFlags), ns.stateMask(resetFlags) + id, node := ns.updateEnode(n) + if node == nil { + if set == 0 { + ns.lock.Unlock() + return + } + node = ns.newNode(n) + ns.nodes[id] = node + } + oldState := node.state + newState := (node.state & (^reset)) | set + changed := oldState ^ newState + node.state = newState + + // Remove the timeout callbacks for all reset and set flags, + // even they are not existent(it's noop). + ns.removeTimeouts(node, set|reset) + + // Register the timeout callback if the new state is not empty + // and timeout itself is required. + if timeout != 0 && newState != 0 { + ns.addTimeout(n, set, timeout) + } + if newState == oldState { + ns.lock.Unlock() + return + } + if newState == 0 { + delete(ns.nodes, id) + if node.db { + ns.deleteNode(id) + } + } else { + if changed&ns.saveFlags != 0 { + node.dirty = true + } + } + ns.lock.Unlock() + // call state update subscription callbacks without holding the mutex + for _, sub := range ns.stateSubs { + if changed&sub.mask != 0 { + sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}) + } + } + if newState == 0 { + // call field subscriptions for discarded fields + for i, v := range node.fields { + if v != nil { + f := ns.fields[i] + if len(f.subs) > 0 { + for _, cb := range f.subs { + cb(n, Flags{setup: ns.setup}, v, nil) + } + } + } + } + } +} + +// offlineCallbacks calls state update callbacks at startup or shutdown +func (ns *NodeStateMachine) offlineCallbacks(start bool) { + for _, cb := range ns.offlineCallbackList { + for _, sub := range ns.stateSubs { + offState := offlineState & sub.mask + onState := cb.state & sub.mask + if offState != onState { + if start { + sub.callback(cb.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}) + } else { + sub.callback(cb.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}) + } + } + } + for i, f := range cb.fields { + if f != nil && ns.fields[i].subs != nil { + for _, fsub := range ns.fields[i].subs { + if start { + fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, nil, f) + } else { + fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, f, nil) + } + } + } + } + } + ns.offlineCallbackList = nil +} + +// AddTimeout adds a node state timeout associated to the given state flag(s). +// After the specified time interval, the relevant states will be reset. +func (ns *NodeStateMachine) AddTimeout(n *enode.Node, flags Flags, timeout time.Duration) { + ns.lock.Lock() + defer ns.lock.Unlock() + + ns.checkStarted() + if ns.stopped { + return + } + ns.addTimeout(n, ns.stateMask(flags), timeout) +} + +// addTimeout adds a node state timeout associated to the given state flag(s). +func (ns *NodeStateMachine) addTimeout(n *enode.Node, mask bitMask, timeout time.Duration) { + _, node := ns.updateEnode(n) + if node == nil { + return + } + mask &= node.state + if mask == 0 { + return + } + ns.removeTimeouts(node, mask) + t := &nodeStateTimeout{mask: mask} + t.timer = ns.clock.AfterFunc(timeout, func() { + ns.SetState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0) + }) + node.timeouts = append(node.timeouts, t) + if mask&ns.saveFlags != 0 { + node.dirty = true + } +} + +// removeTimeout removes node state timeouts associated to the given state flag(s). +// If a timeout was associated to multiple flags which are not all included in the +// specified remove mask then only the included flags are de-associated and the timer +// stays active. +func (ns *NodeStateMachine) removeTimeouts(node *nodeInfo, mask bitMask) { + for i := 0; i < len(node.timeouts); i++ { + t := node.timeouts[i] + match := t.mask & mask + if match == 0 { + continue + } + t.mask -= match + if t.mask != 0 { + continue + } + t.timer.Stop() + node.timeouts[i] = node.timeouts[len(node.timeouts)-1] + node.timeouts = node.timeouts[:len(node.timeouts)-1] + i-- + if match&ns.saveFlags != 0 { + node.dirty = true + } + } +} + +// GetField retrieves the given field of the given node +func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} { + ns.lock.Lock() + defer ns.lock.Unlock() + + ns.checkStarted() + if ns.stopped { + return nil + } + if _, node := ns.updateEnode(n); node != nil { + return node.fields[ns.fieldIndex(field)] + } + return nil +} + +// SetField sets the given field of the given node +func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface{}) error { + ns.lock.Lock() + ns.checkStarted() + if ns.stopped { + ns.lock.Unlock() + return nil + } + _, node := ns.updateEnode(n) + if node == nil { + ns.lock.Unlock() + return nil + } + fieldIndex := ns.fieldIndex(field) + f := ns.fields[fieldIndex] + if value != nil && reflect.TypeOf(value) != f.ftype { + log.Error("Invalid field type", "type", reflect.TypeOf(value), "required", f.ftype) + ns.lock.Unlock() + return errors.New("invalid field type") + } + oldValue := node.fields[fieldIndex] + if value == oldValue { + ns.lock.Unlock() + return nil + } + node.fields[fieldIndex] = value + if f.encode != nil { + node.dirty = true + } + + state := node.state + ns.lock.Unlock() + if len(f.subs) > 0 { + for _, cb := range f.subs { + cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value) + } + } + return nil +} + +// ForEach calls the callback for each node having all of the required and none of the +// disabled flags set +func (ns *NodeStateMachine) ForEach(requireFlags, disableFlags Flags, cb func(n *enode.Node, state Flags)) { + ns.lock.Lock() + ns.checkStarted() + type callback struct { + node *enode.Node + state bitMask + } + require, disable := ns.stateMask(requireFlags), ns.stateMask(disableFlags) + var callbacks []callback + for _, node := range ns.nodes { + if node.state&require == require && node.state&disable == 0 { + callbacks = append(callbacks, callback{node.node, node.state & (require | disable)}) + } + } + ns.lock.Unlock() + for _, c := range callbacks { + cb(c.node, Flags{mask: c.state, setup: ns.setup}) + } +} + +// GetNode returns the enode currently associated with the given ID +func (ns *NodeStateMachine) GetNode(id enode.ID) *enode.Node { + ns.lock.Lock() + defer ns.lock.Unlock() + + ns.checkStarted() + if node := ns.nodes[id]; node != nil { + return node.node + } + return nil +} + +// AddLogMetrics adds logging and/or metrics for nodes entering, exiting and currently +// being in a given set specified by required and disabled state flags +func (ns *NodeStateMachine) AddLogMetrics(requireFlags, disableFlags Flags, name string, inMeter, outMeter metrics.Meter, gauge metrics.Gauge) { + var count int64 + ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState Flags) { + oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags) + newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags) + if newMatch == oldMatch { + return + } + + if newMatch { + count++ + if name != "" { + log.Debug("Node entered", "set", name, "id", n.ID(), "count", count) + } + if inMeter != nil { + inMeter.Mark(1) + } + } else { + count-- + if name != "" { + log.Debug("Node left", "set", name, "id", n.ID(), "count", count) + } + if outMeter != nil { + outMeter.Mark(1) + } + } + if gauge != nil { + gauge.Update(count) + } + }) +} diff --git a/p2p/nodestate/nodestate_test.go b/p2p/nodestate/nodestate_test.go new file mode 100644 index 000000000..f6ff3ffc0 --- /dev/null +++ b/p2p/nodestate/nodestate_test.go @@ -0,0 +1,389 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package nodestate + +import ( + "errors" + "fmt" + "reflect" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/rlp" +) + +func testSetup(flagPersist []bool, fieldType []reflect.Type) (*Setup, []Flags, []Field) { + setup := &Setup{} + flags := make([]Flags, len(flagPersist)) + for i, persist := range flagPersist { + if persist { + flags[i] = setup.NewPersistentFlag(fmt.Sprintf("flag-%d", i)) + } else { + flags[i] = setup.NewFlag(fmt.Sprintf("flag-%d", i)) + } + } + fields := make([]Field, len(fieldType)) + for i, ftype := range fieldType { + switch ftype { + case reflect.TypeOf(uint64(0)): + fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, uint64FieldEnc, uint64FieldDec) + case reflect.TypeOf(""): + fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, stringFieldEnc, stringFieldDec) + default: + fields[i] = setup.NewField(fmt.Sprintf("field-%d", i), ftype) + } + } + return setup, flags, fields +} + +func testNode(b byte) *enode.Node { + r := &enr.Record{} + r.SetSig(dummyIdentity{b}, []byte{42}) + n, _ := enode.New(dummyIdentity{b}, r) + return n +} + +func TestCallback(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, _ := testSetup([]bool{false, false, false}, nil) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + set0 := make(chan struct{}, 1) + set1 := make(chan struct{}, 1) + set2 := make(chan struct{}, 1) + ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { set0 <- struct{}{} }) + ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { set1 <- struct{}{} }) + ns.SubscribeState(flags[2], func(n *enode.Node, oldState, newState Flags) { set2 <- struct{}{} }) + + ns.Start() + + ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetState(testNode(1), flags[1], Flags{}, time.Second) + ns.SetState(testNode(1), flags[2], Flags{}, 2*time.Second) + + for i := 0; i < 3; i++ { + select { + case <-set0: + case <-set1: + case <-set2: + case <-time.After(time.Second): + t.Fatalf("failed to invoke callback") + } + } +} + +func TestPersistentFlags(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, _ := testSetup([]bool{true, true, true, false}, nil) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + saveNode := make(chan *nodeInfo, 5) + ns.saveNodeHook = func(node *nodeInfo) { + saveNode <- node + } + + ns.Start() + + ns.SetState(testNode(1), flags[0], Flags{}, time.Second) // state with timeout should not be saved + ns.SetState(testNode(2), flags[1], Flags{}, 0) + ns.SetState(testNode(3), flags[2], Flags{}, 0) + ns.SetState(testNode(4), flags[3], Flags{}, 0) + ns.SetState(testNode(5), flags[0], Flags{}, 0) + ns.Persist(testNode(5)) + select { + case <-saveNode: + case <-time.After(time.Second): + t.Fatalf("Timeout") + } + ns.Stop() + + for i := 0; i < 2; i++ { + select { + case <-saveNode: + case <-time.After(time.Second): + t.Fatalf("Timeout") + } + } + select { + case <-saveNode: + t.Fatalf("Unexpected saveNode") + case <-time.After(time.Millisecond * 100): + } +} + +func TestSetField(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf("")}) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + saveNode := make(chan *nodeInfo, 1) + ns.saveNodeHook = func(node *nodeInfo) { + saveNode <- node + } + + ns.Start() + + // Set field before setting state + ns.SetField(testNode(1), fields[0], "hello world") + field := ns.GetField(testNode(1), fields[0]) + if field != nil { + t.Fatalf("Field shouldn't be set before setting states") + } + // Set field after setting state + ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetField(testNode(1), fields[0], "hello world") + field = ns.GetField(testNode(1), fields[0]) + if field == nil { + t.Fatalf("Field should be set after setting states") + } + if err := ns.SetField(testNode(1), fields[0], 123); err == nil { + t.Fatalf("Invalid field should be rejected") + } + // Dirty node should be written back + ns.Stop() + select { + case <-saveNode: + case <-time.After(time.Second): + t.Fatalf("Timeout") + } +} + +func TestUnsetField(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, fields := testSetup([]bool{false}, []reflect.Type{reflect.TypeOf("")}) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + ns.Start() + + ns.SetState(testNode(1), flags[0], Flags{}, time.Second) + ns.SetField(testNode(1), fields[0], "hello world") + + ns.SetState(testNode(1), Flags{}, flags[0], 0) + if field := ns.GetField(testNode(1), fields[0]); field != nil { + t.Fatalf("Field should be unset") + } +} + +func TestSetState(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, _ := testSetup([]bool{false, false, false}, nil) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + type change struct{ old, new Flags } + set := make(chan change, 1) + ns.SubscribeState(flags[0].Or(flags[1]), func(n *enode.Node, oldState, newState Flags) { + set <- change{ + old: oldState, + new: newState, + } + }) + + ns.Start() + + check := func(expectOld, expectNew Flags, expectChange bool) { + if expectChange { + select { + case c := <-set: + if !c.old.Equals(expectOld) { + t.Fatalf("Old state mismatch") + } + if !c.new.Equals(expectNew) { + t.Fatalf("New state mismatch") + } + case <-time.After(time.Second): + } + return + } + select { + case <-set: + t.Fatalf("Unexpected change") + case <-time.After(time.Millisecond * 100): + return + } + } + ns.SetState(testNode(1), flags[0], Flags{}, 0) + check(Flags{}, flags[0], true) + + ns.SetState(testNode(1), flags[1], Flags{}, 0) + check(flags[0], flags[0].Or(flags[1]), true) + + ns.SetState(testNode(1), flags[2], Flags{}, 0) + check(Flags{}, Flags{}, false) + + ns.SetState(testNode(1), Flags{}, flags[0], 0) + check(flags[0].Or(flags[1]), flags[1], true) + + ns.SetState(testNode(1), Flags{}, flags[1], 0) + check(flags[1], Flags{}, true) + + ns.SetState(testNode(1), Flags{}, flags[2], 0) + check(Flags{}, Flags{}, false) + + ns.SetState(testNode(1), flags[0].Or(flags[1]), Flags{}, time.Second) + check(Flags{}, flags[0].Or(flags[1]), true) + clock.Run(time.Second) + check(flags[0].Or(flags[1]), Flags{}, true) +} + +func uint64FieldEnc(field interface{}) ([]byte, error) { + if u, ok := field.(uint64); ok { + enc, err := rlp.EncodeToBytes(&u) + return enc, err + } else { + return nil, errors.New("invalid field type") + } +} + +func uint64FieldDec(enc []byte) (interface{}, error) { + var u uint64 + err := rlp.DecodeBytes(enc, &u) + return u, err +} + +func stringFieldEnc(field interface{}) ([]byte, error) { + if s, ok := field.(string); ok { + return []byte(s), nil + } else { + return nil, errors.New("invalid field type") + } +} + +func stringFieldDec(enc []byte) (interface{}, error) { + return string(enc), nil +} + +func TestPersistentFields(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0)), reflect.TypeOf("")}) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + ns.Start() + ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetField(testNode(1), fields[0], uint64(100)) + ns.SetField(testNode(1), fields[1], "hello world") + ns.Stop() + + ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + ns2.Start() + field0 := ns2.GetField(testNode(1), fields[0]) + if !reflect.DeepEqual(field0, uint64(100)) { + t.Fatalf("Field changed") + } + field1 := ns2.GetField(testNode(1), fields[1]) + if !reflect.DeepEqual(field1, "hello world") { + t.Fatalf("Field changed") + } + + s.Version++ + ns3 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + ns3.Start() + if ns3.GetField(testNode(1), fields[0]) != nil { + t.Fatalf("Old field version should have been discarded") + } +} + +func TestFieldSub(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0))}) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + var ( + lastState Flags + lastOldValue, lastNewValue interface{} + ) + ns.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { + lastState, lastOldValue, lastNewValue = state, oldValue, newValue + }) + check := func(state Flags, oldValue, newValue interface{}) { + if !lastState.Equals(state) || lastOldValue != oldValue || lastNewValue != newValue { + t.Fatalf("Incorrect field sub callback (expected [%v %v %v], got [%v %v %v])", state, oldValue, newValue, lastState, lastOldValue, lastNewValue) + } + } + ns.Start() + ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetField(testNode(1), fields[0], uint64(100)) + check(flags[0], nil, uint64(100)) + ns.Stop() + check(s.OfflineFlag(), uint64(100), nil) + + ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + ns2.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { + lastState, lastOldValue, lastNewValue = state, oldValue, newValue + }) + ns2.Start() + check(s.OfflineFlag(), nil, uint64(100)) + ns2.SetState(testNode(1), Flags{}, flags[0], 0) + check(Flags{}, uint64(100), nil) + ns2.Stop() +} + +func TestDuplicatedFlags(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, _ := testSetup([]bool{true}, nil) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + type change struct{ old, new Flags } + set := make(chan change, 1) + ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { + set <- change{oldState, newState} + }) + + ns.Start() + defer ns.Stop() + + check := func(expectOld, expectNew Flags, expectChange bool) { + if expectChange { + select { + case c := <-set: + if !c.old.Equals(expectOld) { + t.Fatalf("Old state mismatch") + } + if !c.new.Equals(expectNew) { + t.Fatalf("New state mismatch") + } + case <-time.After(time.Second): + } + return + } + select { + case <-set: + t.Fatalf("Unexpected change") + case <-time.After(time.Millisecond * 100): + return + } + } + ns.SetState(testNode(1), flags[0], Flags{}, time.Second) + check(Flags{}, flags[0], true) + ns.SetState(testNode(1), flags[0], Flags{}, 2*time.Second) // extend the timeout to 2s + check(Flags{}, flags[0], false) + + clock.Run(2 * time.Second) + check(flags[0], Flags{}, true) +} diff --git a/params/bootnodes.go b/params/bootnodes.go index 0d72321b0..e1898d762 100644 --- a/params/bootnodes.go +++ b/params/bootnodes.go @@ -65,11 +65,22 @@ var GoerliBootnodes = []string{ const dnsPrefix = "enrtree://AKA3AM6LPBYEUDMVNU3BSVQJ5AD45Y7YPOHJLEF6W26QOE4VTUDPE@" -// These DNS names provide bootstrap connectivity for public testnets and the mainnet. -// See https://github.com/ethereum/discv4-dns-lists for more information. -var KnownDNSNetworks = map[common.Hash]string{ - MainnetGenesisHash: dnsPrefix + "all.mainnet.ethdisco.net", - RopstenGenesisHash: dnsPrefix + "all.ropsten.ethdisco.net", - RinkebyGenesisHash: dnsPrefix + "all.rinkeby.ethdisco.net", - GoerliGenesisHash: dnsPrefix + "all.goerli.ethdisco.net", +// KnownDNSNetwork returns the address of a public DNS-based node list for the given +// genesis hash and protocol. See https://github.com/ethereum/discv4-dns-lists for more +// information. +func KnownDNSNetwork(genesis common.Hash, protocol string) string { + var net string + switch genesis { + case MainnetGenesisHash: + net = "mainnet" + case RopstenGenesisHash: + net = "ropsten" + case RinkebyGenesisHash: + net = "rinkeby" + case GoerliGenesisHash: + net = "goerli" + default: + return "" + } + return dnsPrefix + protocol + "." + net + ".ethdisco.net" }