Implement symmetric difference iterator

This commit is contained in:
Roy Crihfield 2023-07-22 15:09:26 +08:00
parent 6b1f676c60
commit 13f6234ae8
2 changed files with 253 additions and 0 deletions

View File

@ -1 +1,155 @@
package utils package utils
import (
"bytes"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/trie"
)
type symmDiffIterator struct {
a, b iterState // Nodes returned are those in b - a and a - b (keys only)
yieldFromA bool // Whether next node comes from a
count int // Number of nodes scanned on either trie
}
// NewDifferenceIterator constructs a trie.NodeIterator that iterates over the exclusive elements in b that
// are not in a. Returns the iterator, and a pointer to an integer recording the number
// of nodes seen.
func NewSymmetricDifferenceIterator(a, b trie.NodeIterator) (*symmDiffIterator, *int) {
it := &symmDiffIterator{
a: iterState{a, true},
b: iterState{b, true},
}
return it, &it.count
}
// pairs an iterator with a cache of its valid status
type iterState struct {
trie.NodeIterator
valid bool
}
func (st *iterState) Next(descend bool) bool {
st.valid = st.NodeIterator.Next(descend)
return st.valid
}
func (it *symmDiffIterator) curr() *iterState {
if it.yieldFromA {
return &it.a
}
return &it.b
}
// FromA returns true if the current node is sourced from A.
func (it *symmDiffIterator) FromA() bool {
return it.yieldFromA
}
func (it *symmDiffIterator) Hash() common.Hash {
return it.curr().Hash()
}
func (it *symmDiffIterator) Parent() common.Hash {
return it.curr().Parent()
}
func (it *symmDiffIterator) Leaf() bool {
return it.curr().Leaf()
}
func (it *symmDiffIterator) LeafKey() []byte {
return it.curr().LeafKey()
}
func (it *symmDiffIterator) LeafBlob() []byte {
return it.curr().LeafBlob()
}
func (it *symmDiffIterator) LeafProof() [][]byte {
return it.curr().LeafProof()
}
func (it *symmDiffIterator) Path() []byte {
return it.curr().Path()
}
func (it *symmDiffIterator) NodeBlob() []byte {
return it.curr().NodeBlob()
}
func (it *symmDiffIterator) AddResolver(resolver trie.NodeResolver) {
panic("not implemented")
}
func (it *symmDiffIterator) Next(bool) bool {
// NodeIterators start in a "pre-valid" state, so you have to call Next before they point to a
// valid node. This delays advancing the sub-iterators until they are so initialized by seek.
if it.count != 0 {
if it.curr().Next(true) {
it.count++
}
}
it.seek()
return it.a.valid || it.b.valid
}
func (it *symmDiffIterator) seek() {
// Invariants:
// - At the end of the function, the sub-iterator with the lexically lesser path
// points to the next element
// - Said sub-iterator never points to an element present in the other
for {
if !it.b.valid {
it.yieldFromA = true
return
}
if !it.a.valid {
it.yieldFromA = false
return
}
switch compareNodes(&it.a, &it.b) {
case -1:
it.yieldFromA = true
return
case 1:
it.yieldFromA = false
return
case 0:
if it.b.Next(true) {
it.count++
}
if it.a.Next(true) {
it.count++
}
}
}
}
func (it *symmDiffIterator) Error() error {
if err := it.a.Error(); err != nil {
return err
}
return it.b.Error()
}
// Compares nodes with equal paths by value
func compareNodes(a, b trie.NodeIterator) int {
if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 {
return cmp
}
if a.Leaf() && !b.Leaf() {
return -1
} else if b.Leaf() && !a.Leaf() {
return 1
}
if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 {
return cmp
}
if a.Leaf() && b.Leaf() {
return bytes.Compare(a.LeafBlob(), b.LeafBlob())
}
return 0
}

99
utils/iterator_test.go Normal file
View File

@ -0,0 +1,99 @@
package utils_test
import (
"testing"
"github.com/cerc-io/plugeth-statediff/utils"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/trie"
"github.com/stretchr/testify/assert"
)
type kvs struct{ k, v string }
var (
testdata1 = []kvs{
{"bar", "b"},
{"barb", "ba"},
{"bard", "bc"},
{"bars", "bb"},
{"fab", "z"},
{"foo", "a"},
{"food", "ab"},
}
testdata2 = []kvs{
{"aardvark", "c"},
{"bar", "b"},
{"barb", "bd"},
{"bars", "be"},
{"fab", "z"},
{"foo", "a"},
{"foos", "aa"},
{"jars", "d"},
}
)
func TestSymmetricDifferenceIterator(t *testing.T) {
t.Run("empty case", func(t *testing.T) {
db := trie.NewDatabase(rawdb.NewMemoryDatabase())
tree := trie.NewEmpty(db)
di, count := utils.NewSymmetricDifferenceIterator(tree.NodeIterator(nil), tree.NodeIterator(nil))
for di.Next(true) {
t.Errorf("iterator should not yield any elements")
}
if *count != 0 {
t.Errorf("node count should be 0 for empty trie, got %d", *count)
}
})
dba := trie.NewDatabase(rawdb.NewMemoryDatabase())
triea := trie.NewEmpty(dba)
for _, val := range testdata1 {
triea.MustUpdate([]byte(val.k), []byte(val.v))
}
rootA, nodesA := triea.Commit(false)
dba.Update(trie.NewWithNodeSet(nodesA))
triea, _ = trie.New(trie.TrieID(rootA), dba)
dbb := trie.NewDatabase(rawdb.NewMemoryDatabase())
trieb := trie.NewEmpty(dbb)
for _, val := range testdata2 {
trieb.MustUpdate([]byte(val.k), []byte(val.v))
}
rootB, nodesB := trieb.Commit(false)
dbb.Update(trie.NewWithNodeSet(nodesB))
trieb, _ = trie.New(trie.TrieID(rootB), dbb)
onlyA := make(map[string]string)
onlyB := make(map[string]string)
it, _ := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
for it.Next(true) {
if !it.Leaf() {
continue
}
key, value := string(it.LeafKey()), string(it.LeafBlob())
if it.FromA() {
onlyA[key] = value
} else {
onlyB[key] = value
}
}
expectedOnlyA := map[string]string{
"barb": "ba",
"bard": "bc",
"bars": "bb",
"food": "ab",
}
expectedOnlyB := map[string]string{
"aardvark": "c",
"barb": "bd",
"bars": "be",
"foos": "aa",
"jars": "d",
}
assert.Equal(t, expectedOnlyA, onlyA)
assert.Equal(t, expectedOnlyB, onlyB)
}