diff --git a/les/utils/limiter.go b/les/utils/limiter.go new file mode 100644 index 000000000..0cc2d7b26 --- /dev/null +++ b/les/utils/limiter.go @@ -0,0 +1,405 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package utils + +import ( + "sort" + "sync" + + "github.com/ethereum/go-ethereum/p2p/enode" +) + +const maxSelectionWeight = 1000000000 // maximum selection weight of each individual node/address group + +// Limiter protects a network request serving mechanism from denial-of-service attacks. +// It limits the total amount of resources used for serving requests while ensuring that +// the most valuable connections always have a reasonable chance of being served. +type Limiter struct { + lock sync.Mutex + cond *sync.Cond + quit bool + + nodes map[enode.ID]*nodeQueue + addresses map[string]*addressGroup + addressSelect, valueSelect *WeightedRandomSelect + maxValue float64 + maxCost, sumCost, sumCostLimit uint + selectAddressNext bool +} + +// nodeQueue represents queued requests coming from a single node ID +type nodeQueue struct { + queue []request // always nil if penaltyCost != 0 + id enode.ID + address string + value float64 + flatWeight, valueWeight uint64 // current selection weights in the address/value selectors + sumCost uint // summed cost of requests queued by the node + penaltyCost uint // cumulative cost of dropped requests since last processed request + groupIndex int +} + +// addressGroup is a group of node IDs that have sent their last requests from the same +// network address +type addressGroup struct { + nodes []*nodeQueue + nodeSelect *WeightedRandomSelect + sumFlatWeight, groupWeight uint64 +} + +// request represents an incoming request scheduled for processing +type request struct { + process chan chan struct{} + cost uint +} + +// flatWeight distributes weights equally between each active network address +func flatWeight(item interface{}) uint64 { return item.(*nodeQueue).flatWeight } + +// add adds the node queue to the address group. It is the caller's responsibility to +// add the address group to the address map and the address selector if it wasn't +// there before. +func (ag *addressGroup) add(nq *nodeQueue) { + if nq.groupIndex != -1 { + panic("added node queue is already in an address group") + } + l := len(ag.nodes) + nq.groupIndex = l + ag.nodes = append(ag.nodes, nq) + ag.sumFlatWeight += nq.flatWeight + ag.groupWeight = ag.sumFlatWeight / uint64(l+1) + ag.nodeSelect.Update(ag.nodes[l]) +} + +// update updates the selection weight of the node queue inside the address group. +// It is the caller's responsibility to update the group's selection weight in the +// address selector. +func (ag *addressGroup) update(nq *nodeQueue, weight uint64) { + if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq { + panic("updated node queue is not in this address group") + } + ag.sumFlatWeight += weight - nq.flatWeight + nq.flatWeight = weight + ag.groupWeight = ag.sumFlatWeight / uint64(len(ag.nodes)) + ag.nodeSelect.Update(nq) +} + +// remove removes the node queue from the address group. It is the caller's responsibility +// to remove the address group from the address map if it is empty. +func (ag *addressGroup) remove(nq *nodeQueue) { + if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq { + panic("removed node queue is not in this address group") + } + + l := len(ag.nodes) - 1 + if nq.groupIndex != l { + ag.nodes[nq.groupIndex] = ag.nodes[l] + ag.nodes[nq.groupIndex].groupIndex = nq.groupIndex + } + nq.groupIndex = -1 + ag.nodes = ag.nodes[:l] + ag.sumFlatWeight -= nq.flatWeight + if l >= 1 { + ag.groupWeight = ag.sumFlatWeight / uint64(l) + } else { + ag.groupWeight = 0 + } + ag.nodeSelect.Remove(nq) +} + +// choose selects one of the node queues belonging to the address group +func (ag *addressGroup) choose() *nodeQueue { + return ag.nodeSelect.Choose().(*nodeQueue) +} + +// NewLimiter creates a new Limiter +func NewLimiter(sumCostLimit uint) *Limiter { + l := &Limiter{ + addressSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*addressGroup).groupWeight }), + valueSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*nodeQueue).valueWeight }), + nodes: make(map[enode.ID]*nodeQueue), + addresses: make(map[string]*addressGroup), + sumCostLimit: sumCostLimit, + } + l.cond = sync.NewCond(&l.lock) + go l.processLoop() + return l +} + +// selectionWeights calculates the selection weights of a node for both the address and +// the value selector. The selection weight depends on the next request cost or the +// summed cost of recently dropped requests. +func (l *Limiter) selectionWeights(reqCost uint, value float64) (flatWeight, valueWeight uint64) { + if value > l.maxValue { + l.maxValue = value + } + if value > 0 { + // normalize value to <= 1 + value /= l.maxValue + } + if reqCost > l.maxCost { + l.maxCost = reqCost + } + relCost := float64(reqCost) / float64(l.maxCost) + var f float64 + if relCost <= 0.001 { + f = 1 + } else { + f = 0.001 / relCost + } + f *= maxSelectionWeight + flatWeight, valueWeight = uint64(f), uint64(f*value) + if flatWeight == 0 { + flatWeight = 1 + } + return +} + +// Add adds a new request to the node queue belonging to the given id. Value belongs +// to the requesting node. A higher value gives the request a higher chance of being +// served quickly in case of heavy load or a DDoS attack. Cost is a rough estimate +// of the serving cost of the request. A lower cost also gives the request a +// better chance. +func (l *Limiter) Add(id enode.ID, address string, value float64, reqCost uint) chan chan struct{} { + l.lock.Lock() + defer l.lock.Unlock() + + process := make(chan chan struct{}, 1) + if l.quit { + close(process) + return process + } + if reqCost == 0 { + reqCost = 1 + } + if nq, ok := l.nodes[id]; ok { + if nq.queue != nil { + nq.queue = append(nq.queue, request{process, reqCost}) + nq.sumCost += reqCost + nq.value = value + if address != nq.address { + // known id sending request from a new address, move to different address group + l.removeFromGroup(nq) + l.addToGroup(nq, address) + } + } else { + // already waiting on a penalty, just add to the penalty cost and drop the request + nq.penaltyCost += reqCost + l.update(nq) + close(process) + return process + } + } else { + nq := &nodeQueue{ + queue: []request{{process, reqCost}}, + id: id, + value: value, + sumCost: reqCost, + groupIndex: -1, + } + nq.flatWeight, nq.valueWeight = l.selectionWeights(reqCost, value) + if len(l.nodes) == 0 { + l.cond.Signal() + } + l.nodes[id] = nq + if nq.valueWeight != 0 { + l.valueSelect.Update(nq) + } + l.addToGroup(nq, address) + } + l.sumCost += reqCost + if l.sumCost > l.sumCostLimit { + l.dropRequests() + } + return process +} + +// update updates the selection weights of the node queue +func (l *Limiter) update(nq *nodeQueue) { + var cost uint + if nq.queue != nil { + cost = nq.queue[0].cost + } else { + cost = nq.penaltyCost + } + flatWeight, valueWeight := l.selectionWeights(cost, nq.value) + ag := l.addresses[nq.address] + ag.update(nq, flatWeight) + l.addressSelect.Update(ag) + nq.valueWeight = valueWeight + l.valueSelect.Update(nq) +} + +// addToGroup adds the node queue to the given address group. The group is created if +// it does not exist yet. +func (l *Limiter) addToGroup(nq *nodeQueue, address string) { + nq.address = address + ag := l.addresses[address] + if ag == nil { + ag = &addressGroup{nodeSelect: NewWeightedRandomSelect(flatWeight)} + l.addresses[address] = ag + } + ag.add(nq) + l.addressSelect.Update(ag) +} + +// removeFromGroup removes the node queue from its address group +func (l *Limiter) removeFromGroup(nq *nodeQueue) { + ag := l.addresses[nq.address] + ag.remove(nq) + if len(ag.nodes) == 0 { + delete(l.addresses, nq.address) + } + l.addressSelect.Update(ag) +} + +// remove removes the node queue from its address group, the nodes map and the value +// selector +func (l *Limiter) remove(nq *nodeQueue) { + l.removeFromGroup(nq) + if nq.valueWeight != 0 { + l.valueSelect.Remove(nq) + } + delete(l.nodes, nq.id) +} + +// choose selects the next node queue to process. +func (l *Limiter) choose() *nodeQueue { + if l.valueSelect.IsEmpty() || l.selectAddressNext { + if ag, ok := l.addressSelect.Choose().(*addressGroup); ok { + l.selectAddressNext = false + return ag.choose() + } + } + nq, _ := l.valueSelect.Choose().(*nodeQueue) + l.selectAddressNext = true + return nq +} + +// processLoop processes requests sequentially +func (l *Limiter) processLoop() { + l.lock.Lock() + defer l.lock.Unlock() + + for { + if l.quit { + for _, nq := range l.nodes { + for _, request := range nq.queue { + close(request.process) + } + } + return + } + nq := l.choose() + if nq == nil { + l.cond.Wait() + continue + } + if nq.queue != nil { + request := nq.queue[0] + nq.queue = nq.queue[1:] + nq.sumCost -= request.cost + l.sumCost -= request.cost + l.lock.Unlock() + ch := make(chan struct{}) + request.process <- ch + <-ch + l.lock.Lock() + if len(nq.queue) > 0 { + l.update(nq) + } else { + l.remove(nq) + } + } else { + // penalized queue removed, next request will be added to a clean queue + l.remove(nq) + } + } +} + +// Stop stops the processing loop. All queued and future requests are rejected. +func (l *Limiter) Stop() { + l.lock.Lock() + defer l.lock.Unlock() + + l.quit = true + l.cond.Signal() +} + +type ( + dropList []dropListItem + dropListItem struct { + nq *nodeQueue + priority float64 + } +) + +func (l dropList) Len() int { + return len(l) +} + +func (l dropList) Less(i, j int) bool { + return l[i].priority < l[j].priority +} + +func (l dropList) Swap(i, j int) { + l[i], l[j] = l[j], l[i] +} + +// dropRequests selects the nodes with the highest queued request cost to selection +// weight ratio and drops their queued request. The empty node queues stay in the +// selectors with a low selection weight in order to penalize these nodes. +func (l *Limiter) dropRequests() { + var ( + sumValue float64 + list dropList + ) + for _, nq := range l.nodes { + sumValue += nq.value + } + for _, nq := range l.nodes { + if nq.sumCost == 0 { + continue + } + w := 1 / float64(len(l.addresses)*len(l.addresses[nq.address].nodes)) + if sumValue > 0 { + w += nq.value / sumValue + } + list = append(list, dropListItem{ + nq: nq, + priority: w / float64(nq.sumCost), + }) + } + sort.Sort(list) + for _, item := range list { + for _, request := range item.nq.queue { + close(request.process) + } + // make the queue penalized; no more requests are accepted until the node is + // selected based on the penalty cost which is the cumulative cost of all dropped + // requests. This ensures that sending excess requests is always penalized + // and incentivizes the sender to stop for a while if no replies are received. + item.nq.queue = nil + item.nq.penaltyCost = item.nq.sumCost + l.sumCost -= item.nq.sumCost // penalty costs are not counted in sumCost + item.nq.sumCost = 0 + l.update(item.nq) + if l.sumCost <= l.sumCostLimit/2 { + return + } + } +} diff --git a/les/utils/limiter_test.go b/les/utils/limiter_test.go new file mode 100644 index 000000000..43af3309a --- /dev/null +++ b/les/utils/limiter_test.go @@ -0,0 +1,206 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package utils + +import ( + "math/rand" + "testing" + + "github.com/ethereum/go-ethereum/p2p/enode" +) + +const ( + ltTolerance = 0.03 + ltRounds = 7 +) + +type ( + ltNode struct { + addr, id int + value, exp float64 + cost uint + reqRate float64 + reqMax, runCount int + lastTotalCost uint + + served, dropped int + } + + ltResult struct { + node *ltNode + ch chan struct{} + } + + limTest struct { + limiter *Limiter + results chan ltResult + runCount int + expCost, totalCost uint + } +) + +func (lt *limTest) request(n *ltNode) { + var ( + address string + id enode.ID + ) + if n.addr >= 0 { + address = string([]byte{byte(n.addr)}) + } else { + var b [32]byte + rand.Read(b[:]) + address = string(b[:]) + } + if n.id >= 0 { + id = enode.ID{byte(n.id)} + } else { + rand.Read(id[:]) + } + lt.runCount++ + n.runCount++ + cch := lt.limiter.Add(id, address, n.value, n.cost) + go func() { + lt.results <- ltResult{n, <-cch} + }() +} + +func (lt *limTest) moreRequests(n *ltNode) { + maxStart := int(float64(lt.totalCost-n.lastTotalCost) * n.reqRate) + if maxStart != 0 { + n.lastTotalCost = lt.totalCost + } + for n.reqMax > n.runCount && maxStart > 0 { + lt.request(n) + maxStart-- + } +} + +func (lt *limTest) process() { + res := <-lt.results + lt.runCount-- + res.node.runCount-- + if res.ch != nil { + res.node.served++ + if res.node.exp != 0 { + lt.expCost += res.node.cost + } + lt.totalCost += res.node.cost + close(res.ch) + } else { + res.node.dropped++ + } +} + +func TestLimiter(t *testing.T) { + limTests := [][]*ltNode{ + { // one id from an individual address and two ids from a shared address + {addr: 0, id: 0, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.5}, + {addr: 1, id: 1, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25}, + {addr: 1, id: 2, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25}, + }, + { // varying request costs + {addr: 0, id: 0, value: 0, cost: 10, reqRate: 0.2, reqMax: 1, exp: 0.5}, + {addr: 1, id: 1, value: 0, cost: 3, reqRate: 0.5, reqMax: 1, exp: 0.25}, + {addr: 1, id: 2, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25}, + }, + { // different request rate + {addr: 0, id: 0, value: 0, cost: 1, reqRate: 2, reqMax: 2, exp: 0.5}, + {addr: 1, id: 1, value: 0, cost: 1, reqRate: 10, reqMax: 10, exp: 0.25}, + {addr: 1, id: 2, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25}, + }, + { // adding value + {addr: 0, id: 0, value: 3, cost: 1, reqRate: 1, reqMax: 1, exp: (0.5 + 0.3) / 2}, + {addr: 1, id: 1, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25 / 2}, + {addr: 1, id: 2, value: 7, cost: 1, reqRate: 1, reqMax: 1, exp: (0.25 + 0.7) / 2}, + }, + { // DoS attack from a single address with a single id + {addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, + {addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, + {addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, + {addr: 3, id: 3, value: 0, cost: 1, reqRate: 10, reqMax: 1000000000, exp: 0}, + }, + { // DoS attack from a single address with different ids + {addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, + {addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, + {addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, + {addr: 3, id: -1, value: 0, cost: 1, reqRate: 1, reqMax: 1000000000, exp: 0}, + }, + { // DDoS attack from different addresses with a single id + {addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, + {addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, + {addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, + {addr: -1, id: 3, value: 0, cost: 1, reqRate: 1, reqMax: 1000000000, exp: 0}, + }, + { // DDoS attack from different addresses with different ids + {addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, + {addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, + {addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, + {addr: -1, id: -1, value: 0, cost: 1, reqRate: 1, reqMax: 1000000000, exp: 0}, + }, + } + + lt := &limTest{ + limiter: NewLimiter(100), + results: make(chan ltResult), + } + for _, test := range limTests { + lt.expCost, lt.totalCost = 0, 0 + iterCount := 10000 + for j := 0; j < ltRounds; j++ { + // try to reach expected target range in multiple rounds with increasing iteration counts + last := j == ltRounds-1 + for _, n := range test { + lt.request(n) + } + for i := 0; i < iterCount; i++ { + lt.process() + for _, n := range test { + lt.moreRequests(n) + } + } + for lt.runCount > 0 { + lt.process() + } + if spamRatio := 1 - float64(lt.expCost)/float64(lt.totalCost); spamRatio > 0.5*(1+ltTolerance) { + t.Errorf("Spam ratio too high (%f)", spamRatio) + } + fail, success := false, true + for _, n := range test { + if n.exp != 0 { + if n.dropped > 0 { + t.Errorf("Dropped %d requests of non-spam node", n.dropped) + fail = true + } + r := float64(n.served) * float64(n.cost) / float64(lt.expCost) + if r < n.exp*(1-ltTolerance) || r > n.exp*(1+ltTolerance) { + if last { + // print error only if the target is still not reached in the last round + t.Errorf("Request ratio (%f) does not match expected value (%f)", r, n.exp) + } + success = false + } + } + } + if fail || success { + break + } + // neither failed nor succeeded; try more iterations to reach probability targets + iterCount *= 2 + } + } + lt.limiter.Stop() +} diff --git a/les/utils/weighted_select.go b/les/utils/weighted_select.go index 9b4413afb..486b00820 100644 --- a/les/utils/weighted_select.go +++ b/les/utils/weighted_select.go @@ -52,17 +52,17 @@ func (w *WeightedRandomSelect) Remove(item WrsItem) { // IsEmpty returns true if the set is empty func (w *WeightedRandomSelect) IsEmpty() bool { - return w.root.sumWeight == 0 + return w.root.sumCost == 0 } // setWeight sets an item's weight to a specific value (removes it if zero) func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) { - if weight > math.MaxInt64-w.root.sumWeight { - // old weight is still included in sumWeight, remove and check again + if weight > math.MaxInt64-w.root.sumCost { + // old weight is still included in sumCost, remove and check again w.setWeight(item, 0) - if weight > math.MaxInt64-w.root.sumWeight { - log.Error("WeightedRandomSelect overflow", "sumWeight", w.root.sumWeight, "new weight", weight) - weight = math.MaxInt64 - w.root.sumWeight + if weight > math.MaxInt64-w.root.sumCost { + log.Error("WeightedRandomSelect overflow", "sumCost", w.root.sumCost, "new weight", weight) + weight = math.MaxInt64 - w.root.sumCost } } idx, ok := w.idx[item] @@ -75,9 +75,9 @@ func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) { if weight != 0 { if w.root.itemCnt == w.root.maxItems { // add a new level - newRoot := &wrsNode{sumWeight: w.root.sumWeight, itemCnt: w.root.itemCnt, level: w.root.level + 1, maxItems: w.root.maxItems * wrsBranches} + newRoot := &wrsNode{sumCost: w.root.sumCost, itemCnt: w.root.itemCnt, level: w.root.level + 1, maxItems: w.root.maxItems * wrsBranches} newRoot.items[0] = w.root - newRoot.weights[0] = w.root.sumWeight + newRoot.weights[0] = w.root.sumCost w.root = newRoot } w.idx[item] = w.root.insert(item, weight) @@ -91,10 +91,10 @@ func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) { // updates its weight and selects another one func (w *WeightedRandomSelect) Choose() WrsItem { for { - if w.root.sumWeight == 0 { + if w.root.sumCost == 0 { return nil } - val := uint64(rand.Int63n(int64(w.root.sumWeight))) + val := uint64(rand.Int63n(int64(w.root.sumCost))) choice, lastWeight := w.root.choose(val) weight := w.wfn(choice) if weight != lastWeight { @@ -112,7 +112,7 @@ const wrsBranches = 8 // max number of branches in the wrsNode tree type wrsNode struct { items [wrsBranches]interface{} weights [wrsBranches]uint64 - sumWeight uint64 + sumCost uint64 level, itemCnt, maxItems int } @@ -126,7 +126,7 @@ func (n *wrsNode) insert(item WrsItem, weight uint64) int { } } n.itemCnt++ - n.sumWeight += weight + n.sumCost += weight n.weights[branch] += weight if n.level == 0 { n.items[branch] = item @@ -150,7 +150,7 @@ func (n *wrsNode) setWeight(idx int, weight uint64) uint64 { oldWeight := n.weights[idx] n.weights[idx] = weight diff := weight - oldWeight - n.sumWeight += diff + n.sumCost += diff if weight == 0 { n.items[idx] = nil n.itemCnt-- @@ -161,7 +161,7 @@ func (n *wrsNode) setWeight(idx int, weight uint64) uint64 { branch := idx / branchItems diff := n.items[branch].(*wrsNode).setWeight(idx-branch*branchItems, weight) n.weights[branch] += diff - n.sumWeight += diff + n.sumCost += diff if weight == 0 { n.itemCnt-- }