symmetric diff iterator: detect common paths

This commit is contained in:
Roy Crihfield 2023-07-24 21:46:09 +08:00
parent fdcce5be35
commit ae30af8544
2 changed files with 55 additions and 9 deletions

View File

@ -8,9 +8,10 @@ import (
)
type symmDiffIterator struct {
a, b iterState // Nodes returned are those in b - a and a - b (keys only)
yieldFromA bool // Whether next node comes from a
count int // Number of nodes scanned on either trie
a, b iterState // Nodes returned are those in b - a and a - b (keys only)
yieldFromA bool // Whether next node comes from a
count int // Number of nodes scanned on either trie
eqPathIndex int // Count index of last pair of equal paths, to detect an updated key
}
// NewDifferenceIterator constructs a trie.NodeIterator that iterates over the exclusive elements in b that
@ -47,6 +48,12 @@ func (it *symmDiffIterator) FromA() bool {
return it.yieldFromA
}
// CommonPath returns true if a node with the current path exists in each sub-iterator - i.e. it
// represents an updated node.
func (it *symmDiffIterator) CommonPath() bool {
return it.count-it.eqPathIndex <= 1
}
func (it *symmDiffIterator) Hash() common.Hash {
return it.curr().Hash()
}
@ -110,7 +117,12 @@ func (it *symmDiffIterator) seek() {
return
}
switch compareNodes(&it.a, &it.b) {
cmp := bytes.Compare(it.a.Path(), it.b.Path())
if cmp == 0 {
it.eqPathIndex = it.count
cmp = compareValues(&it.a, &it.b)
}
switch cmp {
case -1:
it.yieldFromA = true
return
@ -136,10 +148,7 @@ func (it *symmDiffIterator) Error() error {
}
// Compares nodes with equal paths by value
func compareNodes(a, b trie.NodeIterator) int {
if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 {
return cmp
}
func compareValues(a, b trie.NodeIterator) int {
if a.Leaf() && !b.Leaf() {
return -1
} else if b.Leaf() && !a.Leaf() {

View File

@ -50,7 +50,26 @@ func TestSymmetricDifferenceIterator(t *testing.T) {
for di.Next(true) {
t.Errorf("iterator should not yield any elements")
}
assert.NotEqual(t, 0, count)
assert.Equal(t, 4, *count)
})
t.Run("with one difference", func(t *testing.T) {
dba := trie.NewDatabase(rawdb.NewMemoryDatabase())
triea := trie.NewEmpty(dba)
dbb := trie.NewDatabase(rawdb.NewMemoryDatabase())
trieb := trie.NewEmpty(dbb)
trieb.MustUpdate([]byte("foo"), []byte("bar"))
di, count := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
for di.Next(true) {
if di.Leaf() {
assert.Equal(t, []byte("foo"), di.LeafKey())
assert.Equal(t, []byte("bar"), di.LeafBlob())
assert.False(t, di.CommonPath())
}
}
assert.Equal(t, 2, *count)
})
dba := trie.NewDatabase(rawdb.NewMemoryDatabase())
@ -66,6 +85,7 @@ func TestSymmetricDifferenceIterator(t *testing.T) {
onlyA := make(map[string]string)
onlyB := make(map[string]string)
var deletions, creations []string
it, _ := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
for it.Next(true) {
if !it.Leaf() {
@ -74,8 +94,14 @@ func TestSymmetricDifferenceIterator(t *testing.T) {
key, value := string(it.LeafKey()), string(it.LeafBlob())
if it.FromA() {
onlyA[key] = value
if !it.CommonPath() {
deletions = append(deletions, key)
}
} else {
onlyB[key] = value
if !it.CommonPath() {
creations = append(creations, key)
}
}
}
@ -92,6 +118,17 @@ func TestSymmetricDifferenceIterator(t *testing.T) {
"foos": "aa",
"jars": "d",
}
expectedDeletions := []string{
"bard",
"food",
}
expectedCreations := []string{
"aardvark",
"foos",
"jars",
}
assert.Equal(t, expectedOnlyA, onlyA)
assert.Equal(t, expectedOnlyB, onlyB)
assert.Equal(t, expectedDeletions, deletions)
assert.Equal(t, expectedCreations, creations)
}