diff --git a/utils/iterator.go b/utils/iterator.go index d4b585b..1284fd6 100644 --- a/utils/iterator.go +++ b/utils/iterator.go @@ -1 +1,155 @@ package utils + +import ( + "bytes" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/trie" +) + +type symmDiffIterator struct { + a, b iterState // Nodes returned are those in b - a and a - b (keys only) + yieldFromA bool // Whether next node comes from a + count int // Number of nodes scanned on either trie +} + +// NewDifferenceIterator constructs a trie.NodeIterator that iterates over the exclusive elements in b that +// are not in a. Returns the iterator, and a pointer to an integer recording the number +// of nodes seen. +func NewSymmetricDifferenceIterator(a, b trie.NodeIterator) (*symmDiffIterator, *int) { + it := &symmDiffIterator{ + a: iterState{a, true}, + b: iterState{b, true}, + } + return it, &it.count +} + +// pairs an iterator with a cache of its valid status +type iterState struct { + trie.NodeIterator + valid bool +} + +func (st *iterState) Next(descend bool) bool { + st.valid = st.NodeIterator.Next(descend) + return st.valid +} + +func (it *symmDiffIterator) curr() *iterState { + if it.yieldFromA { + return &it.a + } + return &it.b +} + +// FromA returns true if the current node is sourced from A. +func (it *symmDiffIterator) FromA() bool { + return it.yieldFromA +} + +func (it *symmDiffIterator) Hash() common.Hash { + return it.curr().Hash() +} + +func (it *symmDiffIterator) Parent() common.Hash { + return it.curr().Parent() +} + +func (it *symmDiffIterator) Leaf() bool { + return it.curr().Leaf() +} + +func (it *symmDiffIterator) LeafKey() []byte { + return it.curr().LeafKey() +} + +func (it *symmDiffIterator) LeafBlob() []byte { + return it.curr().LeafBlob() +} + +func (it *symmDiffIterator) LeafProof() [][]byte { + return it.curr().LeafProof() +} + +func (it *symmDiffIterator) Path() []byte { + return it.curr().Path() +} + +func (it *symmDiffIterator) NodeBlob() []byte { + return it.curr().NodeBlob() +} + +func (it *symmDiffIterator) AddResolver(resolver trie.NodeResolver) { + panic("not implemented") +} + +func (it *symmDiffIterator) Next(bool) bool { + // NodeIterators start in a "pre-valid" state, so you have to call Next before they point to a + // valid node. This delays advancing the sub-iterators until they are so initialized by seek. + if it.count != 0 { + if it.curr().Next(true) { + it.count++ + } + } + it.seek() + return it.a.valid || it.b.valid +} + +func (it *symmDiffIterator) seek() { + // Invariants: + // - At the end of the function, the sub-iterator with the lexically lesser path + // points to the next element + // - Said sub-iterator never points to an element present in the other + for { + if !it.b.valid { + it.yieldFromA = true + return + } + if !it.a.valid { + it.yieldFromA = false + return + } + + switch compareNodes(&it.a, &it.b) { + case -1: + it.yieldFromA = true + return + case 1: + it.yieldFromA = false + return + case 0: + if it.a.Next(true) { + it.count++ + } + if it.b.Next(true) { + it.count++ + } + } + } +} + +func (it *symmDiffIterator) Error() error { + if err := it.a.Error(); err != nil { + return err + } + return it.b.Error() +} + +// Compares nodes with equal paths by value +func compareNodes(a, b trie.NodeIterator) int { + if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 { + return cmp + } + if a.Leaf() && !b.Leaf() { + return -1 + } else if b.Leaf() && !a.Leaf() { + return 1 + } + if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 { + return cmp + } + if a.Leaf() && b.Leaf() { + return bytes.Compare(a.LeafBlob(), b.LeafBlob()) + } + return 0 +} diff --git a/utils/iterator_test.go b/utils/iterator_test.go new file mode 100644 index 0000000..f22a0d1 --- /dev/null +++ b/utils/iterator_test.go @@ -0,0 +1,97 @@ +package utils_test + +import ( + "testing" + + "github.com/cerc-io/plugeth-statediff/utils" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/trie" + + "github.com/stretchr/testify/assert" +) + +type kvs struct{ k, v string } + +var ( + testdata1 = []kvs{ + {"bar", "b"}, + {"barb", "ba"}, + {"bard", "bc"}, + {"bars", "bb"}, + {"fab", "z"}, + {"foo", "a"}, + {"food", "ab"}, + } + + testdata2 = []kvs{ + {"aardvark", "c"}, + {"bar", "b"}, + {"barb", "bd"}, + {"bars", "be"}, + {"fab", "z"}, + {"foo", "a"}, + {"foos", "aa"}, + {"jars", "d"}, + } +) + +func TestSymmetricDifferenceIterator(t *testing.T) { + t.Run("with no difference", func(t *testing.T) { + db := trie.NewDatabase(rawdb.NewMemoryDatabase()) + tree := trie.NewEmpty(db) + di, count := utils.NewSymmetricDifferenceIterator(tree.NodeIterator(nil), tree.NodeIterator(nil)) + for di.Next(true) { + t.Errorf("iterator should not yield any elements") + } + assert.Equal(t, 0, *count) + + tree.MustUpdate([]byte("foo"), []byte("bar")) + di, count = utils.NewSymmetricDifferenceIterator(tree.NodeIterator(nil), tree.NodeIterator(nil)) + for di.Next(true) { + t.Errorf("iterator should not yield any elements") + } + assert.NotEqual(t, 0, count) + }) + + dba := trie.NewDatabase(rawdb.NewMemoryDatabase()) + triea := trie.NewEmpty(dba) + for _, val := range testdata1 { + triea.MustUpdate([]byte(val.k), []byte(val.v)) + } + dbb := trie.NewDatabase(rawdb.NewMemoryDatabase()) + trieb := trie.NewEmpty(dbb) + for _, val := range testdata2 { + trieb.MustUpdate([]byte(val.k), []byte(val.v)) + } + + onlyA := make(map[string]string) + onlyB := make(map[string]string) + it, _ := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) + for it.Next(true) { + if !it.Leaf() { + continue + } + key, value := string(it.LeafKey()), string(it.LeafBlob()) + if it.FromA() { + onlyA[key] = value + } else { + onlyB[key] = value + } + } + + expectedOnlyA := map[string]string{ + "barb": "ba", + "bard": "bc", + "bars": "bb", + "food": "ab", + } + expectedOnlyB := map[string]string{ + "aardvark": "c", + "barb": "bd", + "bars": "be", + "foos": "aa", + "jars": "d", + } + assert.Equal(t, expectedOnlyA, onlyA) + assert.Equal(t, expectedOnlyB, onlyB) +}