// Copyright 2020 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // The go-ethereum library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. package utils import ( "math/rand" "testing" "github.com/ethereum/go-ethereum/p2p/enode" ) const ( ltTolerance = 0.03 ltRounds = 7 ) type ( ltNode struct { addr, id int value, exp float64 cost uint reqRate float64 reqMax, runCount int lastTotalCost uint served, dropped int } ltResult struct { node *ltNode ch chan struct{} } limTest struct { limiter *Limiter results chan ltResult runCount int expCost, totalCost uint } ) func (lt *limTest) request(n *ltNode) { var ( address string id enode.ID ) if n.addr >= 0 { address = string([]byte{byte(n.addr)}) } else { var b [32]byte rand.Read(b[:]) address = string(b[:]) } if n.id >= 0 { id = enode.ID{byte(n.id)} } else { rand.Read(id[:]) } lt.runCount++ n.runCount++ cch := lt.limiter.Add(id, address, n.value, n.cost) go func() { lt.results <- ltResult{n, <-cch} }() } func (lt *limTest) moreRequests(n *ltNode) { maxStart := int(float64(lt.totalCost-n.lastTotalCost) * n.reqRate) if maxStart != 0 { n.lastTotalCost = lt.totalCost } for n.reqMax > n.runCount && maxStart > 0 { lt.request(n) maxStart-- } } func (lt *limTest) process() { res := <-lt.results lt.runCount-- res.node.runCount-- if res.ch != nil { res.node.served++ if res.node.exp != 0 { lt.expCost += res.node.cost } lt.totalCost += res.node.cost close(res.ch) } else { res.node.dropped++ } } func TestLimiter(t *testing.T) { limTests := [][]*ltNode{ { // one id from an individual address and two ids from a shared address {addr: 0, id: 0, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.5}, {addr: 1, id: 1, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25}, {addr: 1, id: 2, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25}, }, { // varying request costs {addr: 0, id: 0, value: 0, cost: 10, reqRate: 0.2, reqMax: 1, exp: 0.5}, {addr: 1, id: 1, value: 0, cost: 3, reqRate: 0.5, reqMax: 1, exp: 0.25}, {addr: 1, id: 2, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25}, }, { // different request rate {addr: 0, id: 0, value: 0, cost: 1, reqRate: 2, reqMax: 2, exp: 0.5}, {addr: 1, id: 1, value: 0, cost: 1, reqRate: 10, reqMax: 10, exp: 0.25}, {addr: 1, id: 2, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25}, }, { // adding value {addr: 0, id: 0, value: 3, cost: 1, reqRate: 1, reqMax: 1, exp: (0.5 + 0.3) / 2}, {addr: 1, id: 1, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25 / 2}, {addr: 1, id: 2, value: 7, cost: 1, reqRate: 1, reqMax: 1, exp: (0.25 + 0.7) / 2}, }, { // DoS attack from a single address with a single id {addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, {addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, {addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, {addr: 3, id: 3, value: 0, cost: 1, reqRate: 10, reqMax: 1000000000, exp: 0}, }, { // DoS attack from a single address with different ids {addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, {addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, {addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, {addr: 3, id: -1, value: 0, cost: 1, reqRate: 1, reqMax: 1000000000, exp: 0}, }, { // DDoS attack from different addresses with a single id {addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, {addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, {addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, {addr: -1, id: 3, value: 0, cost: 1, reqRate: 1, reqMax: 1000000000, exp: 0}, }, { // DDoS attack from different addresses with different ids {addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, {addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, {addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333}, {addr: -1, id: -1, value: 0, cost: 1, reqRate: 1, reqMax: 1000000000, exp: 0}, }, } lt := &limTest{ limiter: NewLimiter(100), results: make(chan ltResult), } for _, test := range limTests { lt.expCost, lt.totalCost = 0, 0 iterCount := 10000 for j := 0; j < ltRounds; j++ { // try to reach expected target range in multiple rounds with increasing iteration counts last := j == ltRounds-1 for _, n := range test { lt.request(n) } for i := 0; i < iterCount; i++ { lt.process() for _, n := range test { lt.moreRequests(n) } } for lt.runCount > 0 { lt.process() } if spamRatio := 1 - float64(lt.expCost)/float64(lt.totalCost); spamRatio > 0.5*(1+ltTolerance) { t.Errorf("Spam ratio too high (%f)", spamRatio) } fail, success := false, true for _, n := range test { if n.exp != 0 { if n.dropped > 0 { t.Errorf("Dropped %d requests of non-spam node", n.dropped) fail = true } r := float64(n.served) * float64(n.cost) / float64(lt.expCost) if r < n.exp*(1-ltTolerance) || r > n.exp*(1+ltTolerance) { if last { // print error only if the target is still not reached in the last round t.Errorf("Request ratio (%f) does not match expected value (%f)", r, n.exp) } success = false } } } if fail || success { break } // neither failed nor succeeded; try more iterations to reach probability targets iterCount *= 2 } } lt.limiter.Stop() }