trie: implement unionIterator (#14312)
This commit is contained in:
		
							parent
							
								
									409b61fe3c
								
							
						
					
					
						commit
						b35aa21f9f
					
				
							
								
								
									
										138
									
								
								trie/iterator.go
									
									
									
									
									
								
							
							
						
						
									
										138
									
								
								trie/iterator.go
									
									
									
									
									
								
							| @ -18,7 +18,7 @@ package trie | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 
 | ||||
| 	"container/heap" | ||||
| 	"github.com/ethereum/go-ethereum/common" | ||||
| ) | ||||
| 
 | ||||
| @ -268,6 +268,26 @@ outer: | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func compareNodes(a, b NodeIterator) int { | ||||
| 	cmp := bytes.Compare(a.Path(), b.Path()) | ||||
| 	if cmp != 0 { | ||||
| 		return cmp | ||||
| 	} | ||||
| 
 | ||||
| 	if a.Leaf() && !b.Leaf() { | ||||
| 		return -1 | ||||
| 	} else if b.Leaf() && !a.Leaf() { | ||||
| 		return 1 | ||||
| 	} | ||||
| 
 | ||||
| 	cmp = bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()) | ||||
| 	if cmp != 0 { | ||||
| 		return cmp | ||||
| 	} | ||||
| 
 | ||||
| 	return bytes.Compare(a.LeafBlob(), b.LeafBlob()) | ||||
| } | ||||
| 
 | ||||
