package utils import ( "bytes" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/trie" ) type symmDiffIterator struct { a, b iterState // Nodes returned are those in b - a and a - b (keys only) yieldFromA bool // Whether next node comes from a count int // Number of nodes scanned on either trie eqPathIndex int // Count index of last pair of equal paths, to detect an updated key } // NewSymmetricDifferenceIterator constructs a trie.NodeIterator that iterates over the symmetric difference // of elements in a and b, i.e., the elements in a that are not in b, and vice versa. // Returns the iterator, and a pointer to an integer recording the number of nodes seen. func NewSymmetricDifferenceIterator(a, b trie.NodeIterator) (*symmDiffIterator, *int) { it := &symmDiffIterator{ a: iterState{a, true}, b: iterState{b, true}, // common paths are detected by a distance <=1 from this index, so put it out of reach eqPathIndex: -2, } return it, &it.count } // pairs an iterator with a cache of its valid status type iterState struct { trie.NodeIterator valid bool } func (st *iterState) Next(descend bool) bool { st.valid = st.NodeIterator.Next(descend) return st.valid } func (it *symmDiffIterator) curr() *iterState { if it.yieldFromA { return &it.a } return &it.b } // FromA returns true if the current node is sourced from A. func (it *symmDiffIterator) FromA() bool { return it.yieldFromA } // CommonPath returns true if a node with the current path exists in each sub-iterator - i.e. it // represents an updated node. func (it *symmDiffIterator) CommonPath() bool { return it.count-it.eqPathIndex <= 1 } func (it *symmDiffIterator) Hash() common.Hash { return it.curr().Hash() } func (it *symmDiffIterator) Parent() common.Hash { return it.curr().Parent() } func (it *symmDiffIterator) Leaf() bool { return it.curr().Leaf() } func (it *symmDiffIterator) LeafKey() []byte { return it.curr().LeafKey() } func (it *symmDiffIterator) LeafBlob() []byte { return it.curr().LeafBlob() } func (it *symmDiffIterator) LeafProof() [][]byte { return it.curr().LeafProof() } func (it *symmDiffIterator) Path() []byte { return it.curr().Path() } func (it *symmDiffIterator) NodeBlob() []byte { return it.curr().NodeBlob() } func (it *symmDiffIterator) AddResolver(resolver trie.NodeResolver) { panic("not implemented") } func (it *symmDiffIterator) Next(bool) bool { // NodeIterators start in a "pre-valid" state, so the first Next advances to a valid node. if it.count == 0 { if it.a.Next(true) { it.count++ } if it.b.Next(true) { it.count++ } } else { if it.curr().Next(true) { it.count++ } } it.seek() return it.a.valid || it.b.valid } func (it *symmDiffIterator) seek() { // Invariants: // - At the end of the function, the sub-iterator with the lexically lesser path // points to the next element // - Said sub-iterator never points to an element present in the other for { if !it.b.valid { it.yieldFromA = true return } if !it.a.valid { it.yieldFromA = false return } cmp := bytes.Compare(it.a.Path(), it.b.Path()) if cmp == 0 { it.eqPathIndex = it.count cmp = compareNodes(&it.a, &it.b) } switch cmp { case -1: it.yieldFromA = true return case 1: it.yieldFromA = false return case 0: if it.a.Next(true) { it.count++ } if it.b.Next(true) { it.count++ } } } } func (it *symmDiffIterator) Error() error { if err := it.a.Error(); err != nil { return err } return it.b.Error() } func compareNodes(a, b trie.NodeIterator) int { // if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 { // return cmp // } if a.Leaf() && !b.Leaf() { return -1 } else if b.Leaf() && !a.Leaf() { return 1 } if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 { return cmp } if a.Leaf() && b.Leaf() { return bytes.Compare(a.LeafBlob(), b.LeafBlob()) } return 0 }