Implement symmetric difference iterator
This commit is contained in:
parent
6b1f676c60
commit
b9d988e0ec
@ -1 +1,155 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/trie"
|
||||
)
|
||||
|
||||
type symDiffIterator struct {
|
||||
a, b iterState // Nodes returned are those in b - a and a - b (keys only)
|
||||
yieldFromA bool // Whether next node comes from a
|
||||
count int // Number of nodes scanned on either trie
|
||||
}
|
||||
|
||||
// NewDifferenceIterator constructs a trie.NodeIterator that iterates over the exclusive elements in b that
|
||||
// are not in a. Returns the iterator, and a pointer to an integer recording the number
|
||||
// of nodes seen.
|
||||
func NewSymmetricDifferenceIterator(a, b trie.NodeIterator) (*symDiffIterator, *int) {
|
||||
it := &symDiffIterator{
|
||||
a: iterState{a, true},
|
||||
b: iterState{b, true},
|
||||
}
|
||||
return it, &it.count
|
||||
}
|
||||
|
||||
// pairs an iterator with a cache of its valid status
|
||||
type iterState struct {
|
||||
trie.NodeIterator
|
||||
valid bool
|
||||
}
|
||||
|
||||
func (st *iterState) Next(descend bool) bool {
|
||||
st.valid = st.NodeIterator.Next(descend)
|
||||
return st.valid
|
||||
}
|
||||
|
||||
func (it *symDiffIterator) curr() *iterState {
|
||||
if it.yieldFromA {
|
||||
return &it.a
|
||||
}
|
||||
return &it.b
|
||||
}
|
||||
|
||||
// FromA returns true if the current node is sourced from A.
|
||||
func (it *symDiffIterator) FromA() bool {
|
||||
return it.yieldFromA
|
||||
}
|
||||
|
||||
func (it *symDiffIterator) Hash() common.Hash {
|
||||
return it.curr().Hash()
|
||||
}
|
||||
|
||||
func (it *symDiffIterator) Parent() common.Hash {
|
||||
return it.curr().Parent()
|
||||
}
|
||||
|
||||
func (it *symDiffIterator) Leaf() bool {
|
||||
return it.curr().Leaf()
|
||||
}
|
||||
|
||||
func (it *symDiffIterator) LeafKey() []byte {
|
||||
return it.curr().LeafKey()
|
||||
}
|
||||
|
||||
func (it *symDiffIterator) LeafBlob() []byte {
|
||||
return it.curr().LeafBlob()
|
||||
}
|
||||
|
||||
func (it *symDiffIterator) LeafProof() [][]byte {
|
||||
return it.curr().LeafProof()
|
||||
}
|
||||
|
||||
func (it *symDiffIterator) Path() []byte {
|
||||
return it.curr().Path()
|
||||
}
|
||||
|
||||
func (it *symDiffIterator) NodeBlob() []byte {
|
||||
return it.curr().NodeBlob()
|
||||
}
|
||||
|
||||
func (it *symDiffIterator) AddResolver(resolver trie.NodeResolver) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (it *symDiffIterator) Next(bool) bool {
|
||||
// NodeIterators start in a "pre-valid" state, so you have to call Next before they point to a
|
||||
// valid node. This delays advancing the sub-iterators until they are so initialized by seek.
|
||||
if it.count != 0 {
|
||||
if it.curr().Next(true) {
|
||||
it.count++
|
||||
}
|
||||
}
|
||||
it.seek()
|
||||
return it.a.valid || it.b.valid
|
||||
}
|
||||
|
||||
func (it *symDiffIterator) seek() {
|
||||
// Invariants:
|
||||
// - At the end of the function, the sub-iterator with the lexically lesser path
|
||||
// points to the next element
|
||||
// - Said sub-iterator never points to an element present in the other
|
||||
for {
|
||||
if !it.b.valid {
|
||||
it.yieldFromA = true
|
||||
return
|
||||
}
|
||||
if !it.a.valid {
|
||||
it.yieldFromA = false
|
||||
return
|
||||
}
|
||||
|
||||
switch compareNodes(&it.a, &it.b) {
|
||||
case -1:
|
||||
it.yieldFromA = true
|
||||
return
|
||||
case 1:
|
||||
it.yieldFromA = false
|
||||
return
|
||||
case 0:
|
||||
if it.b.Next(true) {
|
||||
it.count++
|
||||
}
|
||||
if it.a.Next(true) {
|
||||
it.count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (it *symDiffIterator) Error() error {
|
||||
if err := it.a.Error(); err != nil {
|
||||
return err
|
||||
}
|
||||
return it.b.Error()
|
||||
}
|
||||
|
||||
// Compares nodes with equal paths by value
|
||||
func compareNodes(a, b trie.NodeIterator) int {
|
||||
if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 {
|
||||
return cmp
|
||||
}
|
||||
if a.Leaf() && !b.Leaf() {
|
||||
return -1
|
||||
} else if b.Leaf() && !a.Leaf() {
|
||||
return 1
|
||||
}
|
||||
if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 {
|
||||
return cmp
|
||||
}
|
||||
if a.Leaf() && b.Leaf() {
|
||||
return bytes.Compare(a.LeafBlob(), b.LeafBlob())
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
99
utils/iterator_test.go
Normal file
99
utils/iterator_test.go
Normal file
@ -0,0 +1,99 @@
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/cerc-io/plugeth-statediff/utils"
|
||||
"github.com/ethereum/go-ethereum/core/rawdb"
|
||||
"github.com/ethereum/go-ethereum/trie"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type kvs struct{ k, v string }
|
||||
|
||||
var (
|
||||
testdata1 = []kvs{
|
||||
{"bar", "b"},
|
||||
{"barb", "ba"},
|
||||
{"bard", "bc"},
|
||||
{"bars", "bb"},
|
||||
{"fab", "z"},
|
||||
{"foo", "a"},
|
||||
{"food", "ab"},
|
||||
}
|
||||
|
||||
testdata2 = []kvs{
|
||||
{"aardvark", "c"},
|
||||
{"bar", "b"},
|
||||
{"barb", "bd"},
|
||||
{"bars", "be"},
|
||||
{"fab", "z"},
|
||||
{"foo", "a"},
|
||||
{"foos", "aa"},
|
||||
{"jars", "d"},
|
||||
}
|
||||
)
|
||||
|
||||
func TestSymmetricDifferenceIterator(t *testing.T) {
|
||||
t.Run("empty case", func(t *testing.T) {
|
||||
db := trie.NewDatabase(rawdb.NewMemoryDatabase())
|
||||
tree := trie.NewEmpty(db)
|
||||
di, count := utils.NewSymmetricDifferenceIterator(tree.NodeIterator(nil), tree.NodeIterator(nil))
|
||||
for di.Next(true) {
|
||||
t.Errorf("iterator should not yield any elements")
|
||||
}
|
||||
if *count != 0 {
|
||||
t.Errorf("node count should be 0 for empty trie, got %d", *count)
|
||||
}
|
||||
})
|
||||
|
||||
dba := trie.NewDatabase(rawdb.NewMemoryDatabase())
|
||||
triea := trie.NewEmpty(dba)
|
||||
for _, val := range testdata1 {
|
||||
triea.MustUpdate([]byte(val.k), []byte(val.v))
|
||||
}
|
||||
rootA, nodesA := triea.Commit(false)
|
||||
dba.Update(trie.NewWithNodeSet(nodesA))
|
||||
triea, _ = trie.New(trie.TrieID(rootA), dba)
|
||||
|
||||
dbb := trie.NewDatabase(rawdb.NewMemoryDatabase())
|
||||
trieb := trie.NewEmpty(dbb)
|
||||
for _, val := range testdata2 {
|
||||
trieb.MustUpdate([]byte(val.k), []byte(val.v))
|
||||
}
|
||||
rootB, nodesB := trieb.Commit(false)
|
||||
dbb.Update(trie.NewWithNodeSet(nodesB))
|
||||
trieb, _ = trie.New(trie.TrieID(rootB), dbb)
|
||||
|
||||
onlyA := make(map[string]string)
|
||||
onlyB := make(map[string]string)
|
||||
it, _ := utils.NewSymmetricDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
|
||||
for it.Next(true) {
|
||||
if !it.Leaf() {
|
||||
continue
|
||||
}
|
||||
key, value := string(it.LeafKey()), string(it.LeafBlob())
|
||||
if it.FromA() {
|
||||
onlyA[key] = value
|
||||
} else {
|
||||
onlyB[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
expectedOnlyA := map[string]string{
|
||||
"barb": "ba",
|
||||
"bard": "bc",
|
||||
"bars": "bb",
|
||||
"food": "ab",
|
||||
}
|
||||
expectedOnlyB := map[string]string{
|
||||
"aardvark": "c",
|
||||
"barb": "bd",
|
||||
"bars": "be",
|
||||
"foos": "aa",
|
||||
"jars": "d",
|
||||
}
|
||||
assert.Equal(t, expectedOnlyA, onlyA)
|
||||
assert.Equal(t, expectedOnlyB, onlyB)
|
||||
}
|
Loading…
Reference in New Issue
Block a user