diff --git a/utils/iterator.go b/utils/iterator.go index f69d78f..a2486df 100644 --- a/utils/iterator.go +++ b/utils/iterator.go @@ -14,13 +14,15 @@ type symmDiffIterator struct { eqPathIndex int // Count index of last pair of equal paths, to detect an updated key } -// NewDifferenceIterator constructs a trie.NodeIterator that iterates over the exclusive elements in b that -// are not in a. Returns the iterator, and a pointer to an integer recording the number -// of nodes seen. +// NewSymmetricDifferenceIterator constructs a trie.NodeIterator that iterates over the symmetric difference +// of elements in a and b, i.e., the elements in a that are not in b, and vice versa. +// Returns the iterator, and a pointer to an integer recording the number of nodes seen. func NewSymmetricDifferenceIterator(a, b trie.NodeIterator) (*symmDiffIterator, *int) { it := &symmDiffIterator{ a: iterState{a, true}, b: iterState{b, true}, + // common paths are detected by a distance <=1 from this index, so put it out of reach + eqPathIndex: -2, } return it, &it.count } @@ -91,9 +93,15 @@ func (it *symmDiffIterator) AddResolver(resolver trie.NodeResolver) { } func (it *symmDiffIterator) Next(bool) bool { - // NodeIterators start in a "pre-valid" state, so you have to call Next before they point to a - // valid node. This delays advancing the sub-iterators until they are so initialized by seek. - if it.count != 0 { + // NodeIterators start in a "pre-valid" state, so the first Next advances to a valid node. + if it.count == 0 { + if it.a.Next(true) { + it.count++ + } + if it.b.Next(true) { + it.count++ + } + } else { if it.curr().Next(true) { it.count++ } @@ -120,7 +128,7 @@ func (it *symmDiffIterator) seek() { cmp := bytes.Compare(it.a.Path(), it.b.Path()) if cmp == 0 { it.eqPathIndex = it.count - cmp = compareValues(&it.a, &it.b) + cmp = compareNodes(&it.a, &it.b) } switch cmp { case -1: @@ -147,8 +155,10 @@ func (it *symmDiffIterator) Error() error { return it.b.Error() } -// Compares nodes with equal paths by value -func compareValues(a, b trie.NodeIterator) int { +func compareNodes(a, b trie.NodeIterator) int { + // if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 { + // return cmp + // } if a.Leaf() && !b.Leaf() { return -1 } else if b.Leaf() && !a.Leaf() { diff --git a/utils/iterator_test.go b/utils/iterator_test.go index 2b4eb37..3473c63 100644 --- a/utils/iterator_test.go +++ b/utils/iterator_test.go @@ -14,13 +14,13 @@ type kvs struct{ k, v string } var ( testdata1 = []kvs{ - {"bar", "b"}, {"barb", "ba"}, {"bard", "bc"}, {"bars", "bb"}, + {"bar", "b"}, {"fab", "z"}, - {"foo", "a"}, {"food", "ab"}, + {"foo", "a"}, } testdata2 = []kvs{ @@ -38,22 +38,30 @@ var ( func TestSymmetricDifferenceIterator(t *testing.T) { t.Run("with no difference", func(t *testing.T) { db := trie.NewDatabase(rawdb.NewMemoryDatabase()) - tree := trie.NewEmpty(db) - di, count := utils.NewSymmetricDifferenceIterator(tree.NodeIterator(nil), tree.NodeIterator(nil)) + triea := trie.NewEmpty(db) + di, count := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), triea.NodeIterator(nil)) for di.Next(true) { t.Errorf("iterator should not yield any elements") } assert.Equal(t, 0, *count) - tree.MustUpdate([]byte("foo"), []byte("bar")) - di, count = utils.NewSymmetricDifferenceIterator(tree.NodeIterator(nil), tree.NodeIterator(nil)) + triea.MustUpdate([]byte("foo"), []byte("bar")) + di, count = utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), triea.NodeIterator(nil)) for di.Next(true) { t.Errorf("iterator should not yield any elements") } assert.Equal(t, 4, *count) + + trieb := trie.NewEmpty(db) + di, count = utils.NewSymmetricDifferenceIterator(triea.NodeIterator([]byte("food")), trieb.NodeIterator(nil)) + for di.Next(true) { + t.Errorf("iterator should not yield any elements") + t.Logf("%s", di.LeafKey()) + } + assert.Equal(t, 0, *count) }) - t.Run("with one difference", func(t *testing.T) { + t.Run("small difference", func(t *testing.T) { dba := trie.NewDatabase(rawdb.NewMemoryDatabase()) triea := trie.NewEmpty(dba) @@ -62,14 +70,31 @@ func TestSymmetricDifferenceIterator(t *testing.T) { trieb.MustUpdate([]byte("foo"), []byte("bar")) di, count := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) + leaves := 0 for di.Next(true) { if di.Leaf() { - assert.Equal(t, []byte("foo"), di.LeafKey()) - assert.Equal(t, []byte("bar"), di.LeafBlob()) assert.False(t, di.CommonPath()) + assert.Equal(t, "foo", string(di.LeafKey())) + assert.Equal(t, "bar", string(di.LeafBlob())) + leaves++ } } + assert.Equal(t, 1, leaves) assert.Equal(t, 2, *count) + + trieb.MustUpdate([]byte("quux"), []byte("bars")) + di, count = utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator([]byte("quux"))) + leaves = 0 + for di.Next(true) { + if di.Leaf() { + assert.False(t, di.CommonPath()) + assert.Equal(t, "quux", string(di.LeafKey())) + assert.Equal(t, "bars", string(di.LeafBlob())) + leaves++ + } + } + assert.Equal(t, 1, leaves) + assert.Equal(t, 1, *count) }) dba := trie.NewDatabase(rawdb.NewMemoryDatabase())