diff --git a/trie/iterator.go b/trie/iterator.go index 9f6dc3af7..9b7d97a5f 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -86,6 +86,10 @@ type NodeIterator interface { // For leaf nodes, the last element of the path is the 'terminator symbol' 0x10. Path() []byte + // NodeBlob returns the rlp-encoded value of the current iterated node. + // If the node is an embedded node in its parent, nil is returned then. + NodeBlob() []byte + // Leaf returns true iff the current node is a leaf node. Leaf() bool @@ -224,6 +228,18 @@ func (it *nodeIterator) Path() []byte { return it.path } +func (it *nodeIterator) NodeBlob() []byte { + if it.Hash() == (common.Hash{}) { + return nil // skip the non-standalone node + } + blob, err := it.resolveBlob(it.Hash().Bytes(), it.Path()) + if err != nil { + it.err = err + return nil + } + return blob +} + func (it *nodeIterator) Error() error { if it.err == errIteratorEnd { return nil @@ -362,6 +378,15 @@ func (it *nodeIterator) resolveHash(hash hashNode, path []byte) (node, error) { return resolved, err } +func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) { + if it.resolver != nil { + if blob, err := it.resolver.Get(hash); err == nil && len(blob) > 0 { + return blob, nil + } + } + return it.trie.resolveBlob(hash, path) +} + func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error { if hash, ok := st.node.(hashNode); ok { resolved, err := it.resolveHash(hash, path) @@ -549,6 +574,10 @@ func (it *differenceIterator) Path() []byte { return it.b.Path() } +func (it *differenceIterator) NodeBlob() []byte { + return it.b.NodeBlob() +} + func (it *differenceIterator) AddResolver(resolver ethdb.KeyValueReader) { panic("not implemented") } @@ -660,6 +689,10 @@ func (it *unionIterator) Path() []byte { return (*it.items)[0].Path() } +func (it *unionIterator) NodeBlob() []byte { + return (*it.items)[0].NodeBlob() +} + func (it *unionIterator) AddResolver(resolver ethdb.KeyValueReader) { panic("not implemented") } diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 95cafdd3b..162f781c5 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -521,3 +521,54 @@ func TestNodeIteratorLargeTrie(t *testing.T) { t.Fatalf("Too many lookups during seek, have %d want %d", have, want) } } + +func TestIteratorNodeBlob(t *testing.T) { + var ( + db = memorydb.New() + triedb = NewDatabase(db) + trie, _ = New(common.Hash{}, triedb) + ) + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"dog", "puppy"}, + {"somethingveryoddindeedthis is", "myothernodedata"}, + } + all := make(map[string]string) + for _, val := range vals { + all[val.k] = val.v + trie.Update([]byte(val.k), []byte(val.v)) + } + trie.Commit(nil) + triedb.Cap(0) + + found := make(map[common.Hash][]byte) + it := trie.NodeIterator(nil) + for it.Next(true) { + if it.Hash() == (common.Hash{}) { + continue + } + found[it.Hash()] = it.NodeBlob() + } + + dbIter := db.NewIterator(nil, nil) + defer dbIter.Release() + + var count int + for dbIter.Next() { + got, present := found[common.BytesToHash(dbIter.Key())] + if !present { + t.Fatalf("Miss trie node %v", dbIter.Key()) + } + if !bytes.Equal(got, dbIter.Value()) { + t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got) + } + count += 1 + } + if count != len(found) { + t.Fatal("Find extra trie node via iterator") + } +} diff --git a/trie/trie.go b/trie/trie.go index 13343112b..e40b03be3 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -514,6 +514,15 @@ func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { return nil, &MissingNodeError{NodeHash: hash, Path: prefix} } +func (t *Trie) resolveBlob(n hashNode, prefix []byte) ([]byte, error) { + hash := common.BytesToHash(n) + blob, _ := t.db.Node(hash) + if len(blob) != 0 { + return blob, nil + } + return nil, &MissingNodeError{NodeHash: hash, Path: prefix} +} + // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash {