tests, trie: use slices package for sorting (#27496)

Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
Dan Laine 2023-06-19 05:41:31 -04:00 committed by GitHub
parent 87e510d963
commit 50ecb16de0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 52 additions and 67 deletions

View File

@ -21,12 +21,12 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"sort"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"golang.org/x/exp/slices"
) )
type kv struct { type kv struct {
@ -34,12 +34,6 @@ type kv struct {
t bool t bool
} }
type entrySlice []*kv
func (p entrySlice) Len() int { return len(p) }
func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
type fuzzer struct { type fuzzer struct {
input io.Reader input io.Reader
exhausted bool exhausted bool
@ -97,14 +91,16 @@ func (f *fuzzer) fuzz() int {
if f.exhausted { if f.exhausted {
return 0 // input too short return 0 // input too short
} }
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
if len(entries) <= 1 { if len(entries) <= 1 {
return 0 return 0
} }
sort.Sort(entries) slices.SortFunc(entries, func(a, b *kv) bool {
return bytes.Compare(a.k, b.k) < 0
})
var ok = 0 var ok = 0
for { for {

View File

@ -23,7 +23,6 @@ import (
"fmt" "fmt"
"hash" "hash"
"io" "io"
"sort"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
@ -33,6 +32,7 @@ import (
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/trie/trienode" "github.com/ethereum/go-ethereum/trie/trienode"
"golang.org/x/crypto/sha3" "golang.org/x/crypto/sha3"
"golang.org/x/exp/slices"
) )
type fuzzer struct { type fuzzer struct {
@ -104,19 +104,6 @@ func (b *spongeBatch) Replay(w ethdb.KeyValueWriter) error { return nil }
type kv struct { type kv struct {
k, v []byte k, v []byte
} }
type kvs []kv
func (k kvs) Len() int {
return len(k)
}
func (k kvs) Less(i, j int) bool {
return bytes.Compare(k[i].k, k[j].k) < 0
}
func (k kvs) Swap(i, j int) {
k[j], k[i] = k[i], k[j]
}
// Fuzz is the fuzzing entry-point. // Fuzz is the fuzzing entry-point.
// The function must return // The function must return
@ -156,7 +143,7 @@ func (f *fuzzer) fuzz() int {
trieB = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) { trieB = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(spongeB, owner, path, hash, blob, dbB.Scheme()) rawdb.WriteTrieNode(spongeB, owner, path, hash, blob, dbB.Scheme())
}) })
vals kvs vals []kv
useful bool useful bool
maxElements = 10000 maxElements = 10000
// operate on unique keys only // operate on unique keys only
@ -192,7 +179,9 @@ func (f *fuzzer) fuzz() int {
dbA.Commit(rootA, false) dbA.Commit(rootA, false)
// Stacktrie requires sorted insertion // Stacktrie requires sorted insertion
sort.Sort(vals) slices.SortFunc(vals, func(a, b kv) bool {
return bytes.Compare(a.k, b.k) < 0
})
for _, kv := range vals { for _, kv := range vals {
if f.debugging { if f.debugging {
fmt.Printf("{\"%#x\" , \"%#x\"} // stacktrie.Update\n", kv.k, kv.v) fmt.Printf("{\"%#x\" , \"%#x\"} // stacktrie.Update\n", kv.k, kv.v)

View File

@ -84,6 +84,10 @@ type kv struct {
t bool t bool
} }
func (k *kv) less(other *kv) bool {
return bytes.Compare(k.k, other.k) < 0
}
func TestIteratorLargeData(t *testing.T) { func TestIteratorLargeData(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv) vals := make(map[string]*kv)

View File

@ -22,13 +22,13 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
mrand "math/rand" mrand "math/rand"
"sort"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/ethdb/memorydb"
"golang.org/x/exp/slices"
) )
// Prng is a pseudo random number generator seeded by strong randomness. // Prng is a pseudo random number generator seeded by strong randomness.
@ -165,21 +165,15 @@ func TestMissingKeyProof(t *testing.T) {
} }
} }
type entrySlice []*kv
func (p entrySlice) Len() int { return len(p) }
func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// TestRangeProof tests normal range proof with both edge proofs // TestRangeProof tests normal range proof with both edge proofs
// as the existent proof. The test cases are generated randomly. // as the existent proof. The test cases are generated randomly.
func TestRangeProof(t *testing.T) { func TestRangeProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
for i := 0; i < 500; i++ { for i := 0; i < 500; i++ {
start := mrand.Intn(len(entries)) start := mrand.Intn(len(entries))
end := mrand.Intn(len(entries)-start) + start + 1 end := mrand.Intn(len(entries)-start) + start + 1
@ -208,11 +202,11 @@ func TestRangeProof(t *testing.T) {
// The test cases are generated randomly. // The test cases are generated randomly.
func TestRangeProofWithNonExistentProof(t *testing.T) { func TestRangeProofWithNonExistentProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
for i := 0; i < 500; i++ { for i := 0; i < 500; i++ {
start := mrand.Intn(len(entries)) start := mrand.Intn(len(entries))
end := mrand.Intn(len(entries)-start) + start + 1 end := mrand.Intn(len(entries)-start) + start + 1
@ -280,11 +274,11 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
// - There exists a gap between the last element and the right edge proof // - There exists a gap between the last element and the right edge proof
func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
// Case 1 // Case 1
start, end := 100, 200 start, end := 100, 200
@ -337,11 +331,11 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
// non-existent one. // non-existent one.
func TestOneElementRangeProof(t *testing.T) { func TestOneElementRangeProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
// One element with existent edge proof, both edge proofs // One element with existent edge proof, both edge proofs
// point to the SAME key. // point to the SAME key.
@ -424,11 +418,11 @@ func TestOneElementRangeProof(t *testing.T) {
// The edge proofs can be nil. // The edge proofs can be nil.
func TestAllElementsProof(t *testing.T) { func TestAllElementsProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
var k [][]byte var k [][]byte
var v [][]byte var v [][]byte
@ -474,13 +468,13 @@ func TestAllElementsProof(t *testing.T) {
func TestSingleSideRangeProof(t *testing.T) { func TestSingleSideRangeProof(t *testing.T) {
for i := 0; i < 64; i++ { for i := 0; i < 64; i++ {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice var entries []*kv
for i := 0; i < 4096; i++ { for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false} value := &kv{randBytes(32), randBytes(20), false}
trie.MustUpdate(value.k, value.v) trie.MustUpdate(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
for _, pos := range cases { for _, pos := range cases {
@ -509,13 +503,13 @@ func TestSingleSideRangeProof(t *testing.T) {
func TestReverseSingleSideRangeProof(t *testing.T) { func TestReverseSingleSideRangeProof(t *testing.T) {
for i := 0; i < 64; i++ { for i := 0; i < 64; i++ {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice var entries []*kv
for i := 0; i < 4096; i++ { for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false} value := &kv{randBytes(32), randBytes(20), false}
trie.MustUpdate(value.k, value.v) trie.MustUpdate(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
for _, pos := range cases { for _, pos := range cases {
@ -545,11 +539,11 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
// The prover is expected to detect the error. // The prover is expected to detect the error.
func TestBadRangeProof(t *testing.T) { func TestBadRangeProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
for i := 0; i < 500; i++ { for i := 0; i < 500; i++ {
start := mrand.Intn(len(entries)) start := mrand.Intn(len(entries))
@ -648,11 +642,11 @@ func TestGappedRangeProof(t *testing.T) {
// TestSameSideProofs tests the element is not in the range covered by proofs // TestSameSideProofs tests the element is not in the range covered by proofs
func TestSameSideProofs(t *testing.T) { func TestSameSideProofs(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
pos := 1000 pos := 1000
first := decreaseKey(common.CopyBytes(entries[pos].k)) first := decreaseKey(common.CopyBytes(entries[pos].k))
@ -690,13 +684,13 @@ func TestSameSideProofs(t *testing.T) {
func TestHasRightElement(t *testing.T) { func TestHasRightElement(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice var entries []*kv
for i := 0; i < 4096; i++ { for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false} value := &kv{randBytes(32), randBytes(20), false}
trie.MustUpdate(value.k, value.v) trie.MustUpdate(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
var cases = []struct { var cases = []struct {
start int start int
@ -764,11 +758,11 @@ func TestHasRightElement(t *testing.T) {
// The first edge proof must be a non-existent proof. // The first edge proof must be a non-existent proof.
func TestEmptyRangeProof(t *testing.T) { func TestEmptyRangeProof(t *testing.T) {
trie, vals := randomTrie(4096) trie, vals := randomTrie(4096)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
var cases = []struct { var cases = []struct {
pos int pos int
@ -799,11 +793,11 @@ func TestEmptyRangeProof(t *testing.T) {
func TestBloatedProof(t *testing.T) { func TestBloatedProof(t *testing.T) {
// Use a small trie // Use a small trie
trie, kvs := nonRandomTrie(100) trie, kvs := nonRandomTrie(100)
var entries entrySlice var entries []*kv
for _, kv := range kvs { for _, kv := range kvs {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
var keys [][]byte var keys [][]byte
var vals [][]byte var vals [][]byte
@ -833,11 +827,11 @@ func TestBloatedProof(t *testing.T) {
// noop technically, but practically should be rejected. // noop technically, but practically should be rejected.
func TestEmptyValueRangeProof(t *testing.T) { func TestEmptyValueRangeProof(t *testing.T) {
trie, values := randomTrie(512) trie, values := randomTrie(512)
var entries entrySlice var entries []*kv
for _, kv := range values { for _, kv := range values {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
// Create a new entry with a slightly modified key // Create a new entry with a slightly modified key
mid := len(entries) / 2 mid := len(entries) / 2
@ -877,11 +871,11 @@ func TestEmptyValueRangeProof(t *testing.T) {
// practically should be rejected. // practically should be rejected.
func TestAllElementsEmptyValueRangeProof(t *testing.T) { func TestAllElementsEmptyValueRangeProof(t *testing.T) {
trie, values := randomTrie(512) trie, values := randomTrie(512)
var entries entrySlice var entries []*kv
for _, kv := range values { for _, kv := range values {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
// Create a new entry with a slightly modified key // Create a new entry with a slightly modified key
mid := len(entries) / 2 mid := len(entries) / 2
@ -983,11 +977,11 @@ func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b,
func benchmarkVerifyRangeProof(b *testing.B, size int) { func benchmarkVerifyRangeProof(b *testing.B, size int) {
trie, vals := randomTrie(8192) trie, vals := randomTrie(8192)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
start := 2 start := 2
end := start + size end := start + size
@ -1020,11 +1014,11 @@ func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof
func benchmarkVerifyRangeNoProof(b *testing.B, size int) { func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
trie, vals := randomTrie(size) trie, vals := randomTrie(size)
var entries entrySlice var entries []*kv
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, kv) entries = append(entries, kv)
} }
sort.Sort(entries) slices.SortFunc(entries, (*kv).less)
var keys [][]byte var keys [][]byte
var values [][]byte var values [][]byte

View File

@ -18,10 +18,10 @@ package trienode
import ( import (
"fmt" "fmt"
"sort"
"strings" "strings"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"golang.org/x/exp/slices"
) )
// Node is a wrapper which contains the encoded blob of the trie node and its // Node is a wrapper which contains the encoded blob of the trie node and its
@ -100,12 +100,14 @@ func NewNodeSet(owner common.Hash) *NodeSet {
// ForEachWithOrder iterates the nodes with the order from bottom to top, // ForEachWithOrder iterates the nodes with the order from bottom to top,
// right to left, nodes with the longest path will be iterated first. // right to left, nodes with the longest path will be iterated first.
func (set *NodeSet) ForEachWithOrder(callback func(path string, n *Node)) { func (set *NodeSet) ForEachWithOrder(callback func(path string, n *Node)) {
var paths sort.StringSlice var paths []string
for path := range set.Nodes { for path := range set.Nodes {
paths = append(paths, path) paths = append(paths, path)
} }
// Bottom-up, longest path first // Bottom-up, longest path first
sort.Sort(sort.Reverse(paths)) slices.SortFunc(paths, func(a, b string) bool {
return a > b // Sort in reverse order
})
for _, path := range paths { for _, path := range paths {
callback(path, set.Nodes[path].Unwrap()) callback(path, set.Nodes[path].Unwrap())
} }