diff --git a/trie_by_cid/state/database.go b/trie_by_cid/state/database.go index 4c6a182..07fd020 100644 --- a/trie_by_cid/state/database.go +++ b/trie_by_cid/state/database.go @@ -42,6 +42,7 @@ type Database interface { // Trie is a Ethereum Merkle Patricia trie. type Trie interface { TryGet(key []byte) ([]byte, error) + TryGetNode(path []byte) ([]byte, int, error) TryGetAccount(key []byte) (*types.StateAccount, error) Hash() common.Hash NodeIterator(startKey []byte) trie.NodeIterator diff --git a/trie_by_cid/trie/encoding.go b/trie_by_cid/trie/encoding.go index 5871a66..381883a 100644 --- a/trie_by_cid/trie/encoding.go +++ b/trie_by_cid/trie/encoding.go @@ -16,6 +16,21 @@ package trie +// CompactToHex converts a compact encoded path to hex format +func CompactToHex(compact []byte) []byte { + if len(compact) == 0 { + return compact + } + base := keybytesToHex(compact) + // delete terminator flag + if base[0] < 2 { + base = base[:len(base)-1] + } + // apply odd flag + chop := 2 - base[0]&1 + return base[chop:] +} + func keybytesToHex(str []byte) []byte { l := len(str)*2 + 1 var nibbles = make([]byte, l) diff --git a/trie_by_cid/trie/secure_trie.go b/trie_by_cid/trie/secure_trie.go index 731b095..6e4b4ca 100644 --- a/trie_by_cid/trie/secure_trie.go +++ b/trie_by_cid/trie/secure_trie.go @@ -95,6 +95,14 @@ func (t *StateTrie) TryGetAccount(key []byte) (*types.StateAccount, error) { return &ret, err } +// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not +// possible to use keybyte-encoding as the path might contain odd nibbles. +// If the specified trie node is not in the trie, nil will be returned. +// If a trie node is not found in the database, a MissingNodeError is returned. +func (t *StateTrie) TryGetNode(path []byte) ([]byte, int, error) { + return t.trie.TryGetNode(path) +} + // Hash returns the root hash of StateTrie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *StateTrie) Hash() common.Hash { diff --git a/trie_by_cid/trie/trie.go b/trie_by_cid/trie/trie.go index 6adb08b..f82c128 100644 --- a/trie_by_cid/trie/trie.go +++ b/trie_by_cid/trie/trie.go @@ -3,6 +3,7 @@ package trie import ( "bytes" + "errors" "fmt" "github.com/ethereum/go-ethereum/common" @@ -125,6 +126,83 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode } } +// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not +// possible to use keybyte-encoding as the path might contain odd nibbles. +func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) { + item, newroot, resolved, err := t.tryGetNode(t.root, CompactToHex(path), 0) + if err != nil { + return nil, resolved, err + } + if resolved > 0 { + t.root = newroot + } + if item == nil { + return nil, resolved, nil + } + return item, resolved, err +} + +func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) { + // If non-existent path requested, abort + if origNode == nil { + return nil, nil, 0, nil + } + // If we reached the requested path, return the current node + if pos >= len(path) { + // Although we most probably have the original node expanded, encoding + // that into consensus form can be nasty (needs to cascade down) and + // time consuming. Instead, just pull the hash up from disk directly. + var hash hashNode + if node, ok := origNode.(hashNode); ok { + hash = node + } else { + hash, _ = origNode.cache() + } + if hash == nil { + return nil, origNode, 0, errors.New("non-consensus node") + } + blob, err := t.db.Node(hash) + return blob, origNode, 1, err + } + // Path still needs to be traversed, descend into children + switch n := (origNode).(type) { + case valueNode: + // Path prematurely ended, abort + return nil, nil, 0, nil + + case *shortNode: + if len(path)-pos < len(n.Key) || !bytes.Equal(n.Key, path[pos:pos+len(n.Key)]) { + // Path branches off from short node + return nil, n, 0, nil + } + item, newnode, resolved, err = t.tryGetNode(n.Val, path, pos+len(n.Key)) + if err == nil && resolved > 0 { + n = n.copy() + n.Val = newnode + } + return item, n, resolved, err + + case *fullNode: + item, newnode, resolved, err = t.tryGetNode(n.Children[path[pos]], path, pos+1) + if err == nil && resolved > 0 { + n = n.copy() + n.Children[path[pos]] = newnode + } + return item, n, resolved, err + + case hashNode: + child, err := t.resolveHash(n, path[:pos]) + if err != nil { + return nil, n, 1, err + } + item, newnode, resolved, err := t.tryGetNode(child, path, pos) + return item, newnode, resolved + 1, err + + default: + panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) + } +} + // resolveHash loads node from the underlying database with the provided // node hash and path prefix. func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) {