diff --git a/utils/iterator.go b/utils/iterator.go index a2486df..27214b3 100644 --- a/utils/iterator.go +++ b/utils/iterator.go @@ -8,23 +8,8 @@ import ( ) type symmDiffIterator struct { - a, b iterState // Nodes returned are those in b - a and a - b (keys only) - yieldFromA bool // Whether next node comes from a - count int // Number of nodes scanned on either trie - eqPathIndex int // Count index of last pair of equal paths, to detect an updated key -} - -// NewSymmetricDifferenceIterator constructs a trie.NodeIterator that iterates over the symmetric difference -// of elements in a and b, i.e., the elements in a that are not in b, and vice versa. -// Returns the iterator, and a pointer to an integer recording the number of nodes seen. -func NewSymmetricDifferenceIterator(a, b trie.NodeIterator) (*symmDiffIterator, *int) { - it := &symmDiffIterator{ - a: iterState{a, true}, - b: iterState{b, true}, - // common paths are detected by a distance <=1 from this index, so put it out of reach - eqPathIndex: -2, - } - return it, &it.count + a, b iterState // Nodes returned are those in b - a and a - b (keys only) + SymmDiffAux } // pairs an iterator with a cache of its valid status @@ -33,11 +18,49 @@ type iterState struct { valid bool } +// SymmDiffAux exposes state specific to symmetric difference iteration, which is not accessible +// from the NodeIterator interface. This includes the number of nodes seen, whether the current key +// is common to both A and B, and whether the current node is sourced from A or B. +type SymmDiffAux struct { + yieldFromA bool // Whether next node comes from a + count int // Number of nodes scanned on either trie + eqPathIndex int // Count index of last pair of equal paths, to detect an updated key +} + +// NewSymmetricDifferenceIterator constructs a trie.NodeIterator that iterates over the symmetric difference +// of elements in a and b, i.e., the elements in a that are not in b, and vice versa. +// Returns the iterator, and a pointer to an auxiliary object for accessing the state not exposed by the NodeIterator interface recording the number of nodes seen. +func NewSymmetricDifferenceIterator(a, b trie.NodeIterator) (trie.NodeIterator, *SymmDiffAux) { + it := &symmDiffIterator{ + a: iterState{a, true}, + b: iterState{b, true}, + // common paths are detected by a distance <=1 between count and this index, so we start at -2 + SymmDiffAux: SymmDiffAux{eqPathIndex: -2}, + } + return it, &it.SymmDiffAux +} + func (st *iterState) Next(descend bool) bool { st.valid = st.NodeIterator.Next(descend) return st.valid } +// FromA returns true if the current node is sourced from A. +func (it *SymmDiffAux) FromA() bool { + return it.yieldFromA +} + +// CommonPath returns true if a node with the current path exists in each sub-iterator - i.e. it +// represents an updated node. +func (it *SymmDiffAux) CommonPath() bool { + return it.count-it.eqPathIndex <= 1 +} + +// Count returns the number of nodes seen. +func (it *SymmDiffAux) Count() int { + return it.count +} + func (it *symmDiffIterator) curr() *iterState { if it.yieldFromA { return &it.a @@ -45,17 +68,6 @@ func (it *symmDiffIterator) curr() *iterState { return &it.b } -// FromA returns true if the current node is sourced from A. -func (it *symmDiffIterator) FromA() bool { - return it.yieldFromA -} - -// CommonPath returns true if a node with the current path exists in each sub-iterator - i.e. it -// represents an updated node. -func (it *symmDiffIterator) CommonPath() bool { - return it.count-it.eqPathIndex <= 1 -} - func (it *symmDiffIterator) Hash() common.Hash { return it.curr().Hash() }