diff --git a/utils/iterator_test.go b/utils/iterator_test.go index b440fbd..84b7289 100644 --- a/utils/iterator_test.go +++ b/utils/iterator_test.go @@ -9,6 +9,7 @@ import ( "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/triedb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -43,16 +44,26 @@ var ( func TestSymmetricDifferenceIterator(t *testing.T) { t.Run("with no difference", func(t *testing.T) { - db := trie.NewDatabase(rawdb.NewMemoryDatabase()) + db := triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil) triea := trie.NewEmpty(db) - di := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), triea.NodeIterator(nil)) + + ita, err := triea.NodeIterator(nil) + assert.NoError(t, err) + itb, err := triea.NodeIterator(nil) + assert.NoError(t, err) + di := utils.NewSymmetricDifferenceIterator(ita, itb) for di.Next(true) { t.Errorf("iterator should not yield any elements") } assert.Equal(t, 0, di.Count()) triea.MustUpdate([]byte("foo"), []byte("bar")) - di = utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), triea.NodeIterator(nil)) + ita, err = triea.NodeIterator(nil) + assert.NoError(t, err) + itb, err = triea.NodeIterator(nil) + assert.NoError(t, err) + + di = utils.NewSymmetricDifferenceIterator(ita, itb) for di.Next(true) { t.Errorf("iterator should not yield any elements") } @@ -60,7 +71,11 @@ func TestSymmetricDifferenceIterator(t *testing.T) { assert.Equal(t, 2, di.Count()) trieb := trie.NewEmpty(db) - di = utils.NewSymmetricDifferenceIterator(triea.NodeIterator([]byte("jars")), trieb.NodeIterator(nil)) + ita, err = triea.NodeIterator([]byte("jars")) + assert.NoError(t, err) + itb, err = trieb.NodeIterator(nil) + assert.NoError(t, err) + di = utils.NewSymmetricDifferenceIterator(ita, itb) for di.Next(true) { t.Errorf("iterator should not yield any elements") } @@ -75,14 +90,18 @@ func TestSymmetricDifferenceIterator(t *testing.T) { }) t.Run("small difference", func(t *testing.T) { - dba := trie.NewDatabase(rawdb.NewMemoryDatabase()) + dba := triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil) triea := trie.NewEmpty(dba) - dbb := trie.NewDatabase(rawdb.NewMemoryDatabase()) + dbb := triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil) trieb := trie.NewEmpty(dbb) trieb.MustUpdate([]byte("foo"), []byte("bar")) - di := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) + ita, err := triea.NodeIterator(nil) + assert.NoError(t, err) + itb, err := triea.NodeIterator(nil) + assert.NoError(t, err) + di := utils.NewSymmetricDifferenceIterator(ita, itb) leaves := 0 for di.Next(true) { if di.Leaf() { @@ -96,7 +115,11 @@ func TestSymmetricDifferenceIterator(t *testing.T) { assert.Equal(t, 2, di.Count()) trieb.MustUpdate([]byte("quux"), []byte("bars")) - di = utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator([]byte("quux"))) + ita, err = triea.NodeIterator(nil) + assert.NoError(t, err) + itb, err = triea.NodeIterator([]byte("quux")) + assert.NoError(t, err) + di = utils.NewSymmetricDifferenceIterator(ita, itb) leaves = 0 for di.Next(true) { if di.Leaf() { @@ -110,12 +133,12 @@ func TestSymmetricDifferenceIterator(t *testing.T) { assert.Equal(t, 1, di.Count()) }) - dba := trie.NewDatabase(rawdb.NewMemoryDatabase()) + dba := triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil) triea := trie.NewEmpty(dba) for _, val := range testdata1 { triea.MustUpdate([]byte(val.k), []byte(val.v)) } - dbb := trie.NewDatabase(rawdb.NewMemoryDatabase()) + dbb := triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil) trieb := trie.NewEmpty(dbb) for _, val := range testdata2 { trieb.MustUpdate([]byte(val.k), []byte(val.v)) @@ -124,7 +147,11 @@ func TestSymmetricDifferenceIterator(t *testing.T) { onlyA := make(map[string]string) onlyB := make(map[string]string) var deletions, creations []string - it := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) + ita, err := triea.NodeIterator(nil) + assert.NoError(t, err) + itb, err := triea.NodeIterator(nil) + assert.NoError(t, err) + it := utils.NewSymmetricDifferenceIterator(ita, itb) for it.Next(true) { if !it.Leaf() { continue @@ -177,7 +204,7 @@ func TestCompareDifferenceIterators(t *testing.T) { test_helpers.QuietLogs() db := rawdb.NewMemoryDatabase() - core.DefaultGenesisBlock().MustCommit(db) + core.DefaultGenesisBlock().MustCommit(db, triedb.NewDatabase(db, nil)) blocks := mainnet.GetBlocks() chain, _ := core.NewBlockChain(db, nil, nil, nil, ethash.NewFaker(), vm.Config{}, nil, nil) _, err := chain.InsertChain(blocks[1:]) @@ -196,16 +223,24 @@ func TestCompareDifferenceIterators(t *testing.T) { // collect the paths of nodes exclusive to A and B separately, then make sure the symmetric // iterator produces the same sets var pathsA, pathsB [][]byte - itBonly, _ := trie.NewDifferenceIterator(treeA.NodeIterator(nil), treeB.NodeIterator(nil)) + ita, err := treeA.NodeIterator(nil) + assert.NoError(t, err) + itb, err := treeB.NodeIterator(nil) + assert.NoError(t, err) + itBonly, _ := trie.NewDifferenceIterator(ita, itb) for itBonly.Next(true) { pathsB = append(pathsB, itBonly.Path()) } - itAonly, _ := trie.NewDifferenceIterator(treeB.NodeIterator(nil), treeA.NodeIterator(nil)) + itAonly, _ := trie.NewDifferenceIterator(itb, ita) for itAonly.Next(true) { pathsA = append(pathsA, itAonly.Path()) } - itSym := utils.NewSymmetricDifferenceIterator(treeA.NodeIterator(nil), treeB.NodeIterator(nil)) + ita, err = treeA.NodeIterator(nil) + assert.NoError(t, err) + itb, err = treeB.NodeIterator(nil) + assert.NoError(t, err) + itSym := utils.NewSymmetricDifferenceIterator(ita, itb) var idxA, idxB int for itSym.Next(true) { if itSym.FromA() {