diff --git a/les/client_handler.go b/les/client_handler.go index cfeec7a03..77a0ea5c6 100644 --- a/les/client_handler.go +++ b/les/client_handler.go @@ -102,13 +102,7 @@ func (h *clientHandler) handle(p *serverPeer) error { p.Log().Debug("Light Ethereum peer connected", "name", p.Name()) // Execute the LES handshake - var ( - head = h.backend.blockchain.CurrentHeader() - hash = head.Hash() - number = head.Number.Uint64() - td = h.backend.blockchain.GetTd(hash, number) - ) - if err := p.Handshake(td, hash, number, h.backend.blockchain.Genesis().Hash(), nil); err != nil { + if err := p.Handshake(h.backend.blockchain.Genesis().Hash()); err != nil { p.Log().Debug("Light Ethereum handshake failed", "err", err) return err } diff --git a/les/clientpool.go b/les/clientpool.go index 4f6e3fafe..da0db6e62 100644 --- a/les/clientpool.go +++ b/les/clientpool.go @@ -18,7 +18,6 @@ package les import ( "fmt" - "reflect" "sync" "time" @@ -46,19 +45,6 @@ const ( inactiveTimeout = time.Second * 10 ) -var ( - clientPoolSetup = &nodestate.Setup{} - clientField = clientPoolSetup.NewField("clientInfo", reflect.TypeOf(&clientInfo{})) - connAddressField = clientPoolSetup.NewField("connAddr", reflect.TypeOf("")) - balanceTrackerSetup = lps.NewBalanceTrackerSetup(clientPoolSetup) - priorityPoolSetup = lps.NewPriorityPoolSetup(clientPoolSetup) -) - -func init() { - balanceTrackerSetup.Connect(connAddressField, priorityPoolSetup.CapacityField) - priorityPoolSetup.Connect(balanceTrackerSetup.BalanceField, balanceTrackerSetup.UpdateFlag) // NodeBalance implements nodePriority -} - // clientPool implements a client database that assigns a priority to each client // based on a positive and negative balance. Positive balance is externally assigned // to prioritized clients and is decreased with connection time and processed @@ -119,8 +105,7 @@ type clientInfo struct { } // newClientPool creates a new client pool -func newClientPool(lespayDb ethdb.Database, minCap uint64, connectedBias time.Duration, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { - ns := nodestate.NewNodeStateMachine(nil, nil, clock, clientPoolSetup) +func newClientPool(ns *nodestate.NodeStateMachine, lespayDb ethdb.Database, minCap uint64, connectedBias time.Duration, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { pool := &clientPool{ ns: ns, BalanceTrackerSetup: balanceTrackerSetup, @@ -147,7 +132,7 @@ func newClientPool(lespayDb ethdb.Database, minCap uint64, connectedBias time.Du }) ns.SubscribeState(pool.ActiveFlag.Or(pool.PriorityFlag), func(node *enode.Node, oldState, newState nodestate.Flags) { - c, _ := ns.GetField(node, clientField).(*clientInfo) + c, _ := ns.GetField(node, clientInfoField).(*clientInfo) if c == nil { return } @@ -172,7 +157,7 @@ func newClientPool(lespayDb ethdb.Database, minCap uint64, connectedBias time.Du if oldState.Equals(pool.ActiveFlag) && newState.Equals(pool.InactiveFlag) { clientDeactivatedMeter.Mark(1) log.Debug("Client deactivated", "id", node.ID()) - c, _ := ns.GetField(node, clientField).(*clientInfo) + c, _ := ns.GetField(node, clientInfoField).(*clientInfo) if c == nil || !c.peer.allowInactive() { pool.removePeer(node.ID()) } @@ -190,13 +175,11 @@ func newClientPool(lespayDb ethdb.Database, minCap uint64, connectedBias time.Du newCap, _ := newValue.(uint64) totalConnected += newCap - oldCap totalConnectedGauge.Update(int64(totalConnected)) - c, _ := ns.GetField(node, clientField).(*clientInfo) + c, _ := ns.GetField(node, clientInfoField).(*clientInfo) if c != nil { c.peer.updateCapacity(newCap) } }) - - ns.Start() return pool } @@ -210,7 +193,6 @@ func (f *clientPool) stop() { f.disconnectNode(node) }) f.bt.Stop() - f.ns.Stop() } // connect should be called after a successful handshake. If the connection was @@ -225,7 +207,7 @@ func (f *clientPool) connect(peer clientPoolPeer) (uint64, error) { } // Dedup connected peers. node, freeID := peer.Node(), peer.freeClientId() - if f.ns.GetField(node, clientField) != nil { + if f.ns.GetField(node, clientInfoField) != nil { log.Debug("Client already connected", "address", freeID, "id", node.ID().String()) return 0, fmt.Errorf("Client already connected address=%s id=%s", freeID, node.ID().String()) } @@ -237,7 +219,7 @@ func (f *clientPool) connect(peer clientPoolPeer) (uint64, error) { connected: true, connectedAt: now, } - f.ns.SetField(node, clientField, c) + f.ns.SetField(node, clientInfoField, c) f.ns.SetField(node, connAddressField, freeID) if c.balance, _ = f.ns.GetField(node, f.BalanceField).(*lps.NodeBalance); c.balance == nil { f.disconnect(peer) @@ -280,7 +262,7 @@ func (f *clientPool) disconnect(p clientPoolPeer) { // disconnectNode removes node fields and flags related to connected status func (f *clientPool) disconnectNode(node *enode.Node) { f.ns.SetField(node, connAddressField, nil) - f.ns.SetField(node, clientField, nil) + f.ns.SetField(node, clientInfoField, nil) } // setDefaultFactors sets the default price factors applied to subsequently connected clients @@ -299,7 +281,8 @@ func (f *clientPool) capacityInfo() (uint64, uint64, uint64) { defer f.lock.Unlock() // total priority active cap will be supported when the token issuer module is added - return f.capLimit, f.pp.ActiveCapacity(), 0 + _, activeCap := f.pp.Active() + return f.capLimit, activeCap, 0 } // setLimits sets the maximum number and total capacity of connected clients, @@ -314,13 +297,13 @@ func (f *clientPool) setLimits(totalConn int, totalCap uint64) { // setCapacity sets the assigned capacity of a connected client func (f *clientPool) setCapacity(node *enode.Node, freeID string, capacity uint64, bias time.Duration, setCap bool) (uint64, error) { - c, _ := f.ns.GetField(node, clientField).(*clientInfo) + c, _ := f.ns.GetField(node, clientInfoField).(*clientInfo) if c == nil { if setCap { return 0, fmt.Errorf("client %064x is not connected", node.ID()) } c = &clientInfo{node: node} - f.ns.SetField(node, clientField, c) + f.ns.SetField(node, clientInfoField, c) f.ns.SetField(node, connAddressField, freeID) if c.balance, _ = f.ns.GetField(node, f.BalanceField).(*lps.NodeBalance); c.balance == nil { log.Error("BalanceField is missing", "node", node.ID()) @@ -328,7 +311,7 @@ func (f *clientPool) setCapacity(node *enode.Node, freeID string, capacity uint6 } defer func() { f.ns.SetField(node, connAddressField, nil) - f.ns.SetField(node, clientField, nil) + f.ns.SetField(node, clientInfoField, nil) }() } var ( @@ -370,7 +353,7 @@ func (f *clientPool) forClients(ids []enode.ID, cb func(client *clientInfo)) { if len(ids) == 0 { f.ns.ForEach(nodestate.Flags{}, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { - c, _ := f.ns.GetField(node, clientField).(*clientInfo) + c, _ := f.ns.GetField(node, clientInfoField).(*clientInfo) if c != nil { cb(c) } @@ -381,12 +364,12 @@ func (f *clientPool) forClients(ids []enode.ID, cb func(client *clientInfo)) { if node == nil { node = enode.SignNull(&enr.Record{}, id) } - c, _ := f.ns.GetField(node, clientField).(*clientInfo) + c, _ := f.ns.GetField(node, clientInfoField).(*clientInfo) if c != nil { cb(c) } else { c = &clientInfo{node: node} - f.ns.SetField(node, clientField, c) + f.ns.SetField(node, clientInfoField, c) f.ns.SetField(node, connAddressField, "") if c.balance, _ = f.ns.GetField(node, f.BalanceField).(*lps.NodeBalance); c.balance != nil { cb(c) @@ -394,7 +377,7 @@ func (f *clientPool) forClients(ids []enode.ID, cb func(client *clientInfo)) { log.Error("BalanceField is missing") } f.ns.SetField(node, connAddressField, nil) - f.ns.SetField(node, clientField, nil) + f.ns.SetField(node, clientInfoField, nil) } } } diff --git a/les/clientpool_test.go b/les/clientpool_test.go index cfd1486b4..b1c38d374 100644 --- a/les/clientpool_test.go +++ b/les/clientpool_test.go @@ -64,6 +64,11 @@ type poolTestPeer struct { inactiveAllowed bool } +func testStateMachine() *nodestate.NodeStateMachine { + return nodestate.NewNodeStateMachine(nil, nil, mclock.System{}, serverSetup) + +} + func newPoolTestPeer(i int, disconnCh chan int) *poolTestPeer { return &poolTestPeer{ index: i, @@ -91,7 +96,7 @@ func (i *poolTestPeer) allowInactive() bool { } func getBalance(pool *clientPool, p *poolTestPeer) (pos, neg uint64) { - temp := pool.ns.GetField(p.node, clientField) == nil + temp := pool.ns.GetField(p.node, clientInfoField) == nil if temp { pool.ns.SetField(p.node, connAddressField, p.freeClientId()) } @@ -128,8 +133,9 @@ func testClientPool(t *testing.T, activeLimit, clientCount, paidCount int, rando disconnFn = func(id enode.ID) { disconnCh <- int(id[0]) + int(id[1])<<8 } - pool = newClientPool(db, 1, 0, &clock, disconnFn) + pool = newClientPool(testStateMachine(), db, 1, 0, &clock, disconnFn) ) + pool.ns.Start() pool.setLimits(activeLimit, uint64(activeLimit)) pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -233,7 +239,8 @@ func TestConnectPaidClient(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -248,7 +255,8 @@ func TestConnectPaidClientToSmallPool(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -266,7 +274,8 @@ func TestConnectPaidClientToFullPool(t *testing.T) { db = rawdb.NewMemoryDatabase() ) removeFn := func(enode.ID) {} // Noop - pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -295,7 +304,8 @@ func TestPaidClientKickedOut(t *testing.T) { removeFn := func(id enode.ID) { kickedCh <- int(id[0]) } - pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn) + pool.ns.Start() pool.bt.SetExpirationTCs(0, 0) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 @@ -325,7 +335,8 @@ func TestConnectFreeClient(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -341,7 +352,8 @@ func TestConnectFreeClientToFullPool(t *testing.T) { db = rawdb.NewMemoryDatabase() ) removeFn := func(enode.ID) {} // Noop - pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -370,7 +382,8 @@ func TestFreeClientKickedOut(t *testing.T) { kicked = make(chan int, 100) ) removeFn := func(id enode.ID) { kicked <- int(id[0]) } - pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -411,7 +424,8 @@ func TestPositiveBalanceCalculation(t *testing.T) { kicked = make(chan int, 10) ) removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop - pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -434,7 +448,8 @@ func TestDowngradePriorityClient(t *testing.T) { kicked = make(chan int, 10) ) removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop - pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -468,7 +483,8 @@ func TestNegativeBalanceCalculation(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1e-3, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1e-3, CapacityFactor: 0, RequestFactor: 1}) @@ -503,7 +519,8 @@ func TestInactiveClient(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool.ns.Start() defer pool.stop() pool.setLimits(2, uint64(2)) diff --git a/les/enr_entry.go b/les/enr_entry.go index 65d0d1fdb..11e6273be 100644 --- a/les/enr_entry.go +++ b/les/enr_entry.go @@ -36,7 +36,7 @@ func (e lesEntry) ENRKey() string { // setupDiscovery creates the node discovery source for the eth protocol. func (eth *LightEthereum) setupDiscovery(cfg *p2p.Config) (enode.Iterator, error) { - if /*cfg.NoDiscovery || */ len(eth.config.DiscoveryURLs) == 0 { + if cfg.NoDiscovery || len(eth.config.DiscoveryURLs) == 0 { return nil, nil } client := dnsdisc.NewClient(dnsdisc.Config{}) diff --git a/les/lespay/server/prioritypool.go b/les/lespay/server/prioritypool.go index 52224e093..c0c33840c 100644 --- a/les/lespay/server/prioritypool.go +++ b/les/lespay/server/prioritypool.go @@ -253,12 +253,12 @@ func (pp *PriorityPool) SetActiveBias(bias time.Duration) { pp.tryActivate() } -// ActiveCapacity returns the total capacity of currently active nodes -func (pp *PriorityPool) ActiveCapacity() uint64 { +// Active returns the number and total capacity of currently active nodes +func (pp *PriorityPool) Active() (uint64, uint64) { pp.lock.Lock() defer pp.lock.Unlock() - return pp.activeCap + return pp.activeCount, pp.activeCap } // inactiveSetIndex callback updates ppNodeInfo item index in inactiveQueue diff --git a/les/peer.go b/les/peer.go index 0549daf9a..2b0117bed 100644 --- a/les/peer.go +++ b/les/peer.go @@ -126,7 +126,7 @@ type peerCommons struct { frozen uint32 // Flag whether the peer is frozen. announceType uint64 // New block announcement type. serving uint32 // The status indicates the peer is served. - headInfo blockInfo // Latest block information. + headInfo blockInfo // Last announced block information. // Background task queue for caching peer tasks and executing in order. sendQueue *utils.ExecQueue @@ -255,6 +255,8 @@ func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, g // Add some basic handshake fields send = send.add("protocolVersion", uint64(p.version)) send = send.add("networkId", p.network) + // Note: the head info announced at handshake is only used in case of server peers + // but dummy values are still announced by clients for compatibility with older servers send = send.add("headTd", td) send = send.add("headHash", head) send = send.add("headNum", headNum) @@ -273,24 +275,14 @@ func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, g if size > allowedUpdateBytes { return errResp(ErrRequestRejected, "") } - var rGenesis, rHash common.Hash - var rVersion, rNetwork, rNum uint64 - var rTd *big.Int + var rGenesis common.Hash + var rVersion, rNetwork uint64 if err := recv.get("protocolVersion", &rVersion); err != nil { return err } if err := recv.get("networkId", &rNetwork); err != nil { return err } - if err := recv.get("headTd", &rTd); err != nil { - return err - } - if err := recv.get("headHash", &rHash); err != nil { - return err - } - if err := recv.get("headNum", &rNum); err != nil { - return err - } if err := recv.get("genesisHash", &rGenesis); err != nil { return err } @@ -303,7 +295,6 @@ func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, g if int(rVersion) != p.version { return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", rVersion, p.version) } - p.headInfo = blockInfo{Hash: rHash, Number: rNum, Td: rTd} if recvCallback != nil { return recvCallback(recv) } @@ -569,9 +560,11 @@ func (p *serverPeer) updateHead(hash common.Hash, number uint64, td *big.Int) { } // Handshake executes the les protocol handshake, negotiating version number, -// network IDs, difficulties, head and genesis blocks. -func (p *serverPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, server *LesServer) error { - return p.handshake(td, head, headNum, genesis, func(lists *keyValueList) { +// network IDs and genesis blocks. +func (p *serverPeer) Handshake(genesis common.Hash) error { + // Note: there is no need to share local head with a server but older servers still + // require these fields so we announce zero values. + return p.handshake(common.Big0, common.Hash{}, 0, genesis, func(lists *keyValueList) { // Add some client-specific handshake fields // // Enable signed announcement randomly even the server is not trusted. @@ -581,6 +574,21 @@ func (p *serverPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge } *lists = (*lists).add("announceType", p.announceType) }, func(recv keyValueMap) error { + var ( + rHash common.Hash + rNum uint64 + rTd *big.Int + ) + if err := recv.get("headTd", &rTd); err != nil { + return err + } + if err := recv.get("headHash", &rHash); err != nil { + return err + } + if err := recv.get("headNum", &rNum); err != nil { + return err + } + p.headInfo = blockInfo{Hash: rHash, Number: rNum, Td: rTd} if recv.get("serveChainSince", &p.chainSince) != nil { p.onlyAnnounce = true } @@ -937,6 +945,9 @@ func (p *clientPeer) freezeClient() { // Handshake executes the les protocol handshake, negotiating version number, // network IDs, difficulties, head and genesis blocks. func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, server *LesServer) error { + // Note: clientPeer.headInfo should contain the last head announced to the client by us. + // The values announced in the handshake are dummy values for compatibility reasons and should be ignored. + p.headInfo = blockInfo{Hash: head, Number: headNum, Td: td} return p.handshake(td, head, headNum, genesis, func(lists *keyValueList) { // Add some information which services server can offer. if !server.config.UltraLightOnlyAnnounce { @@ -1009,145 +1020,6 @@ type serverPeerSubscriber interface { unregisterPeer(*serverPeer) } -// clientPeerSubscriber is an interface to notify services about added or -// removed client peers -type clientPeerSubscriber interface { - registerPeer(*clientPeer) - unregisterPeer(*clientPeer) -} - -// clientPeerSet represents the set of active client peers currently -// participating in the Light Ethereum sub-protocol. -type clientPeerSet struct { - peers map[string]*clientPeer - // subscribers is a batch of subscribers and peerset will notify - // these subscribers when the peerset changes(new client peer is - // added or removed) - subscribers []clientPeerSubscriber - closed bool - lock sync.RWMutex -} - -// newClientPeerSet creates a new peer set to track the client peers. -func newClientPeerSet() *clientPeerSet { - return &clientPeerSet{peers: make(map[string]*clientPeer)} -} - -// subscribe adds a service to be notified about added or removed -// peers and also register all active peers into the given service. -func (ps *clientPeerSet) subscribe(sub clientPeerSubscriber) { - ps.lock.Lock() - defer ps.lock.Unlock() - - ps.subscribers = append(ps.subscribers, sub) - for _, p := range ps.peers { - sub.registerPeer(p) - } -} - -// unSubscribe removes the specified service from the subscriber pool. -func (ps *clientPeerSet) unSubscribe(sub clientPeerSubscriber) { - ps.lock.Lock() - defer ps.lock.Unlock() - - for i, s := range ps.subscribers { - if s == sub { - ps.subscribers = append(ps.subscribers[:i], ps.subscribers[i+1:]...) - return - } - } -} - -// register adds a new peer into the peer set, or returns an error if the -// peer is already known. -func (ps *clientPeerSet) register(peer *clientPeer) error { - ps.lock.Lock() - defer ps.lock.Unlock() - - if ps.closed { - return errClosed - } - if _, exist := ps.peers[peer.id]; exist { - return errAlreadyRegistered - } - ps.peers[peer.id] = peer - for _, sub := range ps.subscribers { - sub.registerPeer(peer) - } - return nil -} - -// unregister removes a remote peer from the peer set, disabling any further -// actions to/from that particular entity. It also initiates disconnection -// at the networking layer. -func (ps *clientPeerSet) unregister(id string) error { - ps.lock.Lock() - defer ps.lock.Unlock() - - p, ok := ps.peers[id] - if !ok { - return errNotRegistered - } - delete(ps.peers, id) - for _, sub := range ps.subscribers { - sub.unregisterPeer(p) - } - p.Peer.Disconnect(p2p.DiscRequested) - return nil -} - -// ids returns a list of all registered peer IDs -func (ps *clientPeerSet) ids() []string { - ps.lock.RLock() - defer ps.lock.RUnlock() - - var ids []string - for id := range ps.peers { - ids = append(ids, id) - } - return ids -} - -// peer retrieves the registered peer with the given id. -func (ps *clientPeerSet) peer(id string) *clientPeer { - ps.lock.RLock() - defer ps.lock.RUnlock() - - return ps.peers[id] -} - -// len returns if the current number of peers in the set. -func (ps *clientPeerSet) len() int { - ps.lock.RLock() - defer ps.lock.RUnlock() - - return len(ps.peers) -} - -// allClientPeers returns all client peers in a list. -func (ps *clientPeerSet) allPeers() []*clientPeer { - ps.lock.RLock() - defer ps.lock.RUnlock() - - list := make([]*clientPeer, 0, len(ps.peers)) - for _, p := range ps.peers { - list = append(list, p) - } - return list -} - -// close disconnects all peers. No new peers can be registered -// after close has returned. -func (ps *clientPeerSet) close() { - ps.lock.Lock() - defer ps.lock.Unlock() - - for _, p := range ps.peers { - p.Disconnect(p2p.DiscQuitting) - } - ps.closed = true -} - // serverPeerSet represents the set of active server peers currently // participating in the Light Ethereum sub-protocol. type serverPeerSet struct { @@ -1298,42 +1170,3 @@ func (ps *serverPeerSet) close() { } ps.closed = true } - -// serverSet is a special set which contains all connected les servers. -// Les servers will also be discovered by discovery protocol because they -// also run the LES protocol. We can't drop them although they are useless -// for us(server) but for other protocols(e.g. ETH) upon the devp2p they -// may be useful. -type serverSet struct { - lock sync.Mutex - set map[string]*clientPeer - closed bool -} - -func newServerSet() *serverSet { - return &serverSet{set: make(map[string]*clientPeer)} -} - -func (s *serverSet) register(peer *clientPeer) error { - s.lock.Lock() - defer s.lock.Unlock() - - if s.closed { - return errClosed - } - if _, exist := s.set[peer.id]; exist { - return errAlreadyRegistered - } - s.set[peer.id] = peer - return nil -} - -func (s *serverSet) close() { - s.lock.Lock() - defer s.lock.Unlock() - - for _, p := range s.set { - p.Disconnect(p2p.DiscQuitting) - } - s.closed = true -} diff --git a/les/protocol.go b/les/protocol.go index 4fd19f9be..19a9561ce 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -174,12 +174,6 @@ var errorToString = map[int]string{ ErrMissingKey: "Key missing from list", } -type announceBlock struct { - Hash common.Hash // Hash of one particular block being announced - Number uint64 // Number of one particular block being announced - Td *big.Int // Total difficulty of one particular block being announced -} - // announceData is the network packet for the block announcements. type announceData struct { Hash common.Hash // Hash of one particular block being announced @@ -199,7 +193,7 @@ func (a *announceData) sanityCheck() error { // sign adds a signature to the block announcement by the given privKey func (a *announceData) sign(privKey *ecdsa.PrivateKey) { - rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td}) + rlp, _ := rlp.EncodeToBytes(blockInfo{a.Hash, a.Number, a.Td}) sig, _ := crypto.Sign(crypto.Keccak256(rlp), privKey) a.Update = a.Update.add("sign", sig) } @@ -210,7 +204,7 @@ func (a *announceData) checkSignature(id enode.ID, update keyValueMap) error { if err := update.get("sign", &sig); err != nil { return err } - rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td}) + rlp, _ := rlp.EncodeToBytes(blockInfo{a.Hash, a.Number, a.Td}) recPubkey, err := crypto.SigToPub(crypto.Keccak256(rlp), sig) if err != nil { return err diff --git a/les/server.go b/les/server.go index 225a7ad1f..cbedce136 100644 --- a/les/server.go +++ b/les/server.go @@ -18,6 +18,7 @@ package les import ( "crypto/ecdsa" + "reflect" "time" "github.com/ethereum/go-ethereum/common/mclock" @@ -31,17 +32,32 @@ import ( "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/nodestate" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" ) +var ( + serverSetup = &nodestate.Setup{} + clientPeerField = serverSetup.NewField("clientPeer", reflect.TypeOf(&clientPeer{})) + clientInfoField = serverSetup.NewField("clientInfo", reflect.TypeOf(&clientInfo{})) + connAddressField = serverSetup.NewField("connAddr", reflect.TypeOf("")) + balanceTrackerSetup = lps.NewBalanceTrackerSetup(serverSetup) + priorityPoolSetup = lps.NewPriorityPoolSetup(serverSetup) +) + +func init() { + balanceTrackerSetup.Connect(connAddressField, priorityPoolSetup.CapacityField) + priorityPoolSetup.Connect(balanceTrackerSetup.BalanceField, balanceTrackerSetup.UpdateFlag) // NodeBalance implements nodePriority +} + type LesServer struct { lesCommons + ns *nodestate.NodeStateMachine archiveMode bool // Flag whether the ethereum node runs in archive mode. - peers *clientPeerSet - serverset *serverSet handler *serverHandler + broadcaster *broadcaster lesTopics []discv5.Topic privateKey *ecdsa.PrivateKey @@ -60,6 +76,7 @@ type LesServer struct { } func NewLesServer(node *node.Node, e *eth.Ethereum, config *eth.Config) (*LesServer, error) { + ns := nodestate.NewNodeStateMachine(nil, nil, mclock.System{}, serverSetup) // Collect les protocol version information supported by local node. lesTopics := make([]discv5.Topic, len(AdvertiseProtocolVersions)) for i, pv := range AdvertiseProtocolVersions { @@ -83,9 +100,9 @@ func NewLesServer(node *node.Node, e *eth.Ethereum, config *eth.Config) (*LesSer bloomTrieIndexer: light.NewBloomTrieIndexer(e.ChainDb(), nil, params.BloomBitsBlocks, params.BloomTrieFrequency, true), closeCh: make(chan struct{}), }, + ns: ns, archiveMode: e.ArchiveMode(), - peers: newClientPeerSet(), - serverset: newServerSet(), + broadcaster: newBroadcaster(ns), lesTopics: lesTopics, fcManager: flowcontrol.NewClientManager(nil, &mclock.System{}), servingQueue: newServingQueue(int64(time.Millisecond*10), float64(config.LightServ)/100), @@ -116,7 +133,7 @@ func NewLesServer(node *node.Node, e *eth.Ethereum, config *eth.Config) (*LesSer srv.maxCapacity = totalRecharge } srv.fcManager.SetCapacityLimits(srv.minCapacity, srv.maxCapacity, srv.minCapacity*2) - srv.clientPool = newClientPool(srv.chainDb, srv.minCapacity, defaultConnectedBias, mclock.System{}, func(id enode.ID) { go srv.peers.unregister(id.String()) }) + srv.clientPool = newClientPool(ns, srv.chainDb, srv.minCapacity, defaultConnectedBias, mclock.System{}, srv.dropClient) srv.clientPool.setDefaultFactors(lps.PriceFactors{TimeFactor: 0, CapacityFactor: 1, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 0, CapacityFactor: 1, RequestFactor: 1}) checkpoint := srv.latestLocalCheckpoint() @@ -130,6 +147,13 @@ func NewLesServer(node *node.Node, e *eth.Ethereum, config *eth.Config) (*LesSer node.RegisterAPIs(srv.APIs()) node.RegisterLifecycle(srv) + // disconnect all peers at nsm shutdown + ns.SubscribeField(clientPeerField, func(node *enode.Node, state nodestate.Flags, oldValue, newValue interface{}) { + if state.Equals(serverSetup.OfflineFlag()) && oldValue != nil { + oldValue.(*clientPeer).Peer.Disconnect(p2p.DiscRequested) + } + }) + ns.Start() return srv, nil } @@ -158,7 +182,7 @@ func (s *LesServer) APIs() []rpc.API { func (s *LesServer) Protocols() []p2p.Protocol { ps := s.makeProtocols(ServerProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} { - if p := s.peers.peer(id.String()); p != nil { + if p := s.getClient(id); p != nil { return p.Info() } return nil @@ -173,6 +197,7 @@ func (s *LesServer) Protocols() []p2p.Protocol { // Start starts the LES server func (s *LesServer) Start() error { s.privateKey = s.p2pSrv.PrivateKey + s.broadcaster.setSignerKey(s.privateKey) s.handler.start() s.wg.Add(1) @@ -198,19 +223,11 @@ func (s *LesServer) Start() error { func (s *LesServer) Stop() error { close(s.closeCh) - // Disconnect existing connections with other LES servers. - s.serverset.close() - - // Disconnect existing sessions. - // This also closes the gate for any new registrations on the peer set. - // sessions which are already established but not added to pm.peers yet - // will exit when they try to register. - s.peers.close() - + s.clientPool.stop() + s.ns.Stop() s.fcManager.Stop() s.costTracker.stop() s.handler.stop() - s.clientPool.stop() // client pool should be closed after handler. s.servingQueue.stop() // Note, bloom trie indexer is closed by parent bloombits indexer. @@ -279,3 +296,18 @@ func (s *LesServer) capacityManagement() { } } } + +func (s *LesServer) getClient(id enode.ID) *clientPeer { + if node := s.ns.GetNode(id); node != nil { + if p, ok := s.ns.GetField(node, clientPeerField).(*clientPeer); ok { + return p + } + } + return nil +} + +func (s *LesServer) dropClient(id enode.ID) { + if p := s.getClient(id); p != nil { + p.Peer.Disconnect(p2p.DiscRequested) + } +} diff --git a/les/server_handler.go b/les/server_handler.go index 583df9600..c657d37f1 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -17,6 +17,7 @@ package les import ( + "crypto/ecdsa" "encoding/binary" "encoding/json" "errors" @@ -36,6 +37,8 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/nodestate" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -91,7 +94,7 @@ func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb et // start starts the server handler. func (h *serverHandler) start() { h.wg.Add(1) - go h.broadcastHeaders() + go h.broadcastLoop() } // stop stops the server handler. @@ -123,47 +126,58 @@ func (h *serverHandler) handle(p *clientPeer) error { p.Log().Debug("Light Ethereum handshake failed", "err", err) return err } - if p.server { - if err := h.server.serverset.register(p); err != nil { - return err - } - // connected to another server, no messages expected, just wait for disconnection - _, err := p.rw.ReadMsg() - return err - } // Reject light clients if server is not synced. if !h.synced() { p.Log().Debug("Light server not synced, rejecting peer") return p2p.DiscRequested } - defer p.fcClient.Disconnect() + var registered bool + if err := h.server.ns.Operation(func() { + if h.server.ns.GetField(p.Node(), clientPeerField) != nil { + registered = true + } else { + h.server.ns.SetFieldSub(p.Node(), clientPeerField, p) + } + }); err != nil { + return err + } + if registered { + return errAlreadyRegistered + } + + defer func() { + h.server.ns.SetField(p.Node(), clientPeerField, nil) + if p.fcClient != nil { // is nil when connecting another server + p.fcClient.Disconnect() + } + }() + if p.server { + // connected to another server, no messages expected, just wait for disconnection + _, err := p.rw.ReadMsg() + return err + } // Disconnect the inbound peer if it's rejected by clientPool if cap, err := h.server.clientPool.connect(p); cap != p.fcParams.MinRecharge || err != nil { p.Log().Debug("Light Ethereum peer rejected", "err", errFullClientPool) return errFullClientPool } - p.balance, _ = h.server.clientPool.ns.GetField(p.Node(), h.server.clientPool.BalanceField).(*lps.NodeBalance) + p.balance, _ = h.server.ns.GetField(p.Node(), h.server.clientPool.BalanceField).(*lps.NodeBalance) if p.balance == nil { return p2p.DiscRequested } - // Register the peer locally - if err := h.server.peers.register(p); err != nil { - h.server.clientPool.disconnect(p) - p.Log().Error("Light Ethereum peer registration failed", "err", err) - return err - } - clientConnectionGauge.Update(int64(h.server.peers.len())) + activeCount, _ := h.server.clientPool.pp.Active() + clientConnectionGauge.Update(int64(activeCount)) var wg sync.WaitGroup // Wait group used to track all in-flight task routines. connectedAt := mclock.Now() defer func() { wg.Wait() // Ensure all background task routines have exited. - h.server.peers.unregister(p.id) h.server.clientPool.disconnect(p) p.balance = nil - clientConnectionGauge.Update(int64(h.server.peers.len())) + activeCount, _ := h.server.clientPool.pp.Active() + clientConnectionGauge.Update(int64(activeCount)) connectionTimer.Update(time.Duration(mclock.Now() - connectedAt)) }() // Mark the peer starts to be served. @@ -911,11 +925,11 @@ func (h *serverHandler) txStatus(hash common.Hash) light.TxStatus { return stat } -// broadcastHeaders broadcasts new block information to all connected light +// broadcastLoop broadcasts new block information to all connected light // clients. According to the agreement between client and server, server should // only broadcast new announcement if the total difficulty is higher than the // last one. Besides server will add the signature if client requires. -func (h *serverHandler) broadcastHeaders() { +func (h *serverHandler) broadcastLoop() { defer h.wg.Done() headCh := make(chan core.ChainHeadEvent, 10) @@ -929,10 +943,6 @@ func (h *serverHandler) broadcastHeaders() { for { select { case ev := <-headCh: - peers := h.server.peers.allPeers() - if len(peers) == 0 { - continue - } header := ev.Block.Header() hash, number := header.Hash(), header.Number.Uint64() td := h.blockchain.GetTd(hash, number) @@ -944,33 +954,73 @@ func (h *serverHandler) broadcastHeaders() { reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(h.chainDb, header, lastHead).Number.Uint64() } lastHead, lastTd = header, td - log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg) - var ( - signed bool - signedAnnounce announceData - ) - announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg} - for _, p := range peers { - p := p - switch p.announceType { - case announceTypeSimple: - if !p.queueSend(func() { p.sendAnnounce(announce) }) { - log.Debug("Drop announcement because queue is full", "number", number, "hash", hash) - } - case announceTypeSigned: - if !signed { - signedAnnounce = announce - signedAnnounce.sign(h.server.privateKey) - signed = true - } - if !p.queueSend(func() { p.sendAnnounce(signedAnnounce) }) { - log.Debug("Drop announcement because queue is full", "number", number, "hash", hash) - } - } - } + h.server.broadcaster.broadcast(announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg}) case <-h.closeCh: return } } } + +// broadcaster sends new header announcements to active client peers +type broadcaster struct { + ns *nodestate.NodeStateMachine + privateKey *ecdsa.PrivateKey + lastAnnounce, signedAnnounce announceData +} + +// newBroadcaster creates a new broadcaster +func newBroadcaster(ns *nodestate.NodeStateMachine) *broadcaster { + b := &broadcaster{ns: ns} + ns.SubscribeState(priorityPoolSetup.ActiveFlag, func(node *enode.Node, oldState, newState nodestate.Flags) { + if newState.Equals(priorityPoolSetup.ActiveFlag) { + // send last announcement to activated peers + b.sendTo(node) + } + }) + return b +} + +// setSignerKey sets the signer key for signed announcements. Should be called before +// starting the protocol handler. +func (b *broadcaster) setSignerKey(privateKey *ecdsa.PrivateKey) { + b.privateKey = privateKey +} + +// broadcast sends the given announcements to all active peers +func (b *broadcaster) broadcast(announce announceData) { + b.ns.Operation(func() { + // iterate in an Operation to ensure that the active set does not change while iterating + b.lastAnnounce = announce + b.ns.ForEach(priorityPoolSetup.ActiveFlag, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { + b.sendTo(node) + }) + }) +} + +// sendTo sends the most recent announcement to the given node unless the same or higher Td +// announcement has already been sent. +func (b *broadcaster) sendTo(node *enode.Node) { + if b.lastAnnounce.Td == nil { + return + } + if p, _ := b.ns.GetField(node, clientPeerField).(*clientPeer); p != nil { + if p.headInfo.Td == nil || b.lastAnnounce.Td.Cmp(p.headInfo.Td) > 0 { + switch p.announceType { + case announceTypeSimple: + if !p.queueSend(func() { p.sendAnnounce(b.lastAnnounce) }) { + log.Debug("Drop announcement because queue is full", "number", b.lastAnnounce.Number, "hash", b.lastAnnounce.Hash) + } + case announceTypeSigned: + if b.signedAnnounce.Hash != b.lastAnnounce.Hash { + b.signedAnnounce = b.lastAnnounce + b.signedAnnounce.sign(b.privateKey) + } + if !p.queueSend(func() { p.sendAnnounce(b.signedAnnounce) }) { + log.Debug("Drop announcement because queue is full", "number", b.lastAnnounce.Number, "hash", b.lastAnnounce.Hash) + } + } + p.headInfo = blockInfo{b.lastAnnounce.Hash, b.lastAnnounce.Number, b.lastAnnounce.Td} + } + } +} diff --git a/les/test_helper.go b/les/test_helper.go index 9f9b28721..5a8d64f76 100644 --- a/les/test_helper.go +++ b/les/test_helper.go @@ -46,6 +46,7 @@ import ( "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/nodestate" "github.com/ethereum/go-ethereum/params" ) @@ -227,7 +228,7 @@ func newTestClientHandler(backend *backends.SimulatedBackend, odr *LesOdr, index return client.handler } -func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Database, peers *clientPeerSet, clock mclock.Clock) (*serverHandler, *backends.SimulatedBackend) { +func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Database, clock mclock.Clock) (*serverHandler, *backends.SimulatedBackend) { var ( gspec = core.Genesis{ Config: params.AllEthashProtocolChanges, @@ -263,6 +264,7 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da } oracle = checkpointoracle.New(checkpointConfig, getLocal) } + ns := nodestate.NewNodeStateMachine(nil, nil, mclock.System{}, serverSetup) server := &LesServer{ lesCommons: lesCommons{ genesis: genesis.Hash(), @@ -274,7 +276,8 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da oracle: oracle, closeCh: make(chan struct{}), }, - peers: peers, + ns: ns, + broadcaster: newBroadcaster(ns), servingQueue: newServingQueue(int64(time.Millisecond*10), 1), defParams: flowcontrol.ServerParams{ BufLimit: testBufLimit, @@ -284,13 +287,14 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da } server.costTracker, server.minCapacity = newCostTracker(db, server.config) server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism. - server.clientPool = newClientPool(db, testBufRecharge, defaultConnectedBias, clock, func(id enode.ID) {}) + server.clientPool = newClientPool(ns, db, testBufRecharge, defaultConnectedBias, clock, func(id enode.ID) {}) server.clientPool.setLimits(10000, 10000) // Assign enough capacity for clientpool server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true }) if server.oracle != nil { server.oracle.Start(simulation) } server.servingQueue.setThreads(4) + ns.Start() server.handler.start() return server.handler, simulation } @@ -463,7 +467,7 @@ func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallba if simClock { clock = &mclock.Simulated{} } - handler, b := newTestServerHandler(blocks, indexers, db, newClientPeerSet(), clock) + handler, b := newTestServerHandler(blocks, indexers, db, clock) var peer *testPeer if newPeer { @@ -502,7 +506,7 @@ func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallba func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, ulcServers []string, ulcFraction int, simClock bool, connect bool, disablePruning bool) (*testServer, *testClient, func()) { sdb, cdb := rawdb.NewMemoryDatabase(), rawdb.NewMemoryDatabase() - speers, cpeers := newServerPeerSet(), newClientPeerSet() + speers := newServerPeerSet() var clock mclock.Clock = &mclock.System{} if simClock { @@ -519,7 +523,7 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexer ccIndexer, cbIndexer, cbtIndexer := cIndexers[0], cIndexers[1], cIndexers[2] odr.SetIndexers(ccIndexer, cbIndexer, cbtIndexer) - server, b := newTestServerHandler(blocks, sindexers, sdb, cpeers, clock) + server, b := newTestServerHandler(blocks, sindexers, sdb, clock) client := newTestClientHandler(b, odr, cIndexers, cdb, speers, ulcServers, ulcFraction) scIndexer.Start(server.blockchain)