iterator fixes

This commit is contained in:
Roy Crihfield 2023-08-01 22:50:59 +08:00
parent a6562c81c4
commit b29d60542c
2 changed files with 53 additions and 18 deletions

View File

@ -14,13 +14,15 @@ type symmDiffIterator struct {
eqPathIndex int // Count index of last pair of equal paths, to detect an updated key eqPathIndex int // Count index of last pair of equal paths, to detect an updated key
} }
// NewDifferenceIterator constructs a trie.NodeIterator that iterates over the exclusive elements in b that // NewSymmetricDifferenceIterator constructs a trie.NodeIterator that iterates over the symmetric difference
// are not in a. Returns the iterator, and a pointer to an integer recording the number // of elements in a and b, i.e., the elements in a that are not in b, and vice versa.
// of nodes seen. // Returns the iterator, and a pointer to an integer recording the number of nodes seen.
func NewSymmetricDifferenceIterator(a, b trie.NodeIterator) (*symmDiffIterator, *int) { func NewSymmetricDifferenceIterator(a, b trie.NodeIterator) (*symmDiffIterator, *int) {
it := &symmDiffIterator{ it := &symmDiffIterator{
a: iterState{a, true}, a: iterState{a, true},
b: iterState{b, true}, b: iterState{b, true},
// common paths are detected by a distance <=1 from this index, so put it out of reach
eqPathIndex: -2,
} }
return it, &it.count return it, &it.count
} }
@ -91,9 +93,15 @@ func (it *symmDiffIterator) AddResolver(resolver trie.NodeResolver) {
} }
func (it *symmDiffIterator) Next(bool) bool { func (it *symmDiffIterator) Next(bool) bool {
// NodeIterators start in a "pre-valid" state, so you have to call Next before they point to a // NodeIterators start in a "pre-valid" state, so the first Next advances to a valid node.
// valid node. This delays advancing the sub-iterators until they are so initialized by seek. if it.count == 0 {
if it.count != 0 { if it.a.Next(true) {
it.count++
}
if it.b.Next(true) {
it.count++
}
} else {
if it.curr().Next(true) { if it.curr().Next(true) {
it.count++ it.count++
} }
@ -120,7 +128,7 @@ func (it *symmDiffIterator) seek() {
cmp := bytes.Compare(it.a.Path(), it.b.Path()) cmp := bytes.Compare(it.a.Path(), it.b.Path())
if cmp == 0 { if cmp == 0 {
it.eqPathIndex = it.count it.eqPathIndex = it.count
cmp = compareValues(&it.a, &it.b) cmp = compareNodes(&it.a, &it.b)
} }
switch cmp { switch cmp {
case -1: case -1:
@ -147,8 +155,10 @@ func (it *symmDiffIterator) Error() error {
return it.b.Error() return it.b.Error()
} }
// Compares nodes with equal paths by value func compareNodes(a, b trie.NodeIterator) int {
func compareValues(a, b trie.NodeIterator) int { // if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 {
// return cmp
// }
if a.Leaf() && !b.Leaf() { if a.Leaf() && !b.Leaf() {
return -1 return -1
} else if b.Leaf() && !a.Leaf() { } else if b.Leaf() && !a.Leaf() {

View File

@ -14,13 +14,13 @@ type kvs struct{ k, v string }
var ( var (
testdata1 = []kvs{ testdata1 = []kvs{
{"bar", "b"},
{"barb", "ba"}, {"barb", "ba"},
{"bard", "bc"}, {"bard", "bc"},
{"bars", "bb"}, {"bars", "bb"},
{"bar", "b"},
{"fab", "z"}, {"fab", "z"},
{"foo", "a"},
{"food", "ab"}, {"food", "ab"},
{"foo", "a"},
} }
testdata2 = []kvs{ testdata2 = []kvs{
@ -38,22 +38,30 @@ var (
func TestSymmetricDifferenceIterator(t *testing.T) { func TestSymmetricDifferenceIterator(t *testing.T) {
t.Run("with no difference", func(t *testing.T) { t.Run("with no difference", func(t *testing.T) {
db := trie.NewDatabase(rawdb.NewMemoryDatabase()) db := trie.NewDatabase(rawdb.NewMemoryDatabase())
tree := trie.NewEmpty(db) triea := trie.NewEmpty(db)
di, count := utils.NewSymmetricDifferenceIterator(tree.NodeIterator(nil), tree.NodeIterator(nil)) di, count := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), triea.NodeIterator(nil))
for di.Next(true) { for di.Next(true) {
t.Errorf("iterator should not yield any elements") t.Errorf("iterator should not yield any elements")
} }
assert.Equal(t, 0, *count) assert.Equal(t, 0, *count)
tree.MustUpdate([]byte("foo"), []byte("bar")) triea.MustUpdate([]byte("foo"), []byte("bar"))
di, count = utils.NewSymmetricDifferenceIterator(tree.NodeIterator(nil), tree.NodeIterator(nil)) di, count = utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), triea.NodeIterator(nil))
for di.Next(true) { for di.Next(true) {
t.Errorf("iterator should not yield any elements") t.Errorf("iterator should not yield any elements")
} }
assert.Equal(t, 4, *count) assert.Equal(t, 4, *count)
trieb := trie.NewEmpty(db)
di, count = utils.NewSymmetricDifferenceIterator(triea.NodeIterator([]byte("food")), trieb.NodeIterator(nil))
for di.Next(true) {
t.Errorf("iterator should not yield any elements")
t.Logf("%s", di.LeafKey())
}
assert.Equal(t, 0, *count)
}) })
t.Run("with one difference", func(t *testing.T) { t.Run("small difference", func(t *testing.T) {
dba := trie.NewDatabase(rawdb.NewMemoryDatabase()) dba := trie.NewDatabase(rawdb.NewMemoryDatabase())
triea := trie.NewEmpty(dba) triea := trie.NewEmpty(dba)
@ -62,14 +70,31 @@ func TestSymmetricDifferenceIterator(t *testing.T) {
trieb.MustUpdate([]byte("foo"), []byte("bar")) trieb.MustUpdate([]byte("foo"), []byte("bar"))
di, count := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) di, count := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
leaves := 0
for di.Next(true) { for di.Next(true) {
if di.Leaf() { if di.Leaf() {
assert.Equal(t, []byte("foo"), di.LeafKey())
assert.Equal(t, []byte("bar"), di.LeafBlob())
assert.False(t, di.CommonPath()) assert.False(t, di.CommonPath())
assert.Equal(t, "foo", string(di.LeafKey()))
assert.Equal(t, "bar", string(di.LeafBlob()))
leaves++
} }
} }
assert.Equal(t, 1, leaves)
assert.Equal(t, 2, *count) assert.Equal(t, 2, *count)
trieb.MustUpdate([]byte("quux"), []byte("bars"))
di, count = utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator([]byte("quux")))
leaves = 0
for di.Next(true) {
if di.Leaf() {
assert.False(t, di.CommonPath())
assert.Equal(t, "quux", string(di.LeafKey()))
assert.Equal(t, "bars", string(di.LeafBlob()))
leaves++
}
}
assert.Equal(t, 1, leaves)
assert.Equal(t, 1, *count)
}) })
dba := trie.NewDatabase(rawdb.NewMemoryDatabase()) dba := trie.NewDatabase(rawdb.NewMemoryDatabase())