les: create utilities as common package (#20509)

* les: move execqueue into utilities package

execqueue is a util for executing queued functions
in a serial order which is used by both les server
and les client. Move it to common package.

* les: move randselect to utilities package

weighted_random_selector is a helpful tool for randomly select
items maintained in a set but based on the item weight.

It's used anywhere is LES package, mainly by les client but will
be used in les server with very high chance. So move it into a
common package as the second step for les separation.

* les: rename to utils
This commit is contained in:
gary rong 2020-03-31 23:17:24 +08:00 committed by GitHub
parent 32d31c31af
commit f78ffc0545
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 90 additions and 90 deletions

View File

@ -22,6 +22,7 @@ import (
"time" "time"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/les/utils"
) )
// requestDistributor implements a mechanism that distributes requests to // requestDistributor implements a mechanism that distributes requests to
@ -194,7 +195,7 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) {
elem := d.reqQueue.Front() elem := d.reqQueue.Front()
var ( var (
bestWait time.Duration bestWait time.Duration
sel *weightedRandomSelect sel *utils.WeightedRandomSelect
) )
d.peerLock.RLock() d.peerLock.RLock()
@ -219,9 +220,9 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) {
wait, bufRemain := peer.waitBefore(cost) wait, bufRemain := peer.waitBefore(cost)
if wait == 0 { if wait == 0 {
if sel == nil { if sel == nil {
sel = newWeightedRandomSelect() sel = utils.NewWeightedRandomSelect()
} }
sel.update(selectPeerItem{peer: peer, req: req, weight: int64(bufRemain*1000000) + 1}) sel.Update(selectPeerItem{peer: peer, req: req, weight: int64(bufRemain*1000000) + 1})
} else { } else {
if bestWait == 0 || wait < bestWait { if bestWait == 0 || wait < bestWait {
bestWait = wait bestWait = wait
@ -239,7 +240,7 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) {
} }
if sel != nil { if sel != nil {
c := sel.choose().(selectPeerItem) c := sel.Choose().(selectPeerItem)
return c.peer, c.req, 0 return c.peer, c.req, 0
} }
return nil, nil, bestWait return nil, nil, bestWait

View File

@ -32,6 +32,7 @@ import (
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/les/flowcontrol" "github.com/ethereum/go-ethereum/les/flowcontrol"
"github.com/ethereum/go-ethereum/les/utils"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
@ -135,7 +136,7 @@ type peerCommons struct {
headInfo blockInfo // Latest block information. headInfo blockInfo // Latest block information.
// Background task queue for caching peer tasks and executing in order. // Background task queue for caching peer tasks and executing in order.
sendQueue *execQueue sendQueue *utils.ExecQueue
// Flow control agreement. // Flow control agreement.
fcParams flowcontrol.ServerParams // The config for token bucket. fcParams flowcontrol.ServerParams // The config for token bucket.
@ -153,13 +154,13 @@ func (p *peerCommons) isFrozen() bool {
// canQueue returns an indicator whether the peer can queue a operation. // canQueue returns an indicator whether the peer can queue a operation.
func (p *peerCommons) canQueue() bool { func (p *peerCommons) canQueue() bool {
return p.sendQueue.canQueue() && !p.isFrozen() return p.sendQueue.CanQueue() && !p.isFrozen()
} }
// queueSend caches a peer operation in the background task queue. // queueSend caches a peer operation in the background task queue.
// Please ensure to check `canQueue` before call this function // Please ensure to check `canQueue` before call this function
func (p *peerCommons) queueSend(f func()) bool { func (p *peerCommons) queueSend(f func()) bool {
return p.sendQueue.queue(f) return p.sendQueue.Queue(f)
} }
// mustQueueSend starts a for loop and retry the caching if failed. // mustQueueSend starts a for loop and retry the caching if failed.
@ -337,7 +338,7 @@ func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, g
// close closes the channel and notifies all background routines to exit. // close closes the channel and notifies all background routines to exit.
func (p *peerCommons) close() { func (p *peerCommons) close() {
close(p.closeCh) close(p.closeCh)
p.sendQueue.quit() p.sendQueue.Quit()
} }
// serverPeer represents each node to which the client is connected. // serverPeer represents each node to which the client is connected.
@ -375,7 +376,7 @@ func newServerPeer(version int, network uint64, trusted bool, p *p2p.Peer, rw p2
id: peerIdToString(p.ID()), id: peerIdToString(p.ID()),
version: version, version: version,
network: network, network: network,
sendQueue: newExecQueue(100), sendQueue: utils.NewExecQueue(100),
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
}, },
trusted: trusted, trusted: trusted,
@ -407,7 +408,7 @@ func (p *serverPeer) rejectUpdate(size uint64) bool {
// frozen. // frozen.
func (p *serverPeer) freeze() { func (p *serverPeer) freeze() {
if atomic.CompareAndSwapUint32(&p.frozen, 0, 1) { if atomic.CompareAndSwapUint32(&p.frozen, 0, 1) {
p.sendQueue.clear() p.sendQueue.Clear()
} }
} }
@ -652,7 +653,7 @@ func newClientPeer(version int, network uint64, p *p2p.Peer, rw p2p.MsgReadWrite
id: peerIdToString(p.ID()), id: peerIdToString(p.ID()),
version: version, version: version,
network: network, network: network,
sendQueue: newExecQueue(100), sendQueue: utils.NewExecQueue(100),
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
}, },
errCh: make(chan error, 1), errCh: make(chan error, 1),

View File

@ -30,6 +30,7 @@ import (
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/les/utils"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/discv5"
@ -129,7 +130,7 @@ type serverPool struct {
adjustStats chan poolStatAdjust adjustStats chan poolStatAdjust
knownQueue, newQueue poolEntryQueue knownQueue, newQueue poolEntryQueue
knownSelect, newSelect *weightedRandomSelect knownSelect, newSelect *utils.WeightedRandomSelect
knownSelected, newSelected int knownSelected, newSelected int
fastDiscover bool fastDiscover bool
connCh chan *connReq connCh chan *connReq
@ -152,8 +153,8 @@ func newServerPool(db ethdb.Database, ulcServers []string) *serverPool {
disconnCh: make(chan *disconnReq), disconnCh: make(chan *disconnReq),
registerCh: make(chan *registerReq), registerCh: make(chan *registerReq),
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
knownSelect: newWeightedRandomSelect(), knownSelect: utils.NewWeightedRandomSelect(),
newSelect: newWeightedRandomSelect(), newSelect: utils.NewWeightedRandomSelect(),
fastDiscover: true, fastDiscover: true,
trustedNodes: parseTrustedNodes(ulcServers), trustedNodes: parseTrustedNodes(ulcServers),
} }
@ -402,8 +403,8 @@ func (pool *serverPool) eventLoop() {
entry.lastConnected = addr entry.lastConnected = addr
entry.addr = make(map[string]*poolEntryAddress) entry.addr = make(map[string]*poolEntryAddress)
entry.addr[addr.strKey()] = addr entry.addr[addr.strKey()] = addr
entry.addrSelect = *newWeightedRandomSelect() entry.addrSelect = *utils.NewWeightedRandomSelect()
entry.addrSelect.update(addr) entry.addrSelect.Update(addr)
req.result <- entry req.result <- entry
} }
@ -459,7 +460,7 @@ func (pool *serverPool) findOrNewNode(node *enode.Node) *poolEntry {
entry = &poolEntry{ entry = &poolEntry{
node: node, node: node,
addr: make(map[string]*poolEntryAddress), addr: make(map[string]*poolEntryAddress),
addrSelect: *newWeightedRandomSelect(), addrSelect: *utils.NewWeightedRandomSelect(),
shortRetry: shortRetryCnt, shortRetry: shortRetryCnt,
} }
pool.entries[node.ID()] = entry pool.entries[node.ID()] = entry
@ -477,7 +478,7 @@ func (pool *serverPool) findOrNewNode(node *enode.Node) *poolEntry {
entry.addr[addr.strKey()] = addr entry.addr[addr.strKey()] = addr
} }
addr.lastSeen = now addr.lastSeen = now
entry.addrSelect.update(addr) entry.addrSelect.Update(addr)
if !entry.known { if !entry.known {
pool.newQueue.setLatest(entry) pool.newQueue.setLatest(entry)
} }
@ -505,7 +506,7 @@ func (pool *serverPool) loadNodes() {
pool.entries[e.node.ID()] = e pool.entries[e.node.ID()] = e
if pool.trustedNodes[e.node.ID()] == nil { if pool.trustedNodes[e.node.ID()] == nil {
pool.knownQueue.setLatest(e) pool.knownQueue.setLatest(e)
pool.knownSelect.update((*knownEntry)(e)) pool.knownSelect.Update((*knownEntry)(e))
} }
} }
} }
@ -556,8 +557,8 @@ func (pool *serverPool) saveNodes() {
// Note that it is called by the new/known queues from which the entry has already // Note that it is called by the new/known queues from which the entry has already
// been removed so removing it from the queues is not necessary. // been removed so removing it from the queues is not necessary.
func (pool *serverPool) removeEntry(entry *poolEntry) { func (pool *serverPool) removeEntry(entry *poolEntry) {
pool.newSelect.remove((*discoveredEntry)(entry)) pool.newSelect.Remove((*discoveredEntry)(entry))
pool.knownSelect.remove((*knownEntry)(entry)) pool.knownSelect.Remove((*knownEntry)(entry))
entry.removed = true entry.removed = true
delete(pool.entries, entry.node.ID()) delete(pool.entries, entry.node.ID())
} }
@ -586,8 +587,8 @@ func (pool *serverPool) setRetryDial(entry *poolEntry) {
// updateCheckDial is called when an entry can potentially be dialed again. It updates // updateCheckDial is called when an entry can potentially be dialed again. It updates
// its selection weights and checks if new dials can/should be made. // its selection weights and checks if new dials can/should be made.
func (pool *serverPool) updateCheckDial(entry *poolEntry) { func (pool *serverPool) updateCheckDial(entry *poolEntry) {
pool.newSelect.update((*discoveredEntry)(entry)) pool.newSelect.Update((*discoveredEntry)(entry))
pool.knownSelect.update((*knownEntry)(entry)) pool.knownSelect.Update((*knownEntry)(entry))
pool.checkDial() pool.checkDial()
} }
@ -596,7 +597,7 @@ func (pool *serverPool) updateCheckDial(entry *poolEntry) {
func (pool *serverPool) checkDial() { func (pool *serverPool) checkDial() {
fillWithKnownSelects := !pool.fastDiscover fillWithKnownSelects := !pool.fastDiscover
for pool.knownSelected < targetKnownSelect { for pool.knownSelected < targetKnownSelect {
entry := pool.knownSelect.choose() entry := pool.knownSelect.Choose()
if entry == nil { if entry == nil {
fillWithKnownSelects = false fillWithKnownSelects = false
break break
@ -604,7 +605,7 @@ func (pool *serverPool) checkDial() {
pool.dial((*poolEntry)(entry.(*knownEntry)), true) pool.dial((*poolEntry)(entry.(*knownEntry)), true)
} }
for pool.knownSelected+pool.newSelected < targetServerCount { for pool.knownSelected+pool.newSelected < targetServerCount {
entry := pool.newSelect.choose() entry := pool.newSelect.Choose()
if entry == nil { if entry == nil {
break break
} }
@ -615,7 +616,7 @@ func (pool *serverPool) checkDial() {
// is over, we probably won't find more in the near future so select more // is over, we probably won't find more in the near future so select more
// known entries if possible // known entries if possible
for pool.knownSelected < targetServerCount { for pool.knownSelected < targetServerCount {
entry := pool.knownSelect.choose() entry := pool.knownSelect.Choose()
if entry == nil { if entry == nil {
break break
} }
@ -636,7 +637,7 @@ func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) {
} else { } else {
pool.newSelected++ pool.newSelected++
} }
addr := entry.addrSelect.choose().(*poolEntryAddress) addr := entry.addrSelect.Choose().(*poolEntryAddress)
log.Debug("Dialing new peer", "lesaddr", entry.node.ID().String()+"@"+addr.strKey(), "set", len(entry.addr), "known", knownSelected) log.Debug("Dialing new peer", "lesaddr", entry.node.ID().String()+"@"+addr.strKey(), "set", len(entry.addr), "known", knownSelected)
entry.dialed = addr entry.dialed = addr
go func() { go func() {
@ -684,7 +685,7 @@ type poolEntry struct {
addr map[string]*poolEntryAddress addr map[string]*poolEntryAddress
node *enode.Node node *enode.Node
lastConnected, dialed *poolEntryAddress lastConnected, dialed *poolEntryAddress
addrSelect weightedRandomSelect addrSelect utils.WeightedRandomSelect
lastDiscovered mclock.AbsTime lastDiscovered mclock.AbsTime
known, knownSelected, trusted bool known, knownSelected, trusted bool
@ -734,8 +735,8 @@ func (e *poolEntry) DecodeRLP(s *rlp.Stream) error {
e.node = enode.NewV4(pubkey, entry.IP, int(entry.Port), int(entry.Port)) e.node = enode.NewV4(pubkey, entry.IP, int(entry.Port), int(entry.Port))
e.addr = make(map[string]*poolEntryAddress) e.addr = make(map[string]*poolEntryAddress)
e.addr[addr.strKey()] = addr e.addr[addr.strKey()] = addr
e.addrSelect = *newWeightedRandomSelect() e.addrSelect = *utils.NewWeightedRandomSelect()
e.addrSelect.update(addr) e.addrSelect.Update(addr)
e.lastConnected = addr e.lastConnected = addr
e.connectStats = entry.CStat e.connectStats = entry.CStat
e.delayStats = entry.DStat e.delayStats = entry.DStat

View File

@ -14,35 +14,35 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package les package utils
import "sync" import "sync"
// execQueue implements a queue that executes function calls in a single thread, // ExecQueue implements a queue that executes function calls in a single thread,
// in the same order as they have been queued. // in the same order as they have been queued.
type execQueue struct { type ExecQueue struct {
mu sync.Mutex mu sync.Mutex
cond *sync.Cond cond *sync.Cond
funcs []func() funcs []func()
closeWait chan struct{} closeWait chan struct{}
} }
// newExecQueue creates a new execution queue. // NewExecQueue creates a new execution Queue.
func newExecQueue(capacity int) *execQueue { func NewExecQueue(capacity int) *ExecQueue {
q := &execQueue{funcs: make([]func(), 0, capacity)} q := &ExecQueue{funcs: make([]func(), 0, capacity)}
q.cond = sync.NewCond(&q.mu) q.cond = sync.NewCond(&q.mu)
go q.loop() go q.loop()
return q return q
} }
func (q *execQueue) loop() { func (q *ExecQueue) loop() {
for f := q.waitNext(false); f != nil; f = q.waitNext(true) { for f := q.waitNext(false); f != nil; f = q.waitNext(true) {
f() f()
} }
close(q.closeWait) close(q.closeWait)
} }
func (q *execQueue) waitNext(drop bool) (f func()) { func (q *ExecQueue) waitNext(drop bool) (f func()) {
q.mu.Lock() q.mu.Lock()
if drop && len(q.funcs) > 0 { if drop && len(q.funcs) > 0 {
// Remove the function that just executed. We do this here instead of when // Remove the function that just executed. We do this here instead of when
@ -60,20 +60,20 @@ func (q *execQueue) waitNext(drop bool) (f func()) {
return f return f
} }
func (q *execQueue) isClosed() bool { func (q *ExecQueue) isClosed() bool {
return q.closeWait != nil return q.closeWait != nil
} }
// canQueue returns true if more function calls can be added to the execution queue. // CanQueue returns true if more function calls can be added to the execution Queue.
func (q *execQueue) canQueue() bool { func (q *ExecQueue) CanQueue() bool {
q.mu.Lock() q.mu.Lock()
ok := !q.isClosed() && len(q.funcs) < cap(q.funcs) ok := !q.isClosed() && len(q.funcs) < cap(q.funcs)
q.mu.Unlock() q.mu.Unlock()
return ok return ok
} }
// queue adds a function call to the execution queue. Returns true if successful. // Queue adds a function call to the execution Queue. Returns true if successful.
func (q *execQueue) queue(f func()) bool { func (q *ExecQueue) Queue(f func()) bool {
q.mu.Lock() q.mu.Lock()
ok := !q.isClosed() && len(q.funcs) < cap(q.funcs) ok := !q.isClosed() && len(q.funcs) < cap(q.funcs)
if ok { if ok {
@ -84,16 +84,17 @@ func (q *execQueue) queue(f func()) bool {
return ok return ok
} }
// clear drops all queued functions // Clear drops all queued functions.
func (q *execQueue) clear() { func (q *ExecQueue) Clear() {
q.mu.Lock() q.mu.Lock()
q.funcs = q.funcs[:0] q.funcs = q.funcs[:0]
q.mu.Unlock() q.mu.Unlock()
} }
// quit stops the exec queue. // Quit stops the exec Queue.
// quit waits for the current execution to finish before returning. //
func (q *execQueue) quit() { // Quit waits for the current execution to finish before returning.
func (q *ExecQueue) Quit() {
q.mu.Lock() q.mu.Lock()
if !q.isClosed() { if !q.isClosed() {
q.closeWait = make(chan struct{}) q.closeWait = make(chan struct{})

View File

@ -14,21 +14,19 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package les package utils
import ( import "testing"
"testing"
)
func TestExecQueue(t *testing.T) { func TestExecQueue(t *testing.T) {
var ( var (
N = 10000 N = 10000
q = newExecQueue(N) q = NewExecQueue(N)
counter int counter int
execd = make(chan int) execd = make(chan int)
testexit = make(chan struct{}) testexit = make(chan struct{})
) )
defer q.quit() defer q.Quit()
defer close(testexit) defer close(testexit)
check := func(state string, wantOK bool) { check := func(state string, wantOK bool) {
@ -40,11 +38,11 @@ func TestExecQueue(t *testing.T) {
case <-testexit: case <-testexit:
} }
} }
if q.canQueue() != wantOK { if q.CanQueue() != wantOK {
t.Fatalf("canQueue() == %t for %s", !wantOK, state) t.Fatalf("CanQueue() == %t for %s", !wantOK, state)
} }
if q.queue(qf) != wantOK { if q.Queue(qf) != wantOK {
t.Fatalf("canQueue() == %t for %s", !wantOK, state) t.Fatalf("Queue() == %t for %s", !wantOK, state)
} }
} }
@ -57,6 +55,6 @@ func TestExecQueue(t *testing.T) {
t.Fatal("execution out of order") t.Fatal("execution out of order")
} }
} }
q.quit() q.Quit()
check("closed queue", false) check("closed queue", false)
} }

View File

@ -14,43 +14,30 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package les package utils
import ( import "math/rand"
"math/rand"
)
// wrsItem interface should be implemented by any entries that are to be selected from // wrsItem interface should be implemented by any entries that are to be selected from
// a weightedRandomSelect set. Note that recalculating monotonously decreasing item // a WeightedRandomSelect set. Note that recalculating monotonously decreasing item
// weights on-demand (without constantly calling update) is allowed // weights on-demand (without constantly calling Update) is allowed
type wrsItem interface { type wrsItem interface {
Weight() int64 Weight() int64
} }
// weightedRandomSelect is capable of weighted random selection from a set of items // WeightedRandomSelect is capable of weighted random selection from a set of items
type weightedRandomSelect struct { type WeightedRandomSelect struct {
root *wrsNode root *wrsNode
idx map[wrsItem]int idx map[wrsItem]int
} }
// newWeightedRandomSelect returns a new weightedRandomSelect structure // NewWeightedRandomSelect returns a new WeightedRandomSelect structure
func newWeightedRandomSelect() *weightedRandomSelect { func NewWeightedRandomSelect() *WeightedRandomSelect {
return &weightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[wrsItem]int)} return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[wrsItem]int)}
}
// update updates an item's weight, adds it if it was non-existent or removes it if
// the new weight is zero. Note that explicitly updating decreasing weights is not necessary.
func (w *weightedRandomSelect) update(item wrsItem) {
w.setWeight(item, item.Weight())
}
// remove removes an item from the set
func (w *weightedRandomSelect) remove(item wrsItem) {
w.setWeight(item, 0)
} }
// setWeight sets an item's weight to a specific value (removes it if zero) // setWeight sets an item's weight to a specific value (removes it if zero)
func (w *weightedRandomSelect) setWeight(item wrsItem, weight int64) { func (w *WeightedRandomSelect) setWeight(item wrsItem, weight int64) {
idx, ok := w.idx[item] idx, ok := w.idx[item]
if ok { if ok {
w.root.setWeight(idx, weight) w.root.setWeight(idx, weight)
@ -71,11 +58,22 @@ func (w *weightedRandomSelect) setWeight(item wrsItem, weight int64) {
} }
} }
// choose randomly selects an item from the set, with a chance proportional to its // Update updates an item's weight, adds it if it was non-existent or removes it if
// the new weight is zero. Note that explicitly updating decreasing weights is not necessary.
func (w *WeightedRandomSelect) Update(item wrsItem) {
w.setWeight(item, item.Weight())
}
// Remove removes an item from the set
func (w *WeightedRandomSelect) Remove(item wrsItem) {
w.setWeight(item, 0)
}
// Choose randomly selects an item from the set, with a chance proportional to its
// current weight. If the weight of the chosen element has been decreased since the // current weight. If the weight of the chosen element has been decreased since the
// last stored value, returns it with a newWeight/oldWeight chance, otherwise just // last stored value, returns it with a newWeight/oldWeight chance, otherwise just
// updates its weight and selects another one // updates its weight and selects another one
func (w *weightedRandomSelect) choose() wrsItem { func (w *WeightedRandomSelect) Choose() wrsItem {
for { for {
if w.root.sumWeight == 0 { if w.root.sumWeight == 0 {
return nil return nil
@ -154,7 +152,7 @@ func (n *wrsNode) setWeight(idx int, weight int64) int64 {
return diff return diff
} }
// choose recursively selects an item from the tree and returns it along with its weight // Choose recursively selects an item from the tree and returns it along with its weight
func (n *wrsNode) choose(val int64) (wrsItem, int64) { func (n *wrsNode) choose(val int64) (wrsItem, int64) {
for i, w := range n.weights { for i, w := range n.weights {
if val < w { if val < w {

View File

@ -14,7 +14,7 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package les package utils
import ( import (
"math/rand" "math/rand"
@ -36,15 +36,15 @@ func (t *testWrsItem) Weight() int64 {
func TestWeightedRandomSelect(t *testing.T) { func TestWeightedRandomSelect(t *testing.T) {
testFn := func(cnt int) { testFn := func(cnt int) {
s := newWeightedRandomSelect() s := NewWeightedRandomSelect()
w := -1 w := -1
list := make([]testWrsItem, cnt) list := make([]testWrsItem, cnt)
for i := range list { for i := range list {
list[i] = testWrsItem{idx: i, widx: &w} list[i] = testWrsItem{idx: i, widx: &w}
s.update(&list[i]) s.Update(&list[i])
} }
w = rand.Intn(cnt) w = rand.Intn(cnt)
c := s.choose() c := s.Choose()
if c == nil { if c == nil {
t.Errorf("expected item, got nil") t.Errorf("expected item, got nil")
} else { } else {
@ -53,7 +53,7 @@ func TestWeightedRandomSelect(t *testing.T) {
} }
} }
w = -2 w = -2
if s.choose() != nil { if s.Choose() != nil {
t.Errorf("expected nil, got item") t.Errorf("expected nil, got item")
} }
} }