| type differenceIterator struct { | ||||
| 	a, b  NodeIterator // Nodes returned are those in b - a.
 | ||||
| 	eof   bool         // Indicates a has run out of elements
 | ||||
| @ -321,8 +341,7 @@ func (it *differenceIterator) Next(bool) bool { | ||||
| 	} | ||||
| 
 | ||||
| 	for { | ||||
| 		apath, bpath := it.a.Path(), it.b.Path() | ||||
| 		switch bytes.Compare(apath, bpath) { | ||||
| 		switch compareNodes(it.a, it.b) { | ||||
| 		case -1: | ||||
| 			// b jumped past a; advance a
 | ||||
| 			if !it.a.Next(true) { | ||||
| @ -334,15 +353,6 @@ func (it *differenceIterator) Next(bool) bool { | ||||
| 			// b is before a
 | ||||
| 			return true | ||||
| 		case 0: | ||||
| 			if it.a.Hash() != it.b.Hash() || it.a.Leaf() != it.b.Leaf() { | ||||
| 				// Keys are identical, but hashes or leaf status differs
 | ||||
| 				return true | ||||
| 			} | ||||
| 			if it.a.Leaf() && it.b.Leaf() && !bytes.Equal(it.a.LeafBlob(), it.b.LeafBlob()) { | ||||
| 				// Both are leaf nodes, but with different values
 | ||||
| 				return true | ||||
| 			} | ||||
| 
 | ||||
| 			// a and b are identical; skip this whole subtree if the nodes have hashes
 | ||||
| 			hasHash := it.a.Hash() == common.Hash{} | ||||
| 			if !it.b.Next(hasHash) { | ||||
| @ -364,3 +374,107 @@ func (it *differenceIterator) Error() error { | ||||
| 	} | ||||
| 	return it.b.Error() | ||||
| } | ||||
| 
 | ||||
| type nodeIteratorHeap []NodeIterator | ||||
| 
 | ||||
| func (h nodeIteratorHeap) Len() int            { return len(h) } | ||||
| func (h nodeIteratorHeap) Less(i, j int) bool  { return compareNodes(h[i], h[j]) < 0 } | ||||
| func (h nodeIteratorHeap) Swap(i, j int)       { h[i], h[j] = h[j], h[i] } | ||||
| func (h *nodeIteratorHeap) Push(x interface{}) { *h = append(*h, x.(NodeIterator)) } | ||||
| func (h *nodeIteratorHeap) Pop() interface{} { | ||||
| 	n := len(*h) | ||||
| 	x := (*h)[n-1] | ||||
| 	*h = (*h)[0 : n-1] | ||||
| 	return x | ||||
| } | ||||
| 
 | ||||
| type unionIterator struct { | ||||
| 	items *nodeIteratorHeap // Nodes returned are the union of the ones in these iterators
 | ||||
| 	count int               // Number of nodes scanned across all tries
 | ||||
| 	err   error             // The error, if one has been encountered
 | ||||
| } | ||||
| 
 | ||||
| // NewUnionIterator constructs a NodeIterator that iterates over elements in the union
 | ||||
| // of the provided NodeIterators. Returns the iterator, and a pointer to an integer
 | ||||
| // recording the number of nodes visited.
 | ||||
| func NewUnionIterator(iters []NodeIterator) (NodeIterator, *int) { | ||||
| 	h := make(nodeIteratorHeap, len(iters)) | ||||
| 	copy(h, iters) | ||||
| 	heap.Init(&h) | ||||
| 
 | ||||
| 	ui := &unionIterator{ | ||||
| 		items: &h, | ||||
| 	} | ||||
| 	return ui, &ui.count | ||||
| } | ||||
| 
 | ||||
| func (it *unionIterator) Hash() common.Hash { | ||||
| 	return (*it.items)[0].Hash() | ||||
| } | ||||
| 
 | ||||
| func (it *unionIterator) Parent() common.Hash { | ||||
| 	return (*it.items)[0].Parent() | ||||
| } | ||||
| 
 | ||||
| func (it *unionIterator) Leaf() bool { | ||||
| 	return (*it.items)[0].Leaf() | ||||
| } | ||||
| 
 | ||||
| func (it *unionIterator) LeafBlob() []byte { | ||||
| 	return (*it.items)[0].LeafBlob() | ||||
| } | ||||
| 
 | ||||
| func (it *unionIterator) Path() []byte { | ||||
| 	return (*it.items)[0].Path() | ||||
| } | ||||
| 
 | ||||
| // Next returns the next node in the union of tries being iterated over.
 | ||||
| //
 | ||||
| // It does this by maintaining a heap of iterators, sorted by the iteration
 | ||||
| // order of their next elements, with one entry for each source trie. Each
 | ||||
| // time Next() is called, it takes the least element from the heap to return,
 | ||||
| // advancing any other iterators that also point to that same element. These
 | ||||
| // iterators are called with descend=false, since we know that any nodes under
 | ||||
| // these nodes will also be duplicates, found in the currently selected iterator.
 | ||||
| // Whenever an iterator is advanced, it is pushed back into the heap if it still
 | ||||
| // has elements remaining.
 | ||||
| //
 | ||||
| // In the case that descend=false - eg, we're asked to ignore all subnodes of the
 | ||||
| // current node - we also advance any iterators in the heap that have the current
 | ||||
| // path as a prefix.
 | ||||
| func (it *unionIterator) Next(descend bool) bool { | ||||
| 	if len(*it.items) == 0 { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	// Get the next key from the union
 | ||||
| 	least := heap.Pop(it.items).(NodeIterator) | ||||
| 
 | ||||
| 	// Skip over other nodes as long as they're identical, or, if we're not descending, as
 | ||||
| 	// long as they have the same prefix as the current node.
 | ||||
| 	for len(*it.items) > 0 && ((!descend && bytes.HasPrefix((*it.items)[0].Path(), least.Path())) || compareNodes(least, (*it.items)[0]) == 0) { | ||||
| 		skipped := heap.Pop(it.items).(NodeIterator) | ||||
| 		// Skip the whole subtree if the nodes have hashes; otherwise just skip this node
 | ||||
| 		if skipped.Next(skipped.Hash() == common.Hash{}) { | ||||
| 			it.count += 1 | ||||
| 			// If there are more elements, push the iterator back on the heap
 | ||||
| 			heap.Push(it.items, skipped) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if least.Next(descend) { | ||||
| 		it.count += 1 | ||||
| 		heap.Push(it.items, least) | ||||
| 	} | ||||
| 
 | ||||
| 	return len(*it.items) > 0 | ||||
| } | ||||
| 
 | ||||
| func (it *unionIterator) Error() error { | ||||
| 	for i := 0; i < len(*it.items); i++ { | ||||
| 		if err := (*it.items)[i].Error(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @ -117,36 +117,38 @@ func TestNodeIteratorCoverage(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| var testdata1 = []struct{ k, v string }{ | ||||
| 	{"bar", "b"}, | ||||
| 	{"barb", "ba"}, | ||||
| 	{"bars", "bb"}, | ||||
| 	{"bard", "bc"}, | ||||
| 	{"fab", "z"}, | ||||
| 	{"foo", "a"}, | ||||
| 	{"food", "ab"}, | ||||
| 	{"foos", "aa"}, | ||||
| } | ||||
| 
 | ||||
| var testdata2 = []struct{ k, v string }{ | ||||
| 	{"aardvark", "c"}, | ||||
| 	{"bar", "b"}, | ||||
| 	{"barb", "bd"}, | ||||
| 	{"bars", "be"}, | ||||
| 	{"fab", "z"}, | ||||
| 	{"foo", "a"}, | ||||
| 	{"foos", "aa"}, | ||||
| 	{"food", "ab"}, | ||||
| 	{"jars", "d"}, | ||||
| } | ||||
| 
 | ||||
| func TestDifferenceIterator(t *testing.T) { | ||||
| 	triea := newEmpty() | ||||
| 	valsa := []struct{ k, v string }{ | ||||
| 		{"bar", "b"}, | ||||
| 		{"barb", "ba"}, | ||||
| 		{"bars", "bb"}, | ||||
| 		{"bard", "bc"}, | ||||
| 		{"fab", "z"}, | ||||
| 		{"foo", "a"}, | ||||
| 		{"food", "ab"}, | ||||
| 		{"foos", "aa"}, | ||||
| 	} | ||||
| 	for _, val := range valsa { | ||||
| 	for _, val := range testdata1 { | ||||
| 		triea.Update([]byte(val.k), []byte(val.v)) | ||||
| 	} | ||||
| 	triea.Commit() | ||||
| 
 | ||||
| 	trieb := newEmpty() | ||||
| 	valsb := []struct{ k, v string }{ | ||||
| 		{"aardvark", "c"}, | ||||
| 		{"bar", "b"}, | ||||
| 		{"barb", "bd"}, | ||||
| 		{"bars", "be"}, | ||||
| 		{"fab", "z"}, | ||||
| 		{"foo", "a"}, | ||||
| 		{"foos", "aa"}, | ||||
| 		{"food", "ab"}, | ||||
| 		{"jars", "d"}, | ||||
| 	} | ||||
| 	for _, val := range valsb { | ||||
| 	for _, val := range testdata2 { | ||||
| 		trieb.Update([]byte(val.k), []byte(val.v)) | ||||
| 	} | ||||
| 	trieb.Commit() | ||||
| @ -166,10 +168,57 @@ func TestDifferenceIterator(t *testing.T) { | ||||
| 	} | ||||
| 	for _, item := range all { | ||||
| 		if found[item.k] != item.v { | ||||
| 			t.Errorf("iterator value mismatch for %s: got %q want %q", item.k, found[item.k], item.v) | ||||
| 			t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v) | ||||
| 		} | ||||
| 	} | ||||
| 	if len(found) != len(all) { | ||||
| 		t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUnionIterator(t *testing.T) { | ||||
| 	triea := newEmpty() | ||||
| 	for _, val := range testdata1 { | ||||
| 		triea.Update([]byte(val.k), []byte(val.v)) | ||||
| 	} | ||||
| 	triea.Commit() | ||||
| 
 | ||||
| 	trieb := newEmpty() | ||||
| 	for _, val := range testdata2 { | ||||
| 		trieb.Update([]byte(val.k), []byte(val.v)) | ||||
| 	} | ||||
| 	trieb.Commit() | ||||
| 
 | ||||
| 	di, _ := NewUnionIterator([]NodeIterator{NewNodeIterator(triea), NewNodeIterator(trieb)}) | ||||
| 	it := NewIteratorFromNodeIterator(di) | ||||
| 
 | ||||
| 	all := []struct{ k, v string }{ | ||||
| 		{"aardvark", "c"}, | ||||
| 		{"barb", "bd"}, | ||||
| 		{"barb", "ba"}, | ||||
| 		{"bard", "bc"}, | ||||
| 		{"bars", "bb"}, | ||||
| 		{"bars", "be"}, | ||||
| 		{"bar", "b"}, | ||||
| 		{"fab", "z"}, | ||||
| 		{"food", "ab"}, | ||||
| 		{"foos", "aa"}, | ||||
| 		{"foo", "a"}, | ||||
| 		{"jars", "d"}, | ||||
| 	} | ||||
| 
 | ||||
| 	for i, kv := range all { | ||||
| 		if !it.Next() { | ||||
| 			t.Errorf("Iterator ends prematurely at element %d", i) | ||||
| 		} | ||||
| 		if kv.k != string(it.Key) { | ||||
| 			t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k) | ||||
| 		} | ||||
| 		if kv.v != string(it.Value) { | ||||
| 			t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v) | ||||
| 		} | ||||
| 	} | ||||
| 	if it.Next() { | ||||
| 		t.Errorf("Iterator returned extra values.") | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user