diff --git a/p2p/discover/table.go b/p2p/discover/table.go index d08f8a6c6..41d5ac6e3 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -672,15 +672,14 @@ func (h *nodesByDistance) push(n *node, maxElems int) { ix := sort.Search(len(h.entries), func(i int) bool { return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0 }) + + end := len(h.entries) if len(h.entries) < maxElems { h.entries = append(h.entries, n) } - if ix == len(h.entries) { - // farther away than all nodes we already have. - // if there was room for it, the node is now the last element. - } else { - // slide existing entries down to make room - // this will overwrite the entry we just appended. + if ix < end { + // Slide existing entries down to make room. + // This will overwrite the entry we just appended. copy(h.entries[ix+1:], h.entries[ix:]) h.entries[ix] = n } diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 5f40c967f..1ef63fe01 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -396,6 +396,59 @@ func TestTable_revalidateSyncRecord(t *testing.T) { } } +func TestNodesPush(t *testing.T) { + var target enode.ID + n1 := nodeAtDistance(target, 255, intIP(1)) + n2 := nodeAtDistance(target, 254, intIP(2)) + n3 := nodeAtDistance(target, 253, intIP(3)) + perm := [][]*node{ + {n3, n2, n1}, + {n3, n1, n2}, + {n2, n3, n1}, + {n2, n1, n3}, + {n1, n3, n2}, + {n1, n2, n3}, + } + + // Insert all permutations into lists with size limit 3. + for _, nodes := range perm { + list := nodesByDistance{target: target} + for _, n := range nodes { + list.push(n, 3) + } + if !slicesEqual(list.entries, perm[0], nodeIDEqual) { + t.Fatal("not equal") + } + } + + // Insert all permutations into lists with size limit 2. + for _, nodes := range perm { + list := nodesByDistance{target: target} + for _, n := range nodes { + list.push(n, 2) + } + if !slicesEqual(list.entries, perm[0][:2], nodeIDEqual) { + t.Fatal("not equal") + } + } +} + +func nodeIDEqual(n1, n2 *node) bool { + return n1.ID() == n2.ID() +} + +func slicesEqual[T any](s1, s2 []T, check func(e1, e2 T) bool) bool { + if len(s1) != len(s2) { + return false + } + for i := range s1 { + if !check(s1[i], s2[i]) { + return false + } + } + return true +} + // gen wraps quick.Value so it's easier to use. // it generates a random value of the given value's type. func gen(typ interface{}, rand *rand.Rand) interface{} {