diff --git a/restricted/trie/iterator.go b/restricted/trie/iterator.go new file mode 100644 index 0000000..67080d6 --- /dev/null +++ b/restricted/trie/iterator.go @@ -0,0 +1,169 @@ +package trie + +import ( + "bytes" + + "github.com/openrelayxyz/plugeth-utils/core" +) + +type NodeIterator = core.NodeIterator + +// Iterator is a key-value trie iterator that traverses a Trie. +type Iterator struct { + nodeIt core.NodeIterator + + Key []byte // Current data key on which the iterator is positioned on + Value []byte // Current data value on which the iterator is positioned on + Err error +} + +// NewIterator creates a new key-value iterator from a node iterator. +// Note that the value returned by the iterator is raw. If the content is encoded +// (e.g. storage value is RLP-encoded), it's caller's duty to decode it. +func NewIterator(it core.NodeIterator) *Iterator { + return &Iterator{ + nodeIt: it, + } +} + +// Next moves the iterator forward one key-value entry. +func (it *Iterator) Next() bool { + for it.nodeIt.Next(true) { + if it.nodeIt.Leaf() { + it.Key = it.nodeIt.LeafKey() + it.Value = it.nodeIt.LeafBlob() + return true + } + } + it.Key = nil + it.Value = nil + it.Err = it.nodeIt.Error() + return false +} + +// Prove generates the Merkle proof for the leaf node the iterator is currently +// positioned on. +func (it *Iterator) Prove() [][]byte { + return it.nodeIt.LeafProof() +} + +type differenceIterator struct { + a, b core.NodeIterator // Nodes returned are those in b - a. + eof bool // Indicates a has run out of elements + count int // Number of nodes scanned on either trie +} + +// NewDifferenceIterator constructs a NodeIterator that iterates over elements in b that +// are not in a. Returns the iterator, and a pointer to an integer recording the number +// of nodes seen. +func NewDifferenceIterator(a, b core.NodeIterator) (core.NodeIterator, *int) { + a.Next(true) + it := &differenceIterator{ + a: a, + b: b, + } + return it, &it.count +} + +func (it *differenceIterator) Hash() core.Hash { + return it.b.Hash() +} + +func (it *differenceIterator) Parent() core.Hash { + return it.b.Parent() +} + +func (it *differenceIterator) Leaf() bool { + return it.b.Leaf() +} + +func (it *differenceIterator) LeafKey() []byte { + return it.b.LeafKey() +} + +func (it *differenceIterator) LeafBlob() []byte { + return it.b.LeafBlob() +} + +func (it *differenceIterator) LeafProof() [][]byte { + return it.b.LeafProof() +} + +func (it *differenceIterator) Path() []byte { + return it.b.Path() +} + +func (it *differenceIterator) NodeBlob() []byte { + return it.b.NodeBlob() +} + +func (it *differenceIterator) AddResolver(resolver core.NodeResolver) { + panic("not implemented") +} + +func (it *differenceIterator) Next(bool) bool { + // Invariants: + // - We always advance at least one element in b. + // - At the start of this function, a's path is lexically greater than b's. + if !it.b.Next(true) { + return false + } + it.count++ + + if it.eof { + // a has reached eof, so we just return all elements from b + return true + } + + for { + switch compareNodes(it.a, it.b) { + case -1: + // b jumped past a; advance a + if !it.a.Next(true) { + it.eof = true + return true + } + it.count++ + case 1: + // b is before a + return true + case 0: + // a and b are identical; skip this whole subtree if the nodes have hashes + hasHash := it.a.Hash() == core.Hash{} + if !it.b.Next(hasHash) { + return false + } + it.count++ + if !it.a.Next(hasHash) { + it.eof = true + return true + } + it.count++ + } + } +} + +func (it *differenceIterator) Error() error { + if err := it.a.Error(); err != nil { + return err + } + return it.b.Error() +} + +func compareNodes(a, b core.NodeIterator) int { + if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 { + return cmp + } + if a.Leaf() && !b.Leaf() { + return -1 + } else if b.Leaf() && !a.Leaf() { + return 1 + } + if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 { + return cmp + } + if a.Leaf() && b.Leaf() { + return bytes.Compare(a.LeafBlob(), b.LeafBlob()) + } + return 0 +}