plugeth/trie/iterator.go
Martin Holst Swende 4b783c0064
trie: improve the node iterator seek operation (#22470)
This change improves the efficiency of the nodeIterator seek
operation. Previously, seek essentially ran the iterator forward
until it found the matching node. With this change, it skips
over fullnode children and avoids resolving them from the database.
2021-04-21 12:25:26 +02:00

675 lines
19 KiB
Go

// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"bytes"
"container/heap"
"errors"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
)
// Iterator is a key-value trie iterator that traverses a Trie.
type Iterator struct {
nodeIt NodeIterator
Key []byte // Current data key on which the iterator is positioned on
Value []byte // Current data value on which the iterator is positioned on
Err error
}
// NewIterator creates a new key-value iterator from a node iterator.
// Note that the value returned by the iterator is raw. If the content is encoded
// (e.g. storage value is RLP-encoded), it's caller's duty to decode it.
func NewIterator(it NodeIterator) *Iterator {
return &Iterator{
nodeIt: it,
}
}
// Next moves the iterator forward one key-value entry.
func (it *Iterator) Next() bool {
for it.nodeIt.Next(true) {
if it.nodeIt.Leaf() {
it.Key = it.nodeIt.LeafKey()
it.Value = it.nodeIt.LeafBlob()
return true
}
}
it.Key = nil
it.Value = nil
it.Err = it.nodeIt.Error()
return false
}
// Prove generates the Merkle proof for the leaf node the iterator is currently
// positioned on.
func (it *Iterator) Prove() [][]byte {
return it.nodeIt.LeafProof()
}
// NodeIterator is an iterator to traverse the trie pre-order.
type NodeIterator interface {
// Next moves the iterator to the next node. If the parameter is false, any child
// nodes will be skipped.
Next(bool) bool
// Error returns the error status of the iterator.
Error() error
// Hash returns the hash of the current node.
Hash() common.Hash
// Parent returns the hash of the parent of the current node. The hash may be the one
// grandparent if the immediate parent is an internal node with no hash.
Parent() common.Hash
// Path returns the hex-encoded path to the current node.
// Callers must not retain references to the return value after calling Next.
// For leaf nodes, the last element of the path is the 'terminator symbol' 0x10.
Path() []byte
// Leaf returns true iff the current node is a leaf node.
Leaf() bool
// LeafKey returns the key of the leaf. The method panics if the iterator is not
// positioned at a leaf. Callers must not retain references to the value after
// calling Next.
LeafKey() []byte
// LeafBlob returns the content of the leaf. The method panics if the iterator
// is not positioned at a leaf. Callers must not retain references to the value
// after calling Next.
LeafBlob() []byte
// LeafProof returns the Merkle proof of the leaf. The method panics if the
// iterator is not positioned at a leaf. Callers must not retain references
// to the value after calling Next.
LeafProof() [][]byte
}
// nodeIteratorState represents the iteration state at one particular node of the
// trie, which can be resumed at a later invocation.
type nodeIteratorState struct {
hash common.Hash // Hash of the node being iterated (nil if not standalone)
node node // Trie node being iterated
parent common.Hash // Hash of the first full ancestor node (nil if current is the root)
index int // Child to be processed next
pathlen int // Length of the path to this node
}
type nodeIterator struct {
trie *Trie // Trie being iterated
stack []*nodeIteratorState // Hierarchy of trie nodes persisting the iteration state
path []byte // Path to the current node
err error // Failure set in case of an internal error in the iterator
}
// errIteratorEnd is stored in nodeIterator.err when iteration is done.
var errIteratorEnd = errors.New("end of iteration")
// seekError is stored in nodeIterator.err if the initial seek has failed.
type seekError struct {
key []byte
err error
}
func (e seekError) Error() string {
return "seek error: " + e.err.Error()
}
func newNodeIterator(trie *Trie, start []byte) NodeIterator {
if trie.Hash() == emptyState {
return new(nodeIterator)
}
it := &nodeIterator{trie: trie}
it.err = it.seek(start)
return it
}
func (it *nodeIterator) Hash() common.Hash {
if len(it.stack) == 0 {
return common.Hash{}
}
return it.stack[len(it.stack)-1].hash
}
func (it *nodeIterator) Parent() common.Hash {
if len(it.stack) == 0 {
return common.Hash{}
}
return it.stack[len(it.stack)-1].parent
}
func (it *nodeIterator) Leaf() bool {
return hasTerm(it.path)
}
func (it *nodeIterator) LeafKey() []byte {
if len(it.stack) > 0 {
if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
return hexToKeybytes(it.path)
}
}
panic("not at leaf")
}
func (it *nodeIterator) LeafBlob() []byte {
if len(it.stack) > 0 {
if node, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
return node
}
}
panic("not at leaf")
}
func (it *nodeIterator) LeafProof() [][]byte {
if len(it.stack) > 0 {
if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
hasher := newHasher(false)
defer returnHasherToPool(hasher)
proofs := make([][]byte, 0, len(it.stack))
for i, item := range it.stack[:len(it.stack)-1] {
// Gather nodes that end up as hash nodes (or the root)
node, hashed := hasher.proofHash(item.node)
if _, ok := hashed.(hashNode); ok || i == 0 {
enc, _ := rlp.EncodeToBytes(node)
proofs = append(proofs, enc)
}
}
return proofs
}
}
panic("not at leaf")
}
func (it *nodeIterator) Path() []byte {
return it.path
}
func (it *nodeIterator) Error() error {
if it.err == errIteratorEnd {
return nil
}
if seek, ok := it.err.(seekError); ok {
return seek.err
}
return it.err
}
// Next moves the iterator to the next node, returning whether there are any
// further nodes. In case of an internal error this method returns false and
// sets the Error field to the encountered failure. If `descend` is false,
// skips iterating over any subnodes of the current node.
func (it *nodeIterator) Next(descend bool) bool {
if it.err == errIteratorEnd {
return false
}
if seek, ok := it.err.(seekError); ok {
if it.err = it.seek(seek.key); it.err != nil {
return false
}
}
// Otherwise step forward with the iterator and report any errors.
state, parentIndex, path, err := it.peek(descend)
it.err = err
if it.err != nil {
return false
}
it.push(state, parentIndex, path)
return true
}
func (it *nodeIterator) seek(prefix []byte) error {
// The path we're looking for is the hex encoded key without terminator.
key := keybytesToHex(prefix)
key = key[:len(key)-1]
// Move forward until we're just before the closest match to key.
for {
state, parentIndex, path, err := it.peekSeek(key)
if err == errIteratorEnd {
return errIteratorEnd
} else if err != nil {
return seekError{prefix, err}
} else if bytes.Compare(path, key) >= 0 {
return nil
}
it.push(state, parentIndex, path)
}
}
// init initializes the the iterator.
func (it *nodeIterator) init() (*nodeIteratorState, error) {
root := it.trie.Hash()
state := &nodeIteratorState{node: it.trie.root, index: -1}
if root != emptyRoot {
state.hash = root
}
return state, state.resolve(it.trie, nil)
}
// peek creates the next state of the iterator.
func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, error) {
// Initialize the iterator if we've just started.
if len(it.stack) == 0 {
state, err := it.init()
return state, nil, nil, err
}
if !descend {
// If we're skipping children, pop the current node first
it.pop()
}
// Continue iteration to the next child
for len(it.stack) > 0 {
parent := it.stack[len(it.stack)-1]
ancestor := parent.hash
if (ancestor == common.Hash{}) {
ancestor = parent.parent
}
state, path, ok := it.nextChild(parent, ancestor)
if ok {
if err := state.resolve(it.trie, path); err != nil {
return parent, &parent.index, path, err
}
return state, &parent.index, path, nil
}
// No more child nodes, move back up.
it.pop()
}
return nil, nil, nil, errIteratorEnd
}
// peekSeek is like peek, but it also tries to skip resolving hashes by skipping
// over the siblings that do not lead towards the desired seek position.
func (it *nodeIterator) peekSeek(seekKey []byte) (*nodeIteratorState, *int, []byte, error) {
// Initialize the iterator if we've just started.
if len(it.stack) == 0 {
state, err := it.init()
return state, nil, nil, err
}
if !bytes.HasPrefix(seekKey, it.path) {
// If we're skipping children, pop the current node first
it.pop()
}
// Continue iteration to the next child
for len(it.stack) > 0 {
parent := it.stack[len(it.stack)-1]
ancestor := parent.hash
if (ancestor == common.Hash{}) {
ancestor = parent.parent
}
state, path, ok := it.nextChildAt(parent, ancestor, seekKey)
if ok {
if err := state.resolve(it.trie, path); err != nil {
return parent, &parent.index, path, err
}
return state, &parent.index, path, nil
}
// No more child nodes, move back up.
it.pop()
}
return nil, nil, nil, errIteratorEnd
}
func (st *nodeIteratorState) resolve(tr *Trie, path []byte) error {
if hash, ok := st.node.(hashNode); ok {
resolved, err := tr.resolveHash(hash, path)
if err != nil {
return err
}
st.node = resolved
st.hash = common.BytesToHash(hash)
}
return nil
}
func findChild(n *fullNode, index int, path []byte, ancestor common.Hash) (node, *nodeIteratorState, []byte, int) {
var (
child node
state *nodeIteratorState
childPath []byte
)
for ; index < len(n.Children); index++ {
if n.Children[index] != nil {
child = n.Children[index]
hash, _ := child.cache()
state = &nodeIteratorState{
hash: common.BytesToHash(hash),
node: child,
parent: ancestor,
index: -1,
pathlen: len(path),
}
childPath = append(childPath, path...)
childPath = append(childPath, byte(index))
return child, state, childPath, index
}
}
return nil, nil, nil, 0
}
func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Hash) (*nodeIteratorState, []byte, bool) {
switch node := parent.node.(type) {
case *fullNode:
//Full node, move to the first non-nil child.
if child, state, path, index := findChild(node, parent.index+1, it.path, ancestor); child != nil {
parent.index = index - 1
return state, path, true
}
case *shortNode:
// Short node, return the pointer singleton child
if parent.index < 0 {
hash, _ := node.Val.cache()
state := &nodeIteratorState{
hash: common.BytesToHash(hash),
node: node.Val,
parent: ancestor,
index: -1,
pathlen: len(it.path),
}
path := append(it.path, node.Key...)
return state, path, true
}
}
return parent, it.path, false
}
// nextChildAt is similar to nextChild, except that it targets a child as close to the
// target key as possible, thus skipping siblings.
func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.Hash, key []byte) (*nodeIteratorState, []byte, bool) {
switch n := parent.node.(type) {
case *fullNode:
// Full node, move to the first non-nil child before the desired key position
child, state, path, index := findChild(n, parent.index+1, it.path, ancestor)
if child == nil {
// No more children in this fullnode
return parent, it.path, false
}
// If the child we found is already past the seek position, just return it.
if bytes.Compare(path, key) >= 0 {
parent.index = index - 1
return state, path, true
}
// The child is before the seek position. Try advancing
for {
nextChild, nextState, nextPath, nextIndex := findChild(n, index+1, it.path, ancestor)
// If we run out of children, or skipped past the target, return the
// previous one
if nextChild == nil || bytes.Compare(nextPath, key) >= 0 {
parent.index = index - 1
return state, path, true
}
// We found a better child closer to the target
state, path, index = nextState, nextPath, nextIndex
}
case *shortNode:
// Short node, return the pointer singleton child
if parent.index < 0 {
hash, _ := n.Val.cache()
state := &nodeIteratorState{
hash: common.BytesToHash(hash),
node: n.Val,
parent: ancestor,
index: -1,
pathlen: len(it.path),
}
path := append(it.path, n.Key...)
return state, path, true
}
}
return parent, it.path, false
}
func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []byte) {
it.path = path
it.stack = append(it.stack, state)
if parentIndex != nil {
*parentIndex++
}
}
func (it *nodeIterator) pop() {
parent := it.stack[len(it.stack)-1]
it.path = it.path[:parent.pathlen]
it.stack = it.stack[:len(it.stack)-1]
}
func compareNodes(a, b NodeIterator) int {
if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 {
return cmp
}
if a.Leaf() && !b.Leaf() {
return -1
} else if b.Leaf() && !a.Leaf() {
return 1
}
if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 {
return cmp
}
if a.Leaf() && b.Leaf() {
return bytes.Compare(a.LeafBlob(), b.LeafBlob())
}
return 0
}
type differenceIterator struct {
a, b NodeIterator // Nodes returned are those in b - a.
eof bool // Indicates a has run out of elements
count int // Number of nodes scanned on either trie
}
// NewDifferenceIterator constructs a NodeIterator that iterates over elements in b that
// are not in a. Returns the iterator, and a pointer to an integer recording the number
// of nodes seen.
func NewDifferenceIterator(a, b NodeIterator) (NodeIterator, *int) {
a.Next(true)
it := &differenceIterator{
a: a,
b: b,
}
return it, &it.count
}
func (it *differenceIterator) Hash() common.Hash {
return it.b.Hash()
}
func (it *differenceIterator) Parent() common.Hash {
return it.b.Parent()
}
func (it *differenceIterator) Leaf() bool {
return it.b.Leaf()
}
func (it *differenceIterator) LeafKey() []byte {
return it.b.LeafKey()
}
func (it *differenceIterator) LeafBlob() []byte {
return it.b.LeafBlob()
}
func (it *differenceIterator) LeafProof() [][]byte {
return it.b.LeafProof()
}
func (it *differenceIterator) Path() []byte {
return it.b.Path()
}
func (it *differenceIterator) Next(bool) bool {
// Invariants:
// - We always advance at least one element in b.
// - At the start of this function, a's path is lexically greater than b's.
if !it.b.Next(true) {
return false
}
it.count++
if it.eof {
// a has reached eof, so we just return all elements from b
return true
}
for {
switch compareNodes(it.a, it.b) {
case -1:
// b jumped past a; advance a
if !it.a.Next(true) {
it.eof = true
return true
}
it.count++
case 1:
// b is before a
return true
case 0:
// a and b are identical; skip this whole subtree if the nodes have hashes
hasHash := it.a.Hash() == common.Hash{}
if !it.b.Next(hasHash) {
return false
}
it.count++
if !it.a.Next(hasHash) {
it.eof = true
return true
}
it.count++
}
}
}
func (it *differenceIterator) Error() error {
if err := it.a.Error(); err != nil {
return err
}
return it.b.Error()
}
type nodeIteratorHeap []NodeIterator
func (h nodeIteratorHeap) Len() int { return len(h) }
func (h nodeIteratorHeap) Less(i, j int) bool { return compareNodes(h[i], h[j]) < 0 }
func (h nodeIteratorHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *nodeIteratorHeap) Push(x interface{}) { *h = append(*h, x.(NodeIterator)) }
func (h *nodeIteratorHeap) Pop() interface{} {
n := len(*h)
x := (*h)[n-1]
*h = (*h)[0 : n-1]
return x
}
type unionIterator struct {
items *nodeIteratorHeap // Nodes returned are the union of the ones in these iterators
count int // Number of nodes scanned across all tries
}
// NewUnionIterator constructs a NodeIterator that iterates over elements in the union
// of the provided NodeIterators. Returns the iterator, and a pointer to an integer
// recording the number of nodes visited.
func NewUnionIterator(iters []NodeIterator) (NodeIterator, *int) {
h := make(nodeIteratorHeap, len(iters))
copy(h, iters)
heap.Init(&h)
ui := &unionIterator{items: &h}
return ui, &ui.count
}
func (it *unionIterator) Hash() common.Hash {
return (*it.items)[0].Hash()
}
func (it *unionIterator) Parent() common.Hash {
return (*it.items)[0].Parent()
}
func (it *unionIterator) Leaf() bool {
return (*it.items)[0].Leaf()
}
func (it *unionIterator) LeafKey() []byte {
return (*it.items)[0].LeafKey()
}
func (it *unionIterator) LeafBlob() []byte {
return (*it.items)[0].LeafBlob()
}
func (it *unionIterator) LeafProof() [][]byte {
return (*it.items)[0].LeafProof()
}
func (it *unionIterator) Path() []byte {
return (*it.items)[0].Path()
}
// Next returns the next node in the union of tries being iterated over.
//
// It does this by maintaining a heap of iterators, sorted by the iteration
// order of their next elements, with one entry for each source trie. Each
// time Next() is called, it takes the least element from the heap to return,
// advancing any other iterators that also point to that same element. These
// iterators are called with descend=false, since we know that any nodes under
// these nodes will also be duplicates, found in the currently selected iterator.
// Whenever an iterator is advanced, it is pushed back into the heap if it still
// has elements remaining.
//
// In the case that descend=false - eg, we're asked to ignore all subnodes of the
// current node - we also advance any iterators in the heap that have the current
// path as a prefix.
func (it *unionIterator) Next(descend bool) bool {
if len(*it.items) == 0 {
return false
}
// Get the next key from the union
least := heap.Pop(it.items).(NodeIterator)
// Skip over other nodes as long as they're identical, or, if we're not descending, as
// long as they have the same prefix as the current node.
for len(*it.items) > 0 && ((!descend && bytes.HasPrefix((*it.items)[0].Path(), least.Path())) || compareNodes(least, (*it.items)[0]) == 0) {
skipped := heap.Pop(it.items).(NodeIterator)
// Skip the whole subtree if the nodes have hashes; otherwise just skip this node
if skipped.Next(skipped.Hash() == common.Hash{}) {
it.count++
// If there are more elements, push the iterator back on the heap
heap.Push(it.items, skipped)
}
}
if least.Next(descend) {
it.count++
heap.Push(it.items, least)
}
return len(*it.items) > 0
}
func (it *unionIterator) Error() error {
for i := 0; i < len(*it.items); i++ {
if err := (*it.items)[i].Error(); err != nil {
return err
}
}
return nil
}