Implement symmetric difference iterator
This commit is contained in:
parent
6b1f676c60
commit
13f6234ae8
@ -1 +1,155 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/common"
|
||||||
|
"github.com/ethereum/go-ethereum/trie"
|
||||||
|
)
|
||||||
|
|
||||||
|
type symmDiffIterator struct {
|
||||||
|
a, b iterState // Nodes returned are those in b - a and a - b (keys only)
|
||||||
|
yieldFromA bool // Whether next node comes from a
|
||||||
|
count int // Number of nodes scanned on either trie
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDifferenceIterator constructs a trie.NodeIterator that iterates over the exclusive elements in b that
|
||||||
|
// are not in a. Returns the iterator, and a pointer to an integer recording the number
|
||||||
|
// of nodes seen.
|
||||||
|
func NewSymmetricDifferenceIterator(a, b trie.NodeIterator) (*symmDiffIterator, *int) {
|
||||||
|
it := &symmDiffIterator{
|
||||||
|
a: iterState{a, true},
|
||||||
|
b: iterState{b, true},
|
||||||
|
}
|
||||||
|
return it, &it.count
|
||||||
|
}
|
||||||
|
|
||||||
|
// pairs an iterator with a cache of its valid status
|
||||||
|
type iterState struct {
|
||||||
|
trie.NodeIterator
|
||||||
|
valid bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (st *iterState) Next(descend bool) bool {
|
||||||
|
st.valid = st.NodeIterator.Next(descend)
|
||||||
|
return st.valid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *symmDiffIterator) curr() *iterState {
|
||||||
|
if it.yieldFromA {
|
||||||
|
return &it.a
|
||||||
|
}
|
||||||
|
return &it.b
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromA returns true if the current node is sourced from A.
|
||||||
|
func (it *symmDiffIterator) FromA() bool {
|
||||||
|
return it.yieldFromA
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *symmDiffIterator) Hash() common.Hash {
|
||||||
|
return it.curr().Hash()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *symmDiffIterator) Parent() common.Hash {
|
||||||
|
return it.curr().Parent()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *symmDiffIterator) Leaf() bool {
|
||||||
|
return it.curr().Leaf()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *symmDiffIterator) LeafKey() []byte {
|
||||||
|
return it.curr().LeafKey()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *symmDiffIterator) LeafBlob() []byte {
|
||||||
|
return it.curr().LeafBlob()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *symmDiffIterator) LeafProof() [][]byte {
|
||||||
|
return it.curr().LeafProof()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *symmDiffIterator) Path() []byte {
|
||||||
|
return it.curr().Path()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *symmDiffIterator) NodeBlob() []byte {
|
||||||
|
return it.curr().NodeBlob()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *symmDiffIterator) AddResolver(resolver trie.NodeResolver) {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *symmDiffIterator) Next(bool) bool {
|
||||||
|
// NodeIterators start in a "pre-valid" state, so you have to call Next before they point to a
|
||||||
|
// valid node. This delays advancing the sub-iterators until they are so initialized by seek.
|
||||||
|
if it.count != 0 {
|
||||||
|
if it.curr().Next(true) {
|
||||||
|
it.count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
it.seek()
|
||||||
|
return it.a.valid || it.b.valid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *symmDiffIterator) seek() {
|
||||||
|
// Invariants:
|
||||||
|
// - At the end of the function, the sub-iterator with the lexically lesser path
|
||||||
|
// points to the next element
|
||||||
|
// - Said sub-iterator never points to an element present in the other
|
||||||
|
for {
|
||||||
|
if !it.b.valid {
|
||||||
|
it.yieldFromA = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !it.a.valid {
|
||||||
|
it.yieldFromA = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch compareNodes(&it.a, &it.b) {
|
||||||
|
case -1:
|
||||||
|
it.yieldFromA = true
|
||||||
|
return
|
||||||
|
case 1:
|
||||||
|
it.yieldFromA = false
|
||||||
|
return
|
||||||
|
case 0:
|
||||||
|
if it.b.Next(true) {
|
||||||
|
it.count++
|
||||||
|
}
|
||||||
|
if it.a.Next(true) {
|
||||||
|
it.count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *symmDiffIterator) Error() error {
|
||||||
|
if err := it.a.Error(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return it.b.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compares nodes with equal paths by value
|
||||||
|
func compareNodes(a, b trie.NodeIterator) int {
|
||||||
|
if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 {
|
||||||
|
return cmp
|
||||||
|
}
|
||||||
|
if a.Leaf() && !b.Leaf() {
|
||||||
|
return -1
|
||||||
|
} else if b.Leaf() && !a.Leaf() {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 {
|
||||||
|
return cmp
|
||||||
|
}
|
||||||
|
if a.Leaf() && b.Leaf() {
|
||||||
|
return bytes.Compare(a.LeafBlob(), b.LeafBlob())
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
99
utils/iterator_test.go
Normal file
99
utils/iterator_test.go
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
package utils_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/cerc-io/plugeth-statediff/utils"
|
||||||
|
"github.com/ethereum/go-ethereum/core/rawdb"
|
||||||
|
"github.com/ethereum/go-ethereum/trie"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type kvs struct{ k, v string }
|
||||||
|
|
||||||
|
var (
|
||||||
|
testdata1 = []kvs{
|
||||||
|
{"bar", "b"},
|
||||||
|
{"barb", "ba"},
|
||||||
|
{"bard", "bc"},
|
||||||
|
{"bars", "bb"},
|
||||||
|
{"fab", "z"},
|
||||||
|
{"foo", "a"},
|
||||||
|
{"food", "ab"},
|
||||||
|
}
|
||||||
|
|
||||||
|
testdata2 = []kvs{
|
||||||
|
{"aardvark", "c"},
|
||||||
|
{"bar", "b"},
|
||||||
|
{"barb", "bd"},
|
||||||
|
{"bars", "be"},
|
||||||
|
{"fab", "z"},
|
||||||
|
{"foo", "a"},
|
||||||
|
{"foos", "aa"},
|
||||||
|
{"jars", "d"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSymmetricDifferenceIterator(t *testing.T) {
|
||||||
|
t.Run("empty case", func(t *testing.T) {
|
||||||
|
db := trie.NewDatabase(rawdb.NewMemoryDatabase())
|
||||||
|
tree := trie.NewEmpty(db)
|
||||||
|
di, count := utils.NewSymmetricDifferenceIterator(tree.NodeIterator(nil), tree.NodeIterator(nil))
|
||||||
|
for di.Next(true) {
|
||||||
|
t.Errorf("iterator should not yield any elements")
|
||||||
|
}
|
||||||
|
if *count != 0 {
|
||||||
|
t.Errorf("node count should be 0 for empty trie, got %d", *count)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
dba := trie.NewDatabase(rawdb.NewMemoryDatabase())
|
||||||
|
triea := trie.NewEmpty(dba)
|
||||||
|
for _, val := range testdata1 {
|
||||||
|
triea.MustUpdate([]byte(val.k), []byte(val.v))
|
||||||
|
}
|
||||||
|
rootA, nodesA := triea.Commit(false)
|
||||||
|
dba.Update(trie.NewWithNodeSet(nodesA))
|
||||||
|
triea, _ = trie.New(trie.TrieID(rootA), dba)
|
||||||
|
|
||||||
|
dbb := trie.NewDatabase(rawdb.NewMemoryDatabase())
|
||||||
|
trieb := trie.NewEmpty(dbb)
|
||||||
|
for _, val := range testdata2 {
|
||||||
|
trieb.MustUpdate([]byte(val.k), []byte(val.v))
|
||||||
|
}
|
||||||
|
rootB, nodesB := trieb.Commit(false)
|
||||||
|
dbb.Update(trie.NewWithNodeSet(nodesB))
|
||||||
|
trieb, _ = trie.New(trie.TrieID(rootB), dbb)
|
||||||
|
|
||||||
|
onlyA := make(map[string]string)
|
||||||
|
onlyB := make(map[string]string)
|
||||||
|
it, _ := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
|
||||||
|
for it.Next(true) {
|
||||||
|
if !it.Leaf() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key, value := string(it.LeafKey()), string(it.LeafBlob())
|
||||||
|
if it.FromA() {
|
||||||
|
onlyA[key] = value
|
||||||
|
} else {
|
||||||
|
onlyB[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedOnlyA := map[string]string{
|
||||||
|
"barb": "ba",
|
||||||
|
"bard": "bc",
|
||||||
|
"bars": "bb",
|
||||||
|
"food": "ab",
|
||||||
|
}
|
||||||
|
expectedOnlyB := map[string]string{
|
||||||
|
"aardvark": "c",
|
||||||
|
"barb": "bd",
|
||||||
|
"bars": "be",
|
||||||
|
"foos": "aa",
|
||||||
|
"jars": "d",
|
||||||
|
}
|
||||||
|
assert.Equal(t, expectedOnlyA, onlyA)
|
||||||
|
assert.Equal(t, expectedOnlyB, onlyB)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user