From 0573da1982755de45bad210803beb592e1061fa0 Mon Sep 17 00:00:00 2001 From: Roy Crihfield Date: Mon, 3 Apr 2023 14:44:31 +0800 Subject: [PATCH] trie iterator and tests --- bycid/trie/encoding.go | 20 ++ bycid/trie/iterator.go | 458 ++++++++++++++++++++++++++++++++++++ bycid/trie/iterator_test.go | 248 +++++++++++++++++++ bycid/trie/proof.go | 62 ++--- bycid/trie/trie.go | 17 ++ bycid/trie/util_test.go | 32 +++ helper/hasher.go | 31 +++ helper/statediff_helper.go | 81 +++++++ 8 files changed, 918 insertions(+), 31 deletions(-) create mode 100644 bycid/trie/iterator.go create mode 100644 bycid/trie/iterator_test.go create mode 100644 bycid/trie/util_test.go create mode 100644 helper/hasher.go create mode 100644 helper/statediff_helper.go diff --git a/bycid/trie/encoding.go b/bycid/trie/encoding.go index d564701..5871a66 100644 --- a/bycid/trie/encoding.go +++ b/bycid/trie/encoding.go @@ -27,6 +27,26 @@ func keybytesToHex(str []byte) []byte { return nibbles } +// hexToKeyBytes turns hex nibbles into key bytes. +// This can only be used for keys of even length. +func hexToKeyBytes(hex []byte) []byte { + if hasTerm(hex) { + hex = hex[:len(hex)-1] + } + if len(hex)&1 != 0 { + panic("can't convert hex key of odd length") + } + key := make([]byte, len(hex)/2) + decodeNibbles(hex, key) + return key +} + +func decodeNibbles(nibbles []byte, bytes []byte) { + for bi, ni := 0, 0; ni < len(nibbles); bi, ni = bi+1, ni+2 { + bytes[bi] = nibbles[ni]<<4 | nibbles[ni+1] + } +} + // hasTerm returns whether a hex key has the terminator flag. func hasTerm(s []byte) bool { return len(s) > 0 && s[len(s)-1] == 16 diff --git a/bycid/trie/iterator.go b/bycid/trie/iterator.go new file mode 100644 index 0000000..0067bc0 --- /dev/null +++ b/bycid/trie/iterator.go @@ -0,0 +1,458 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "bytes" + "errors" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/trie" +) + +// NodeIterator is a re-export of the go-ethereum interface +type NodeIterator = trie.NodeIterator + +// Iterator is a key-value trie iterator that traverses a Trie. +type Iterator struct { + nodeIt NodeIterator + + Key []byte // Current data key on which the iterator is positioned on + Value []byte // Current data value on which the iterator is positioned on + Err error +} + +// NewIterator creates a new key-value iterator from a node iterator. +// Note that the value returned by the iterator is raw. If the content is encoded +// (e.g. storage value is RLP-encoded), it's caller's duty to decode it. +func NewIterator(it NodeIterator) *Iterator { + return &Iterator{ + nodeIt: it, + } +} + +// Next moves the iterator forward one key-value entry. +func (it *Iterator) Next() bool { + for it.nodeIt.Next(true) { + if it.nodeIt.Leaf() { + it.Key = it.nodeIt.LeafKey() + it.Value = it.nodeIt.LeafBlob() + return true + } + } + it.Key = nil + it.Value = nil + it.Err = it.nodeIt.Error() + return false +} + +// Prove generates the Merkle proof for the leaf node the iterator is currently +// positioned on. +func (it *Iterator) Prove() [][]byte { + return it.nodeIt.LeafProof() +} + +// nodeIteratorState represents the iteration state at one particular node of the +// trie, which can be resumed at a later invocation. +type nodeIteratorState struct { + hash common.Hash // Hash of the node being iterated (nil if not standalone) + node node // Trie node being iterated + parent common.Hash // Hash of the first full ancestor node (nil if current is the root) + index int // Child to be processed next + pathlen int // Length of the path to this node +} + +type nodeIterator struct { + trie *Trie // Trie being iterated + stack []*nodeIteratorState // Hierarchy of trie nodes persisting the iteration state + path []byte // Path to the current node + err error // Failure set in case of an internal error in the iterator + + resolver ethdb.KeyValueReader // Optional intermediate resolver above the disk layer +} + +// errIteratorEnd is stored in nodeIterator.err when iteration is done. +var errIteratorEnd = errors.New("end of iteration") + +// seekError is stored in nodeIterator.err if the initial seek has failed. +type seekError struct { + key []byte + err error +} + +func (e seekError) Error() string { + return "seek error: " + e.err.Error() +} + +func newNodeIterator(trie *Trie, start []byte) NodeIterator { + if trie.Hash() == emptyRoot { + return &nodeIterator{ + trie: trie, + err: errIteratorEnd, + } + } + it := &nodeIterator{trie: trie} + it.err = it.seek(start) + return it +} + +func (it *nodeIterator) AddResolver(resolver ethdb.KeyValueReader) { + it.resolver = resolver +} + +func (it *nodeIterator) Hash() common.Hash { + if len(it.stack) == 0 { + return common.Hash{} + } + return it.stack[len(it.stack)-1].hash +} + +func (it *nodeIterator) Parent() common.Hash { + if len(it.stack) == 0 { + return common.Hash{} + } + return it.stack[len(it.stack)-1].parent +} + +func (it *nodeIterator) Leaf() bool { + return hasTerm(it.path) +} + +func (it *nodeIterator) LeafKey() []byte { + if len(it.stack) > 0 { + if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { + return hexToKeyBytes(it.path) + } + } + panic("not at leaf") +} + +func (it *nodeIterator) LeafBlob() []byte { + if len(it.stack) > 0 { + if node, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { + return node + } + } + panic("not at leaf") +} + +func (it *nodeIterator) LeafProof() [][]byte { + if len(it.stack) > 0 { + if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { + hasher := newHasher(false) + defer returnHasherToPool(hasher) + proofs := make([][]byte, 0, len(it.stack)) + + for i, item := range it.stack[:len(it.stack)-1] { + // Gather nodes that end up as hash nodes (or the root) + node, hashed := hasher.proofHash(item.node) + if _, ok := hashed.(hashNode); ok || i == 0 { + proofs = append(proofs, nodeToBytes(node)) + } + } + return proofs + } + } + panic("not at leaf") +} + +func (it *nodeIterator) Path() []byte { + return it.path +} + +func (it *nodeIterator) NodeBlob() []byte { + if it.Hash() == (common.Hash{}) { + return nil // skip the non-standalone node + } + blob, err := it.resolveBlob(it.Hash().Bytes(), it.Path()) + if err != nil { + it.err = err + return nil + } + return blob +} + +func (it *nodeIterator) Error() error { + if it.err == errIteratorEnd { + return nil + } + if seek, ok := it.err.(seekError); ok { + return seek.err + } + return it.err +} + +// Next moves the iterator to the next node, returning whether there are any +// further nodes. In case of an internal error this method returns false and +// sets the Error field to the encountered failure. If `descend` is false, +// skips iterating over any subnodes of the current node. +func (it *nodeIterator) Next(descend bool) bool { + if it.err == errIteratorEnd { + return false + } + if seek, ok := it.err.(seekError); ok { + if it.err = it.seek(seek.key); it.err != nil { + return false + } + } + // Otherwise step forward with the iterator and report any errors. + state, parentIndex, path, err := it.peek(descend) + it.err = err + if it.err != nil { + return false + } + it.push(state, parentIndex, path) + return true +} + +func (it *nodeIterator) seek(prefix []byte) error { + // The path we're looking for is the hex encoded key without terminator. + key := keybytesToHex(prefix) + key = key[:len(key)-1] + // Move forward until we're just before the closest match to key. + for { + state, parentIndex, path, err := it.peekSeek(key) + if err == errIteratorEnd { + return errIteratorEnd + } else if err != nil { + return seekError{prefix, err} + } else if bytes.Compare(path, key) >= 0 { + return nil + } + it.push(state, parentIndex, path) + } +} + +// init initializes the iterator. +func (it *nodeIterator) init() (*nodeIteratorState, error) { + root := it.trie.Hash() + state := &nodeIteratorState{node: it.trie.root, index: -1} + if root != emptyRoot { + state.hash = root + } + return state, state.resolve(it, nil) +} + +// peek creates the next state of the iterator. +func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, error) { + // Initialize the iterator if we've just started. + if len(it.stack) == 0 { + state, err := it.init() + return state, nil, nil, err + } + if !descend { + // If we're skipping children, pop the current node first + it.pop() + } + + // Continue iteration to the next child + for len(it.stack) > 0 { + parent := it.stack[len(it.stack)-1] + ancestor := parent.hash + if (ancestor == common.Hash{}) { + ancestor = parent.parent + } + state, path, ok := it.nextChild(parent, ancestor) + if ok { + if err := state.resolve(it, path); err != nil { + return parent, &parent.index, path, err + } + return state, &parent.index, path, nil + } + // No more child nodes, move back up. + it.pop() + } + return nil, nil, nil, errIteratorEnd +} + +// peekSeek is like peek, but it also tries to skip resolving hashes by skipping +// over the siblings that do not lead towards the desired seek position. +func (it *nodeIterator) peekSeek(seekKey []byte) (*nodeIteratorState, *int, []byte, error) { + // Initialize the iterator if we've just started. + if len(it.stack) == 0 { + state, err := it.init() + return state, nil, nil, err + } + if !bytes.HasPrefix(seekKey, it.path) { + // If we're skipping children, pop the current node first + it.pop() + } + + // Continue iteration to the next child + for len(it.stack) > 0 { + parent := it.stack[len(it.stack)-1] + ancestor := parent.hash + if (ancestor == common.Hash{}) { + ancestor = parent.parent + } + state, path, ok := it.nextChildAt(parent, ancestor, seekKey) + if ok { + if err := state.resolve(it, path); err != nil { + return parent, &parent.index, path, err + } + return state, &parent.index, path, nil + } + // No more child nodes, move back up. + it.pop() + } + return nil, nil, nil, errIteratorEnd +} + +func (it *nodeIterator) resolveHash(hash hashNode, path []byte) (node, error) { + if it.resolver != nil { + if blob, err := it.resolver.Get(hash); err == nil && len(blob) > 0 { + if resolved, err := decodeNode(hash, blob); err == nil { + return resolved, nil + } + } + } + return it.trie.resolveHash(hash, path) +} + +func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) { + if it.resolver != nil { + if blob, err := it.resolver.Get(hash); err == nil && len(blob) > 0 { + return blob, nil + } + } + return it.trie.resolveBlob(hash, path) +} + +func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error { + if hash, ok := st.node.(hashNode); ok { + resolved, err := it.resolveHash(hash, path) + if err != nil { + return err + } + st.node = resolved + st.hash = common.BytesToHash(hash) + } + return nil +} + +func findChild(n *fullNode, index int, path []byte, ancestor common.Hash) (node, *nodeIteratorState, []byte, int) { + var ( + child node + state *nodeIteratorState + childPath []byte + ) + for ; index < len(n.Children); index++ { + if n.Children[index] != nil { + child = n.Children[index] + hash, _ := child.cache() + state = &nodeIteratorState{ + hash: common.BytesToHash(hash), + node: child, + parent: ancestor, + index: -1, + pathlen: len(path), + } + childPath = append(childPath, path...) + childPath = append(childPath, byte(index)) + return child, state, childPath, index + } + } + return nil, nil, nil, 0 +} + +func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Hash) (*nodeIteratorState, []byte, bool) { + switch node := parent.node.(type) { + case *fullNode: + // Full node, move to the first non-nil child. + if child, state, path, index := findChild(node, parent.index+1, it.path, ancestor); child != nil { + parent.index = index - 1 + return state, path, true + } + case *shortNode: + // Short node, return the pointer singleton child + if parent.index < 0 { + hash, _ := node.Val.cache() + state := &nodeIteratorState{ + hash: common.BytesToHash(hash), + node: node.Val, + parent: ancestor, + index: -1, + pathlen: len(it.path), + } + path := append(it.path, node.Key...) + return state, path, true + } + } + return parent, it.path, false +} + +// nextChildAt is similar to nextChild, except that it targets a child as close to the +// target key as possible, thus skipping siblings. +func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.Hash, key []byte) (*nodeIteratorState, []byte, bool) { + switch n := parent.node.(type) { + case *fullNode: + // Full node, move to the first non-nil child before the desired key position + child, state, path, index := findChild(n, parent.index+1, it.path, ancestor) + if child == nil { + // No more children in this fullnode + return parent, it.path, false + } + // If the child we found is already past the seek position, just return it. + if bytes.Compare(path, key) >= 0 { + parent.index = index - 1 + return state, path, true + } + // The child is before the seek position. Try advancing + for { + nextChild, nextState, nextPath, nextIndex := findChild(n, index+1, it.path, ancestor) + // If we run out of children, or skipped past the target, return the + // previous one + if nextChild == nil || bytes.Compare(nextPath, key) >= 0 { + parent.index = index - 1 + return state, path, true + } + // We found a better child closer to the target + state, path, index = nextState, nextPath, nextIndex + } + case *shortNode: + // Short node, return the pointer singleton child + if parent.index < 0 { + hash, _ := n.Val.cache() + state := &nodeIteratorState{ + hash: common.BytesToHash(hash), + node: n.Val, + parent: ancestor, + index: -1, + pathlen: len(it.path), + } + path := append(it.path, n.Key...) + return state, path, true + } + } + return parent, it.path, false +} + +func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []byte) { + it.path = path + it.stack = append(it.stack, state) + if parentIndex != nil { + *parentIndex++ + } +} + +func (it *nodeIterator) pop() { + last := it.stack[len(it.stack)-1] + it.path = it.path[:last.pathlen] + it.stack[len(it.stack)-1] = nil + it.stack = it.stack[:len(it.stack)-1] +} diff --git a/bycid/trie/iterator_test.go b/bycid/trie/iterator_test.go new file mode 100644 index 0000000..9017d04 --- /dev/null +++ b/bycid/trie/iterator_test.go @@ -0,0 +1,248 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie_test + +import ( + "bytes" + "context" + "fmt" + "math/big" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + geth_state "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" + "github.com/ethereum/go-ethereum/statediff/indexer/ipld" + "github.com/ethereum/go-ethereum/statediff/test_helpers" + geth_trie "github.com/ethereum/go-ethereum/trie" + + "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0" + "github.com/cerc-io/ipld-eth-utils/bycid/state" + "github.com/cerc-io/ipld-eth-utils/bycid/trie" + "github.com/cerc-io/ipld-eth-utils/helper" +) + +type kvs struct { + k string + v int64 +} +type kvMap map[string]int64 + +var ( + cacheConfig = pgipfsethdb.CacheConfig{ + Name: "db", + Size: 3000000, // 3MB + ExpiryDuration: time.Hour, + } + dbConfig, _ = postgres.DefaultConfig.WithEnv() + trieConfig = trie.Config{Cache: 256} + + ctx = context.Background() +) + +var testdata1 = []kvs{ + {"barb", 0}, + {"bard", 1}, + {"bars", 2}, + {"bar", 3}, + {"fab", 4}, + {"food", 5}, + {"foos", 6}, + {"foo", 7}, +} + +var testdata2 = []kvs{ + {"aardvark", 8}, + {"bar", 9}, + {"barb", 10}, + {"bars", 11}, + {"fab", 12}, + {"foo", 13}, + {"foos", 14}, + {"food", 15}, + {"jars", 16}, +} + +func TestEmptyIterator(t *testing.T) { + trie := trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase())) + iter := trie.NodeIterator(nil) + + seen := make(map[string]struct{}) + for iter.Next(true) { + seen[string(iter.Path())] = struct{}{} + } + if len(seen) != 0 { + t.Fatal("Unexpected trie node iterated") + } +} + +func updateTrie(tr *geth_trie.Trie, vals []kvs) (kvMap, error) { + all := kvMap{} + for _, val := range vals { + all[val.k] = val.v + acct := &types.StateAccount{ + Balance: big.NewInt(val.v), + CodeHash: test_helpers.NullCodeHash.Bytes(), + Root: test_helpers.EmptyContractRoot, + } + acct_rlp, err := rlp.EncodeToBytes(acct) + if err != nil { + return nil, err + } + tr.Update([]byte(val.k), acct_rlp) + } + return all, nil +} + +func commitTrie(t *testing.T, db *geth_trie.Database, tr *geth_trie.Trie) common.Hash { + root, nodes, err := tr.Commit(false) + if err != nil { + t.Fatalf("Failed to commit trie %v", err) + } + if err = db.Update(geth_trie.NewWithNodeSet(nodes)); err != nil { + t.Fatal(err) + } + if err = db.Commit(root, false, nil); err != nil { + t.Fatal(err) + } + return root +} + +// commit a LevelDB state trie, index to IPLD and return new trie +func indexTrie(t *testing.T, edb ethdb.Database, root common.Hash) *trie.Trie { + err := helper.IndexChain(dbConfig, geth_state.NewDatabase(edb), common.Hash{}, root) + if err != nil { + t.Fatal(err) + } + + pg_db, err := postgres.ConnectSQLX(ctx, dbConfig) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := TearDownDB(pg_db); err != nil { + t.Fatal(err) + } + }) + + ipfs_db := pgipfsethdb.NewDatabase(pg_db, makeCacheConfig(t)) + sdb_db := state.NewDatabase(ipfs_db) + tr, err := trie.New(common.Hash{}, root, sdb_db.TrieDB(), ipld.MEthStateTrie) + if err != nil { + t.Fatal(err) + } + return tr +} + +func TestIterator(t *testing.T) { + edb := rawdb.NewMemoryDatabase() + db := geth_trie.NewDatabase(edb) + origtrie := geth_trie.NewEmpty(db) + vals := []kvs{ + {"one", 1}, + {"two", 2}, + {"three", 3}, + {"four", 4}, + {"five", 5}, + {"ten", 10}, + } + all, err := updateTrie(origtrie, vals) + if err != nil { + t.Fatal(err) + } + // commit and index data + root := commitTrie(t, db, origtrie) + tr := indexTrie(t, edb, root) + + found := make(map[string]int64) + it := trie.NewIterator(tr.NodeIterator(nil)) + for it.Next() { + var acct types.StateAccount + if err := rlp.DecodeBytes(it.Value, &acct); err != nil { + t.Fatal(err) + } + found[string(it.Key)] = acct.Balance.Int64() + } + + if len(found) != len(all) { + t.Errorf("number of iterated values do not match: want %d, found %d", len(all), len(found)) + } + for k, v := range all { + if found[k] != v { + t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v) + } + } +} + +func checkIteratorOrder(want []kvs, it *trie.Iterator) error { + for it.Next() { + if len(want) == 0 { + return fmt.Errorf("didn't expect any more values, got key %q", it.Key) + } + if !bytes.Equal(it.Key, []byte(want[0].k)) { + return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k) + } + want = want[1:] + } + if len(want) > 0 { + return fmt.Errorf("iterator ended early, want key %q", want[0]) + } + return nil +} + +func TestIteratorSeek(t *testing.T) { + edb := rawdb.NewMemoryDatabase() + db := geth_trie.NewDatabase(edb) + orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase())) + if _, err := updateTrie(orig, testdata1); err != nil { + t.Fatal(err) + } + root := commitTrie(t, db, orig) + tr := indexTrie(t, edb, root) + + // Seek to the middle. + it := trie.NewIterator(tr.NodeIterator([]byte("fab"))) + if err := checkIteratorOrder(testdata1[4:], it); err != nil { + t.Fatal(err) + } + + // Seek to a non-existent key. + it = trie.NewIterator(tr.NodeIterator([]byte("barc"))) + if err := checkIteratorOrder(testdata1[1:], it); err != nil { + t.Fatal(err) + } + + // Seek beyond the end. + it = trie.NewIterator(tr.NodeIterator([]byte("z"))) + if err := checkIteratorOrder(nil, it); err != nil { + t.Fatal(err) + } +} + +// returns a cache config with unique name (groupcache names are global) +func makeCacheConfig(t *testing.T) pgipfsethdb.CacheConfig { + return pgipfsethdb.CacheConfig{ + Name: t.Name(), + Size: 3000000, // 3MB + ExpiryDuration: time.Hour, + } +} diff --git a/bycid/trie/proof.go b/bycid/trie/proof.go index 0a4eea9..d8416e1 100644 --- a/bycid/trie/proof.go +++ b/bycid/trie/proof.go @@ -103,34 +103,34 @@ func (t *StateTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWrit return t.trie.Prove(key, fromLevel, proofDb) } -// VerifyProof checks merkle proofs. The given proof must contain the value for -// key in a trie with the given root hash. VerifyProof returns an error if the -// proof contains invalid trie nodes or the wrong value. -func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) { - key = keybytesToHex(key) - wantHash := rootHash - for i := 0; ; i++ { - buf, _ := proofDb.Get(wantHash[:]) - if buf == nil { - return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash) - } - n, err := decodeNode(wantHash[:], buf) - if err != nil { - return nil, fmt.Errorf("bad proof node %d: %v", i, err) - } - keyrest, cld := get(n, key, true) - switch cld := cld.(type) { - case nil: - // The trie doesn't contain the key. - return nil, nil - case hashNode: - key = keyrest - copy(wantHash[:], cld) - case valueNode: - return cld, nil - } - } -} +// // VerifyProof checks merkle proofs. The given proof must contain the value for +// // key in a trie with the given root hash. VerifyProof returns an error if the +// // proof contains invalid trie nodes or the wrong value. +// func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) { +// key = keybytesToHex(key) +// wantHash := rootHash +// for i := 0; ; i++ { +// buf, _ := proofDb.Get(wantHash[:]) +// if buf == nil { +// return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash) +// } +// n, err := decodeNode(wantHash[:], buf) +// if err != nil { +// return nil, fmt.Errorf("bad proof node %d: %v", i, err) +// } +// keyrest, cld := get(n, key, true) +// switch cld := cld.(type) { +// case nil: +// // The trie doesn't contain the key. +// return nil, nil +// case hashNode: +// key = keyrest +// copy(wantHash[:], cld) +// case valueNode: +// return cld, nil +// } +// } +// } // proofToPath converts a merkle proof to trie node path. The main purpose of // this function is recovering a node path from the merkle proof stream. All @@ -338,9 +338,9 @@ findFork: // unset removes all internal node references either the left most or right most. // It can meet these scenarios: // -// - The given path is existent in the trie, unset the associated nodes with the -// specific direction -// - The given path is non-existent in the trie +// - The given path is existent in the trie, unset the associated nodes with the +// specific direction +// - The given path is non-existent in the trie // - the fork point is a fullnode, the corresponding child pointed by path // is nil, return // - the fork point is a shortnode, the shortnode is included in the range, diff --git a/bycid/trie/trie.go b/bycid/trie/trie.go index 4ee793d..6adb08b 100644 --- a/bycid/trie/trie.go +++ b/bycid/trie/trie.go @@ -72,6 +72,12 @@ func NewEmpty(db *Database) *Trie { return tr } +// NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at +// the key after the given start key. +func (t *Trie) NodeIterator(start []byte) NodeIterator { + return newNodeIterator(t, start) +} + // TryGet returns the value for key stored in the trie. // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. @@ -133,6 +139,17 @@ func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix} } +// resolveHash loads rlp-encoded node blob from the underlying database +// with the provided node hash and path prefix. +func (t *Trie) resolveBlob(n hashNode, prefix []byte) ([]byte, error) { + cid := ipld.Keccak256ToCid(t.codec, n) + blob, _ := t.db.Node(cid.Bytes()) + if len(blob) != 0 { + return blob, nil + } + return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix} +} + // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash { diff --git a/bycid/trie/util_test.go b/bycid/trie/util_test.go new file mode 100644 index 0000000..9a26162 --- /dev/null +++ b/bycid/trie/util_test.go @@ -0,0 +1,32 @@ +package trie_test + +import ( + "fmt" + "github.com/jmoiron/sqlx" +) + +// TearDownDB is used to tear down the watcher dbs after tests +func TearDownDB(db *sqlx.DB) error { + tx, err := db.Beginx() + if err != nil { + return err + } + statements := []string{ + `DELETE FROM nodes`, + `DELETE FROM ipld.blocks`, + `DELETE FROM eth.header_cids`, + `DELETE FROM eth.uncle_cids`, + `DELETE FROM eth.transaction_cids`, + `DELETE FROM eth.receipt_cids`, + `DELETE FROM eth.state_cids`, + `DELETE FROM eth.storage_cids`, + `DELETE FROM eth.log_cids`, + `DELETE FROM eth_meta.watched_addresses`, + } + for _, stm := range statements { + if _, err = tx.Exec(stm); err != nil { + return fmt.Errorf("error executing `%s`: %w", stm, err) + } + } + return tx.Commit() +} diff --git a/helper/hasher.go b/helper/hasher.go new file mode 100644 index 0000000..e914fc0 --- /dev/null +++ b/helper/hasher.go @@ -0,0 +1,31 @@ +package helper + +import ( + "hash" + + "github.com/ethereum/go-ethereum/common" + "golang.org/x/crypto/sha3" +) + +// testHasher (copied from go-ethereum/core/types/block_test.go) +// satisfies types.TrieHasher +type testHasher struct { + hasher hash.Hash +} + +func NewHasher() *testHasher { + return &testHasher{hasher: sha3.NewLegacyKeccak256()} +} + +func (h *testHasher) Reset() { + h.hasher.Reset() +} + +func (h *testHasher) Update(key, val []byte) { + h.hasher.Write(key) + h.hasher.Write(val) +} + +func (h *testHasher) Hash() common.Hash { + return common.BytesToHash(h.hasher.Sum(nil)) +} diff --git a/helper/statediff_helper.go b/helper/statediff_helper.go new file mode 100644 index 0000000..947199a --- /dev/null +++ b/helper/statediff_helper.go @@ -0,0 +1,81 @@ +package helper + +import ( + "context" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/statediff" + "github.com/ethereum/go-ethereum/statediff/indexer" + "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" + "github.com/ethereum/go-ethereum/statediff/indexer/node" +) + +var ( + // ChainDB = rawdb.NewMemoryDatabase() + ChainConfig = params.TestChainConfig + // BankFunds = new(big.Int).Mul(big.NewInt(1e4), big.NewInt(params.Ether)) // i.e. 10,000eth + + mockTD = big.NewInt(1) + // ctx = context.Background() + // signer = types.NewLondonSigner(ChainConfig.ChainID) +) + +func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, rootB common.Hash) error { + _, indexer, err := indexer.NewStateDiffIndexer( + context.Background(), + ChainConfig, + node.Info{}, + // node.Info{ + // GenesisBlock: Genesis.Hash().String(), + // NetworkID: "test_network", + // ID: "test_node", + // ClientName: "geth", + // ChainID: ChainConfig.ChainID.Uint64(), + // }, + dbConfig) + if err != nil { + return err + } + defer indexer.Close() // fixme: hangs when using PGX driver + + // generating statediff payload for block, and transform the data into Postgres + builder := statediff.NewBuilder(stateCache) + block := types.NewBlock(&types.Header{Root: rootB}, nil, nil, nil, NewHasher()) + + // todo: use dummy block hashes to just produce trie structure for testing + args := statediff.Args{ + OldStateRoot: rootA, + NewStateRoot: rootB, + // BlockNumber: block.Number(), + // BlockHash: block.Hash(), + } + diff, err := builder.BuildStateDiffObject(args, statediff.Params{}) + if err != nil { + return err + } + tx, err := indexer.PushBlock(block, nil, mockTD) + if err != nil { + return err + } + // for _, node := range diff.Nodes { + // err := indexer.PushStateNode(tx, node, block.Hash().String()) + // if err != nil { + // return err + // } + // } + for _, ipld := range diff.IPLDs { + if err := indexer.PushIPLD(tx, ipld); err != nil { + return err + } + } + return tx.Submit(err) + + // if err = tx.Submit(err); err != nil { + // return err + // } + // return nil +}