diff --git a/utils/iterator.go b/utils/iterator.go index 1284fd6..f69d78f 100644 --- a/utils/iterator.go +++ b/utils/iterator.go @@ -8,9 +8,10 @@ import ( ) type symmDiffIterator struct { - a, b iterState // Nodes returned are those in b - a and a - b (keys only) - yieldFromA bool // Whether next node comes from a - count int // Number of nodes scanned on either trie + a, b iterState // Nodes returned are those in b - a and a - b (keys only) + yieldFromA bool // Whether next node comes from a + count int // Number of nodes scanned on either trie + eqPathIndex int // Count index of last pair of equal paths, to detect an updated key } // NewDifferenceIterator constructs a trie.NodeIterator that iterates over the exclusive elements in b that @@ -47,6 +48,12 @@ func (it *symmDiffIterator) FromA() bool { return it.yieldFromA } +// CommonPath returns true if a node with the current path exists in each sub-iterator - i.e. it +// represents an updated node. +func (it *symmDiffIterator) CommonPath() bool { + return it.count-it.eqPathIndex <= 1 +} + func (it *symmDiffIterator) Hash() common.Hash { return it.curr().Hash() } @@ -110,7 +117,12 @@ func (it *symmDiffIterator) seek() { return } - switch compareNodes(&it.a, &it.b) { + cmp := bytes.Compare(it.a.Path(), it.b.Path()) + if cmp == 0 { + it.eqPathIndex = it.count + cmp = compareValues(&it.a, &it.b) + } + switch cmp { case -1: it.yieldFromA = true return @@ -136,10 +148,7 @@ func (it *symmDiffIterator) Error() error { } // Compares nodes with equal paths by value -func compareNodes(a, b trie.NodeIterator) int { - if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 { - return cmp - } +func compareValues(a, b trie.NodeIterator) int { if a.Leaf() && !b.Leaf() { return -1 } else if b.Leaf() && !a.Leaf() { diff --git a/utils/iterator_test.go b/utils/iterator_test.go index f22a0d1..2b4eb37 100644 --- a/utils/iterator_test.go +++ b/utils/iterator_test.go @@ -50,7 +50,26 @@ func TestSymmetricDifferenceIterator(t *testing.T) { for di.Next(true) { t.Errorf("iterator should not yield any elements") } - assert.NotEqual(t, 0, count) + assert.Equal(t, 4, *count) + }) + + t.Run("with one difference", func(t *testing.T) { + dba := trie.NewDatabase(rawdb.NewMemoryDatabase()) + triea := trie.NewEmpty(dba) + + dbb := trie.NewDatabase(rawdb.NewMemoryDatabase()) + trieb := trie.NewEmpty(dbb) + trieb.MustUpdate([]byte("foo"), []byte("bar")) + + di, count := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) + for di.Next(true) { + if di.Leaf() { + assert.Equal(t, []byte("foo"), di.LeafKey()) + assert.Equal(t, []byte("bar"), di.LeafBlob()) + assert.False(t, di.CommonPath()) + } + } + assert.Equal(t, 2, *count) }) dba := trie.NewDatabase(rawdb.NewMemoryDatabase()) @@ -66,6 +85,7 @@ func TestSymmetricDifferenceIterator(t *testing.T) { onlyA := make(map[string]string) onlyB := make(map[string]string) + var deletions, creations []string it, _ := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) for it.Next(true) { if !it.Leaf() { @@ -74,8 +94,14 @@ func TestSymmetricDifferenceIterator(t *testing.T) { key, value := string(it.LeafKey()), string(it.LeafBlob()) if it.FromA() { onlyA[key] = value + if !it.CommonPath() { + deletions = append(deletions, key) + } } else { onlyB[key] = value + if !it.CommonPath() { + creations = append(creations, key) + } } } @@ -92,6 +118,17 @@ func TestSymmetricDifferenceIterator(t *testing.T) { "foos": "aa", "jars": "d", } + expectedDeletions := []string{ + "bard", + "food", + } + expectedCreations := []string{ + "aardvark", + "foos", + "jars", + } assert.Equal(t, expectedOnlyA, onlyA) assert.Equal(t, expectedOnlyB, onlyB) + assert.Equal(t, expectedDeletions, deletions) + assert.Equal(t, expectedCreations, creations) }