missed utils package

This commit is contained in:
Roy Crihfield 2024-03-19 12:46:46 +08:00
parent 45b051ba61
commit d559f79e90

View File

@ -9,6 +9,7 @@ import (
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/triedb"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -43,16 +44,26 @@ var (
func TestSymmetricDifferenceIterator(t *testing.T) { func TestSymmetricDifferenceIterator(t *testing.T) {
t.Run("with no difference", func(t *testing.T) { t.Run("with no difference", func(t *testing.T) {
db := trie.NewDatabase(rawdb.NewMemoryDatabase()) db := triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil)
triea := trie.NewEmpty(db) triea := trie.NewEmpty(db)
di := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), triea.NodeIterator(nil))
ita, err := triea.NodeIterator(nil)
assert.NoError(t, err)
itb, err := triea.NodeIterator(nil)
assert.NoError(t, err)
di := utils.NewSymmetricDifferenceIterator(ita, itb)
for di.Next(true) { for di.Next(true) {
t.Errorf("iterator should not yield any elements") t.Errorf("iterator should not yield any elements")
} }
assert.Equal(t, 0, di.Count()) assert.Equal(t, 0, di.Count())
triea.MustUpdate([]byte("foo"), []byte("bar")) triea.MustUpdate([]byte("foo"), []byte("bar"))
di = utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), triea.NodeIterator(nil)) ita, err = triea.NodeIterator(nil)
assert.NoError(t, err)
itb, err = triea.NodeIterator(nil)
assert.NoError(t, err)
di = utils.NewSymmetricDifferenceIterator(ita, itb)
for di.Next(true) { for di.Next(true) {
t.Errorf("iterator should not yield any elements") t.Errorf("iterator should not yield any elements")
} }
@ -60,7 +71,11 @@ func TestSymmetricDifferenceIterator(t *testing.T) {
assert.Equal(t, 2, di.Count()) assert.Equal(t, 2, di.Count())
trieb := trie.NewEmpty(db) trieb := trie.NewEmpty(db)
di = utils.NewSymmetricDifferenceIterator(triea.NodeIterator([]byte("jars")), trieb.NodeIterator(nil)) ita, err = triea.NodeIterator([]byte("jars"))
assert.NoError(t, err)
itb, err = trieb.NodeIterator(nil)
assert.NoError(t, err)
di = utils.NewSymmetricDifferenceIterator(ita, itb)
for di.Next(true) { for di.Next(true) {
t.Errorf("iterator should not yield any elements") t.Errorf("iterator should not yield any elements")
} }
@ -75,14 +90,18 @@ func TestSymmetricDifferenceIterator(t *testing.T) {
}) })
t.Run("small difference", func(t *testing.T) { t.Run("small difference", func(t *testing.T) {
dba := trie.NewDatabase(rawdb.NewMemoryDatabase()) dba := triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil)
triea := trie.NewEmpty(dba) triea := trie.NewEmpty(dba)
dbb := trie.NewDatabase(rawdb.NewMemoryDatabase()) dbb := triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil)
trieb := trie.NewEmpty(dbb) trieb := trie.NewEmpty(dbb)
trieb.MustUpdate([]byte("foo"), []byte("bar")) trieb.MustUpdate([]byte("foo"), []byte("bar"))
di := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) ita, err := triea.NodeIterator(nil)
assert.NoError(t, err)
itb, err := trieb.NodeIterator(nil)
assert.NoError(t, err)
di := utils.NewSymmetricDifferenceIterator(ita, itb)
leaves := 0 leaves := 0
for di.Next(true) { for di.Next(true) {
if di.Leaf() { if di.Leaf() {
@ -96,7 +115,11 @@ func TestSymmetricDifferenceIterator(t *testing.T) {
assert.Equal(t, 2, di.Count()) assert.Equal(t, 2, di.Count())
trieb.MustUpdate([]byte("quux"), []byte("bars")) trieb.MustUpdate([]byte("quux"), []byte("bars"))
di = utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator([]byte("quux"))) ita, err = triea.NodeIterator(nil)
assert.NoError(t, err)
itb, err = trieb.NodeIterator([]byte("quux"))
assert.NoError(t, err)
di = utils.NewSymmetricDifferenceIterator(ita, itb)
leaves = 0 leaves = 0
for di.Next(true) { for di.Next(true) {
if di.Leaf() { if di.Leaf() {
@ -110,12 +133,12 @@ func TestSymmetricDifferenceIterator(t *testing.T) {
assert.Equal(t, 1, di.Count()) assert.Equal(t, 1, di.Count())
}) })
dba := trie.NewDatabase(rawdb.NewMemoryDatabase()) dba := triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil)
triea := trie.NewEmpty(dba) triea := trie.NewEmpty(dba)
for _, val := range testdata1 { for _, val := range testdata1 {
triea.MustUpdate([]byte(val.k), []byte(val.v)) triea.MustUpdate([]byte(val.k), []byte(val.v))
} }
dbb := trie.NewDatabase(rawdb.NewMemoryDatabase()) dbb := triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil)
trieb := trie.NewEmpty(dbb) trieb := trie.NewEmpty(dbb)
for _, val := range testdata2 { for _, val := range testdata2 {
trieb.MustUpdate([]byte(val.k), []byte(val.v)) trieb.MustUpdate([]byte(val.k), []byte(val.v))
@ -124,7 +147,11 @@ func TestSymmetricDifferenceIterator(t *testing.T) {
onlyA := make(map[string]string) onlyA := make(map[string]string)
onlyB := make(map[string]string) onlyB := make(map[string]string)
var deletions, creations []string var deletions, creations []string
it := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) ita, err := triea.NodeIterator(nil)
assert.NoError(t, err)
itb, err := trieb.NodeIterator(nil)
assert.NoError(t, err)
it := utils.NewSymmetricDifferenceIterator(ita, itb)
for it.Next(true) { for it.Next(true) {
if !it.Leaf() { if !it.Leaf() {
continue continue
@ -177,7 +204,7 @@ func TestCompareDifferenceIterators(t *testing.T) {
test_helpers.QuietLogs() test_helpers.QuietLogs()
db := rawdb.NewMemoryDatabase() db := rawdb.NewMemoryDatabase()
core.DefaultGenesisBlock().MustCommit(db) core.DefaultGenesisBlock().MustCommit(db, triedb.NewDatabase(db, nil))
blocks := mainnet.GetBlocks() blocks := mainnet.GetBlocks()
chain, _ := core.NewBlockChain(db, nil, nil, nil, ethash.NewFaker(), vm.Config{}, nil, nil) chain, _ := core.NewBlockChain(db, nil, nil, nil, ethash.NewFaker(), vm.Config{}, nil, nil)
_, err := chain.InsertChain(blocks[1:]) _, err := chain.InsertChain(blocks[1:])
@ -196,16 +223,28 @@ func TestCompareDifferenceIterators(t *testing.T) {
// collect the paths of nodes exclusive to A and B separately, then make sure the symmetric // collect the paths of nodes exclusive to A and B separately, then make sure the symmetric
// iterator produces the same sets // iterator produces the same sets
var pathsA, pathsB [][]byte var pathsA, pathsB [][]byte
itBonly, _ := trie.NewDifferenceIterator(treeA.NodeIterator(nil), treeB.NodeIterator(nil)) ita, err := treeA.NodeIterator(nil)
assert.NoError(t, err)
itb, err := treeB.NodeIterator(nil)
assert.NoError(t, err)
itBonly, _ := trie.NewDifferenceIterator(ita, itb)
for itBonly.Next(true) { for itBonly.Next(true) {
pathsB = append(pathsB, itBonly.Path()) pathsB = append(pathsB, itBonly.Path())
} }
itAonly, _ := trie.NewDifferenceIterator(treeB.NodeIterator(nil), treeA.NodeIterator(nil)) ita, err = treeA.NodeIterator(nil)
assert.NoError(t, err)
itb, err = treeB.NodeIterator(nil)
assert.NoError(t, err)
itAonly, _ := trie.NewDifferenceIterator(itb, ita)
for itAonly.Next(true) { for itAonly.Next(true) {
pathsA = append(pathsA, itAonly.Path()) pathsA = append(pathsA, itAonly.Path())
} }
itSym := utils.NewSymmetricDifferenceIterator(treeA.NodeIterator(nil), treeB.NodeIterator(nil)) ita, err = treeA.NodeIterator(nil)
assert.NoError(t, err)
itb, err = treeB.NodeIterator(nil)
assert.NoError(t, err)
itSym := utils.NewSymmetricDifferenceIterator(ita, itb)
var idxA, idxB int var idxA, idxB int
for itSym.Next(true) { for itSym.Next(true) {
if itSym.FromA() { if itSym.FromA() {