From b6ef6d4e121d39d8a455913b6bf5357e7d911ef2 Mon Sep 17 00:00:00 2001 From: Roy Crihfield Date: Mon, 24 Apr 2023 18:04:13 +0800 Subject: [PATCH] Proof tests (#17) * Port proof tests from geth * Scale trie size, seed rand --- trie_by_cid/helper/statediff_helper.go | 23 +- trie_by_cid/trie/database_test.go | 2 +- trie_by_cid/trie/iterator_test.go | 98 +-- trie_by_cid/trie/proof.go | 377 +-------- trie_by_cid/trie/proof_test.go | 1078 ++++++++++++++++++++++++ trie_by_cid/trie/util_test.go | 139 +++ 6 files changed, 1228 insertions(+), 489 deletions(-) create mode 100644 trie_by_cid/trie/proof_test.go diff --git a/trie_by_cid/helper/statediff_helper.go b/trie_by_cid/helper/statediff_helper.go index 947199a..b03abf6 100644 --- a/trie_by_cid/helper/statediff_helper.go +++ b/trie_by_cid/helper/statediff_helper.go @@ -15,28 +15,14 @@ import ( ) var ( - // ChainDB = rawdb.NewMemoryDatabase() ChainConfig = params.TestChainConfig - // BankFunds = new(big.Int).Mul(big.NewInt(1e4), big.NewInt(params.Ether)) // i.e. 10,000eth mockTD = big.NewInt(1) - // ctx = context.Background() - // signer = types.NewLondonSigner(ChainConfig.ChainID) ) func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, rootB common.Hash) error { _, indexer, err := indexer.NewStateDiffIndexer( - context.Background(), - ChainConfig, - node.Info{}, - // node.Info{ - // GenesisBlock: Genesis.Hash().String(), - // NetworkID: "test_network", - // ID: "test_node", - // ClientName: "geth", - // ChainID: ChainConfig.ChainID.Uint64(), - // }, - dbConfig) + context.Background(), ChainConfig, node.Info{}, dbConfig) if err != nil { return err } @@ -50,8 +36,6 @@ func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, root args := statediff.Args{ OldStateRoot: rootA, NewStateRoot: rootB, - // BlockNumber: block.Number(), - // BlockHash: block.Hash(), } diff, err := builder.BuildStateDiffObject(args, statediff.Params{}) if err != nil { @@ -73,9 +57,4 @@ func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, root } } return tx.Submit(err) - - // if err = tx.Submit(err); err != nil { - // return err - // } - // return nil } diff --git a/trie_by_cid/trie/database_test.go b/trie_by_cid/trie/database_test.go index 156ff18..fce88db 100644 --- a/trie_by_cid/trie/database_test.go +++ b/trie_by_cid/trie/database_test.go @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package trie +package trie_test import ( "testing" diff --git a/trie_by_cid/trie/iterator_test.go b/trie_by_cid/trie/iterator_test.go index 1078b38..4df2ec4 100644 --- a/trie_by_cid/trie/iterator_test.go +++ b/trie_by_cid/trie/iterator_test.go @@ -20,33 +20,17 @@ import ( "bytes" "context" "fmt" - "math/big" "testing" "time" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" - geth_state "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" - "github.com/ethereum/go-ethereum/statediff/indexer/ipld" - "github.com/ethereum/go-ethereum/statediff/test_helpers" geth_trie "github.com/ethereum/go-ethereum/trie" pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0" - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper" - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/state" "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" ) -type kvs struct { - k string - v int64 -} -type kvMap map[string]int64 - var ( cacheConfig = pgipfsethdb.CacheConfig{ Name: "db", @@ -70,18 +54,6 @@ var testdata1 = []kvs{ {"foo", 7}, } -var testdata2 = []kvs{ - {"aardvark", 8}, - {"bar", 9}, - {"barb", 10}, - {"bars", 11}, - {"fab", 12}, - {"foo", 13}, - {"foos", 14}, - {"food", 15}, - {"jars", 16}, -} - func TestEmptyIterator(t *testing.T) { trie := trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase())) iter := trie.NodeIterator(nil) @@ -95,62 +67,6 @@ func TestEmptyIterator(t *testing.T) { } } -func updateTrie(tr *geth_trie.Trie, vals []kvs) (kvMap, error) { - all := kvMap{} - for _, val := range vals { - all[val.k] = val.v - acct := &types.StateAccount{ - Balance: big.NewInt(val.v), - CodeHash: test_helpers.NullCodeHash.Bytes(), - Root: test_helpers.EmptyContractRoot, - } - acct_rlp, err := rlp.EncodeToBytes(acct) - if err != nil { - return nil, err - } - tr.Update([]byte(val.k), acct_rlp) - } - return all, nil -} - -func commitTrie(t *testing.T, db *geth_trie.Database, tr *geth_trie.Trie) common.Hash { - root, nodes := tr.Commit(false) - if err := db.Update(geth_trie.NewWithNodeSet(nodes)); err != nil { - t.Fatal(err) - } - if err := db.Commit(root, false); err != nil { - t.Fatal(err) - } - return root -} - -// commit a LevelDB state trie, index to IPLD and return new trie -func indexTrie(t *testing.T, edb ethdb.Database, root common.Hash) *trie.Trie { - dbConfig.Driver = postgres.PGX - err := helper.IndexChain(dbConfig, geth_state.NewDatabase(edb), common.Hash{}, root) - if err != nil { - t.Fatal(err) - } - - pg_db, err := postgres.ConnectSQLX(ctx, dbConfig) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - if err := TearDownDB(pg_db); err != nil { - t.Fatal(err) - } - }) - - ipfs_db := pgipfsethdb.NewDatabase(pg_db, makeCacheConfig(t)) - sdb_db := state.NewDatabase(ipfs_db) - tr, err := trie.New(common.Hash{}, root, sdb_db.TrieDB(), ipld.MEthStateTrie) - if err != nil { - t.Fatal(err) - } - return tr -} - func TestIterator(t *testing.T) { edb := rawdb.NewMemoryDatabase() db := geth_trie.NewDatabase(edb) @@ -174,19 +90,15 @@ func TestIterator(t *testing.T) { found := make(map[string]int64) it := trie.NewIterator(tr.NodeIterator(nil)) for it.Next() { - var acct types.StateAccount - if err := rlp.DecodeBytes(it.Value, &acct); err != nil { - t.Fatal(err) - } - found[string(it.Key)] = acct.Balance.Int64() + found[string(it.Key)] = unpackValue(it.Value) } if len(found) != len(all) { t.Errorf("number of iterated values do not match: want %d, found %d", len(all), len(found)) } - for k, v := range all { - if found[k] != v { - t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v) + for k, kv := range all { + if found[k] != kv.v { + t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], kv.v) } } } @@ -237,7 +149,7 @@ func TestIteratorSeek(t *testing.T) { } // returns a cache config with unique name (groupcache names are global) -func makeCacheConfig(t *testing.T) pgipfsethdb.CacheConfig { +func makeCacheConfig(t testing.TB) pgipfsethdb.CacheConfig { return pgipfsethdb.CacheConfig{ Name: t.Name(), Size: 3000000, // 3MB diff --git a/trie_by_cid/trie/proof.go b/trie_by_cid/trie/proof.go index d8416e1..e48eda5 100644 --- a/trie_by_cid/trie/proof.go +++ b/trie_by_cid/trie/proof.go @@ -18,14 +18,16 @@ package trie import ( "bytes" - "errors" "fmt" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/trie" ) +var VerifyProof = trie.VerifyProof +var VerifyRangeProof = trie.VerifyRangeProof + // Prove constructs a merkle proof for key. The result contains all encoded nodes // on the path to the value at key. The value itself is also included in the last // node and can be retrieved by verifying the proof. @@ -102,374 +104,3 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e func (t *StateTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error { return t.trie.Prove(key, fromLevel, proofDb) } - -// // VerifyProof checks merkle proofs. The given proof must contain the value for -// // key in a trie with the given root hash. VerifyProof returns an error if the -// // proof contains invalid trie nodes or the wrong value. -// func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) { -// key = keybytesToHex(key) -// wantHash := rootHash -// for i := 0; ; i++ { -// buf, _ := proofDb.Get(wantHash[:]) -// if buf == nil { -// return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash) -// } -// n, err := decodeNode(wantHash[:], buf) -// if err != nil { -// return nil, fmt.Errorf("bad proof node %d: %v", i, err) -// } -// keyrest, cld := get(n, key, true) -// switch cld := cld.(type) { -// case nil: -// // The trie doesn't contain the key. -// return nil, nil -// case hashNode: -// key = keyrest -// copy(wantHash[:], cld) -// case valueNode: -// return cld, nil -// } -// } -// } - -// proofToPath converts a merkle proof to trie node path. The main purpose of -// this function is recovering a node path from the merkle proof stream. All -// necessary nodes will be resolved and leave the remaining as hashnode. -// -// The given edge proof is allowed to be an existent or non-existent proof. -func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyValueReader, allowNonExistent bool) (node, []byte, error) { - // resolveNode retrieves and resolves trie node from merkle proof stream - resolveNode := func(hash common.Hash) (node, error) { - buf, _ := proofDb.Get(hash[:]) - if buf == nil { - return nil, fmt.Errorf("proof node (hash %064x) missing", hash) - } - n, err := decodeNode(hash[:], buf) - if err != nil { - return nil, fmt.Errorf("bad proof node %v", err) - } - return n, err - } - // If the root node is empty, resolve it first. - // Root node must be included in the proof. - if root == nil { - n, err := resolveNode(rootHash) - if err != nil { - return nil, nil, err - } - root = n - } - var ( - err error - child, parent node - keyrest []byte - valnode []byte - ) - key, parent = keybytesToHex(key), root - for { - keyrest, child = get(parent, key, false) - switch cld := child.(type) { - case nil: - // The trie doesn't contain the key. It's possible - // the proof is a non-existing proof, but at least - // we can prove all resolved nodes are correct, it's - // enough for us to prove range. - if allowNonExistent { - return root, nil, nil - } - return nil, nil, errors.New("the node is not contained in trie") - case *shortNode: - key, parent = keyrest, child // Already resolved - continue - case *fullNode: - key, parent = keyrest, child // Already resolved - continue - case hashNode: - child, err = resolveNode(common.BytesToHash(cld)) - if err != nil { - return nil, nil, err - } - case valueNode: - valnode = cld - } - // Link the parent and child. - switch pnode := parent.(type) { - case *shortNode: - pnode.Val = child - case *fullNode: - pnode.Children[key[0]] = child - default: - panic(fmt.Sprintf("%T: invalid node: %v", pnode, pnode)) - } - if len(valnode) > 0 { - return root, valnode, nil // The whole path is resolved - } - key, parent = keyrest, child - } -} - -// unsetInternal removes all internal node references(hashnode, embedded node). -// It should be called after a trie is constructed with two edge paths. Also -// the given boundary keys must be the one used to construct the edge paths. -// -// It's the key step for range proof. All visited nodes should be marked dirty -// since the node content might be modified. Besides it can happen that some -// fullnodes only have one child which is disallowed. But if the proof is valid, -// the missing children will be filled, otherwise it will be thrown anyway. -// -// Note we have the assumption here the given boundary keys are different -// and right is larger than left. -func unsetInternal(n node, left []byte, right []byte) (bool, error) { - left, right = keybytesToHex(left), keybytesToHex(right) - - // Step down to the fork point. There are two scenarios can happen: - // - the fork point is a shortnode: either the key of left proof or - // right proof doesn't match with shortnode's key. - // - the fork point is a fullnode: both two edge proofs are allowed - // to point to a non-existent key. - var ( - pos = 0 - parent node - - // fork indicator, 0 means no fork, -1 means proof is less, 1 means proof is greater - shortForkLeft, shortForkRight int - ) -findFork: - for { - switch rn := (n).(type) { - case *shortNode: - rn.flags = nodeFlag{dirty: true} - - // If either the key of left proof or right proof doesn't match with - // shortnode, stop here and the forkpoint is the shortnode. - if len(left)-pos < len(rn.Key) { - shortForkLeft = bytes.Compare(left[pos:], rn.Key) - } else { - shortForkLeft = bytes.Compare(left[pos:pos+len(rn.Key)], rn.Key) - } - if len(right)-pos < len(rn.Key) { - shortForkRight = bytes.Compare(right[pos:], rn.Key) - } else { - shortForkRight = bytes.Compare(right[pos:pos+len(rn.Key)], rn.Key) - } - if shortForkLeft != 0 || shortForkRight != 0 { - break findFork - } - parent = n - n, pos = rn.Val, pos+len(rn.Key) - case *fullNode: - rn.flags = nodeFlag{dirty: true} - - // If either the node pointed by left proof or right proof is nil, - // stop here and the forkpoint is the fullnode. - leftnode, rightnode := rn.Children[left[pos]], rn.Children[right[pos]] - if leftnode == nil || rightnode == nil || leftnode != rightnode { - break findFork - } - parent = n - n, pos = rn.Children[left[pos]], pos+1 - default: - panic(fmt.Sprintf("%T: invalid node: %v", n, n)) - } - } - switch rn := n.(type) { - case *shortNode: - // There can have these five scenarios: - // - both proofs are less than the trie path => no valid range - // - both proofs are greater than the trie path => no valid range - // - left proof is less and right proof is greater => valid range, unset the shortnode entirely - // - left proof points to the shortnode, but right proof is greater - // - right proof points to the shortnode, but left proof is less - if shortForkLeft == -1 && shortForkRight == -1 { - return false, errors.New("empty range") - } - if shortForkLeft == 1 && shortForkRight == 1 { - return false, errors.New("empty range") - } - if shortForkLeft != 0 && shortForkRight != 0 { - // The fork point is root node, unset the entire trie - if parent == nil { - return true, nil - } - parent.(*fullNode).Children[left[pos-1]] = nil - return false, nil - } - // Only one proof points to non-existent key. - if shortForkRight != 0 { - if _, ok := rn.Val.(valueNode); ok { - // The fork point is root node, unset the entire trie - if parent == nil { - return true, nil - } - parent.(*fullNode).Children[left[pos-1]] = nil - return false, nil - } - return false, unset(rn, rn.Val, left[pos:], len(rn.Key), false) - } - if shortForkLeft != 0 { - if _, ok := rn.Val.(valueNode); ok { - // The fork point is root node, unset the entire trie - if parent == nil { - return true, nil - } - parent.(*fullNode).Children[right[pos-1]] = nil - return false, nil - } - return false, unset(rn, rn.Val, right[pos:], len(rn.Key), true) - } - return false, nil - case *fullNode: - // unset all internal nodes in the forkpoint - for i := left[pos] + 1; i < right[pos]; i++ { - rn.Children[i] = nil - } - if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil { - return false, err - } - if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil { - return false, err - } - return false, nil - default: - panic(fmt.Sprintf("%T: invalid node: %v", n, n)) - } -} - -// unset removes all internal node references either the left most or right most. -// It can meet these scenarios: -// -// - The given path is existent in the trie, unset the associated nodes with the -// specific direction -// - The given path is non-existent in the trie -// - the fork point is a fullnode, the corresponding child pointed by path -// is nil, return -// - the fork point is a shortnode, the shortnode is included in the range, -// keep the entire branch and return. -// - the fork point is a shortnode, the shortnode is excluded in the range, -// unset the entire branch. -func unset(parent node, child node, key []byte, pos int, removeLeft bool) error { - switch cld := child.(type) { - case *fullNode: - if removeLeft { - for i := 0; i < int(key[pos]); i++ { - cld.Children[i] = nil - } - cld.flags = nodeFlag{dirty: true} - } else { - for i := key[pos] + 1; i < 16; i++ { - cld.Children[i] = nil - } - cld.flags = nodeFlag{dirty: true} - } - return unset(cld, cld.Children[key[pos]], key, pos+1, removeLeft) - case *shortNode: - if len(key[pos:]) < len(cld.Key) || !bytes.Equal(cld.Key, key[pos:pos+len(cld.Key)]) { - // Find the fork point, it's an non-existent branch. - if removeLeft { - if bytes.Compare(cld.Key, key[pos:]) < 0 { - // The key of fork shortnode is less than the path - // (it belongs to the range), unset the entrie - // branch. The parent must be a fullnode. - fn := parent.(*fullNode) - fn.Children[key[pos-1]] = nil - } - //else { - // The key of fork shortnode is greater than the - // path(it doesn't belong to the range), keep - // it with the cached hash available. - //} - } else { - if bytes.Compare(cld.Key, key[pos:]) > 0 { - // The key of fork shortnode is greater than the - // path(it belongs to the range), unset the entrie - // branch. The parent must be a fullnode. - fn := parent.(*fullNode) - fn.Children[key[pos-1]] = nil - } - //else { - // The key of fork shortnode is less than the - // path(it doesn't belong to the range), keep - // it with the cached hash available. - //} - } - return nil - } - if _, ok := cld.Val.(valueNode); ok { - fn := parent.(*fullNode) - fn.Children[key[pos-1]] = nil - return nil - } - cld.flags = nodeFlag{dirty: true} - return unset(cld, cld.Val, key, pos+len(cld.Key), removeLeft) - case nil: - // If the node is nil, then it's a child of the fork point - // fullnode(it's a non-existent branch). - return nil - default: - panic("it shouldn't happen") // hashNode, valueNode - } -} - -// hasRightElement returns the indicator whether there exists more elements -// on the right side of the given path. The given path can point to an existent -// key or a non-existent one. This function has the assumption that the whole -// path should already be resolved. -func hasRightElement(node node, key []byte) bool { - pos, key := 0, keybytesToHex(key) - for node != nil { - switch rn := node.(type) { - case *fullNode: - for i := key[pos] + 1; i < 16; i++ { - if rn.Children[i] != nil { - return true - } - } - node, pos = rn.Children[key[pos]], pos+1 - case *shortNode: - if len(key)-pos < len(rn.Key) || !bytes.Equal(rn.Key, key[pos:pos+len(rn.Key)]) { - return bytes.Compare(rn.Key, key[pos:]) > 0 - } - node, pos = rn.Val, pos+len(rn.Key) - case valueNode: - return false // We have resolved the whole path - default: - panic(fmt.Sprintf("%T: invalid node: %v", node, node)) // hashnode - } - } - return false -} - -// get returns the child of the given node. Return nil if the -// node with specified key doesn't exist at all. -// -// There is an additional flag `skipResolved`. If it's set then -// all resolved nodes won't be returned. -func get(tn node, key []byte, skipResolved bool) ([]byte, node) { - for { - switch n := tn.(type) { - case *shortNode: - if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { - return nil, nil - } - tn = n.Val - key = key[len(n.Key):] - if !skipResolved { - return key, tn - } - case *fullNode: - tn = n.Children[key[0]] - key = key[1:] - if !skipResolved { - return key, tn - } - case hashNode: - return key, n - case nil: - return key, nil - case valueNode: - return nil, n - default: - panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) - } - } -} diff --git a/trie_by_cid/trie/proof_test.go b/trie_by_cid/trie/proof_test.go new file mode 100644 index 0000000..3fd508a --- /dev/null +++ b/trie_by_cid/trie/proof_test.go @@ -0,0 +1,1078 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie_test + +import ( + "bytes" + mrand "math/rand" + "sort" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + geth_trie "github.com/ethereum/go-ethereum/trie" + + . "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" +) + +// tweakable trie size +var scaleFactor = 512 + +func init() { + mrand.Seed(time.Now().UnixNano()) +} + +// makeProvers creates Merkle trie provers based on different implementations to +// test all variations. +func makeProvers(trie *Trie) []func(key []byte) *memorydb.Database { + var provers []func(key []byte) *memorydb.Database + + // Create a direct trie based Merkle prover + provers = append(provers, func(key []byte) *memorydb.Database { + proof := memorydb.New() + trie.Prove(key, 0, proof) + return proof + }) + // Create a leaf iterator based Merkle prover + provers = append(provers, func(key []byte) *memorydb.Database { + proof := memorydb.New() + if it := NewIterator(trie.NodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) { + for _, p := range it.Prove() { + proof.Put(crypto.Keccak256(p), p) + } + } + return proof + }) + return provers +} + +func TestProof(t *testing.T) { + trie, vals := randomTrie(t, scaleFactor) + root := trie.Hash() + for i, prover := range makeProvers(trie) { + for _, kv := range vals { + proof := prover(kv.k) + if proof == nil { + t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k) + } + val, err := geth_trie.VerifyProof(root, kv.k, proof) + if err != nil { + t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof) + } + if kv.v != unpackValue(val) { + t.Fatalf("prover %d: verified value mismatch for key %x: have %x, want %x", i, kv.k, val, kv.v) + } + } + } +} + +func TestOneElementProof(t *testing.T) { + edb := rawdb.NewMemoryDatabase() + db := geth_trie.NewDatabase(edb) + orig := geth_trie.NewEmpty(db) + orig.Update([]byte("k"), packValue(42)) + root := commitTrie(t, db, orig) + trie := indexTrie(t, edb, root) + + for i, prover := range makeProvers(trie) { + proof := prover([]byte("k")) + if proof == nil { + t.Fatalf("prover %d: nil proof", i) + } + if proof.Len() != 1 { + t.Errorf("prover %d: proof should have one element", i) + } + val, err := geth_trie.VerifyProof(trie.Hash(), []byte("k"), proof) + if err != nil { + t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) + } + if 42 != unpackValue(val) { + t.Fatalf("prover %d: verified value mismatch: have %x, want 'k'", i, val) + } + } +} + +func TestBadProof(t *testing.T) { + trie, vals := randomTrie(t, 2*scaleFactor) + root := trie.Hash() + for i, prover := range makeProvers(trie) { + for _, kv := range vals { + proof := prover(kv.k) + if proof == nil { + t.Fatalf("prover %d: nil proof", i) + } + it := proof.NewIterator(nil, nil) + for i, d := 0, mrand.Intn(proof.Len()); i <= d; i++ { + it.Next() + } + key := it.Key() + val, _ := proof.Get(key) + proof.Delete(key) + it.Release() + + mutateByte(val) + proof.Put(crypto.Keccak256(val), val) + + if _, err := VerifyProof(root, kv.k, proof); err == nil { + t.Fatalf("prover %d: expected proof to fail for key %x", i, kv.k) + } + } + } +} + +// Tests that missing keys can also be proven. The test explicitly uses a single +// entry trie and checks for missing keys both before and after the single entry. +func TestMissingKeyProof(t *testing.T) { + edb := rawdb.NewMemoryDatabase() + db := geth_trie.NewDatabase(edb) + orig := geth_trie.NewEmpty(db) + orig.Update([]byte("k"), packValue(42)) + root := commitTrie(t, db, orig) + trie := indexTrie(t, edb, root) + + for i, key := range []string{"a", "j", "l", "z"} { + proof := memorydb.New() + trie.Prove([]byte(key), 0, proof) + + if proof.Len() != 1 { + t.Errorf("test %d: proof should have one element", i) + } + val, err := VerifyProof(trie.Hash(), []byte(key), proof) + if err != nil { + t.Fatalf("test %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) + } + if val != nil { + t.Fatalf("test %d: verified value mismatch: have %x, want nil", i, val) + } + } +} + +type entry struct { + k, v []byte +} + +func packEntry(kv *kv) *entry { + return &entry{kv.k, packValue(kv.v)} +} + +type entrySlice []*entry + +func (p entrySlice) Len() int { return len(p) } +func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 } +func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +// TestRangeProof tests normal range proof with both edge proofs +// as the existent proof. The test cases are generated randomly. +func TestRangeProof(t *testing.T) { + trie, vals := randomTrie(t, 8*scaleFactor) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, packEntry(kv)) + } + sort.Sort(entries) + for i := 0; i < 500; i++ { + start := mrand.Intn(len(entries)) + end := mrand.Intn(len(entries)-start) + start + 1 + + proof := memorydb.New() + if err := trie.Prove(entries[start].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + var keys [][]byte + var vals [][]byte + for i := start; i < end; i++ { + keys = append(keys, entries[i].k) + vals = append(vals, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) + if err != nil { + t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) + } + } +} + +// TestRangeProof tests normal range proof with two non-existent proofs. +// The test cases are generated randomly. +func TestRangeProofWithNonExistentProof(t *testing.T) { + trie, vals := randomTrie(t, 8*scaleFactor) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, packEntry(kv)) + } + sort.Sort(entries) + for i := 0; i < 500; i++ { + start := mrand.Intn(len(entries)) + end := mrand.Intn(len(entries)-start) + start + 1 + proof := memorydb.New() + + // Short circuit if the decreased key is same with the previous key + first := decreaseKey(common.CopyBytes(entries[start].k)) + if start != 0 && bytes.Equal(first, entries[start-1].k) { + continue + } + // Short circuit if the decreased key is underflow + if bytes.Compare(first, entries[start].k) > 0 { + continue + } + // Short circuit if the increased key is same with the next key + last := increaseKey(common.CopyBytes(entries[end-1].k)) + if end != len(entries) && bytes.Equal(last, entries[end].k) { + continue + } + // Short circuit if the increased key is overflow + if bytes.Compare(last, entries[end-1].k) < 0 { + continue + } + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + var keys [][]byte + var vals [][]byte + for i := start; i < end; i++ { + keys = append(keys, entries[i].k) + vals = append(vals, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) + if err != nil { + t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) + } + } + // Special case, two edge proofs for two edge key. + proof := memorydb.New() + first := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() + last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes() + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + var k [][]byte + var v [][]byte + for i := 0; i < len(entries); i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof) + if err != nil { + t.Fatal("Failed to verify whole rang with non-existent edges") + } +} + +// TestRangeProofWithInvalidNonExistentProof tests such scenarios: +// - There exists a gap between the first element and the left edge proof +// - There exists a gap between the last element and the right edge proof +func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { + trie, vals := randomTrie(t, 8*scaleFactor) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, packEntry(kv)) + } + sort.Sort(entries) + + // Case 1 + start, end := 100, 200 + first := decreaseKey(common.CopyBytes(entries[start].k)) + + proof := memorydb.New() + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + start = 105 // Gap created + k := make([][]byte, 0) + v := make([][]byte, 0) + for i := start; i < end; i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof) + if err == nil { + t.Fatalf("Expected to detect the error, got nil") + } + + // Case 2 + start, end = 100, 200 + last := increaseKey(common.CopyBytes(entries[end-1].k)) + proof = memorydb.New() + if err := trie.Prove(entries[start].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + end = 195 // Capped slice + k = make([][]byte, 0) + v = make([][]byte, 0) + for i := start; i < end; i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof) + if err == nil { + t.Fatalf("Expected to detect the error, got nil") + } +} + +// TestOneElementRangeProof tests the proof with only one +// element. The first edge proof can be existent one or +// non-existent one. +func TestOneElementRangeProof(t *testing.T) { + trie, vals := randomTrie(t, 8*scaleFactor) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, packEntry(kv)) + } + sort.Sort(entries) + + // One element with existent edge proof, both edge proofs + // point to the SAME key. + start := 1000 + proof := memorydb.New() + if err := trie.Prove(entries[start].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // One element with left non-existent edge proof + start = 1000 + first := decreaseKey(common.CopyBytes(entries[start].k)) + proof = memorydb.New() + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[start].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // One element with right non-existent edge proof + start = 1000 + last := increaseKey(common.CopyBytes(entries[start].k)) + proof = memorydb.New() + if err := trie.Prove(entries[start].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // One element with two non-existent edge proofs + start = 1000 + first, last = decreaseKey(common.CopyBytes(entries[start].k)), increaseKey(common.CopyBytes(entries[start].k)) + proof = memorydb.New() + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Test the mini trie with only a single element. + t.Run("single element", func(t *testing.T) { + edb := rawdb.NewMemoryDatabase() + db := geth_trie.NewDatabase(edb) + orig := geth_trie.NewEmpty(db) + entry := &entry{randBytes(32), packValue(mrand.Int63())} + orig.Update(entry.k, entry.v) + root := commitTrie(t, db, orig) + tinyTrie := indexTrie(t, edb, root) + + first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() + last = entry.k + proof = memorydb.New() + if err := tinyTrie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := tinyTrie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + }) +} + +// TestAllElementsProof tests the range proof with all elements. +// The edge proofs can be nil. +func TestAllElementsProof(t *testing.T) { + trie, vals := randomTrie(t, 8*scaleFactor) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, packEntry(kv)) + } + sort.Sort(entries) + + var k [][]byte + var v [][]byte + for i := 0; i < len(entries); i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // With edge proofs, it should still work. + proof := memorydb.New() + if err := trie.Prove(entries[0].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Even with non-existent edge proofs, it should still work. + proof = memorydb.New() + first := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() + last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes() + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } +} + +func positionCases(size int) []int { + var cases []int + for _, pos := range []int{0, 1, 50, 100, 1000, 2000, size - 1} { + if pos >= size { + break + } + cases = append(cases, pos) + } + return cases +} + +// TestSingleSideRangeProof tests the range starts from zero. +func TestSingleSideRangeProof(t *testing.T) { + edb := rawdb.NewMemoryDatabase() + db := geth_trie.NewDatabase(edb) + orig := geth_trie.NewEmpty(db) + var entries entrySlice + for i := 0; i < 8*scaleFactor; i++ { + value := &entry{randBytes(32), packValue(mrand.Int63())} + orig.Update(value.k, value.v) + entries = append(entries, value) + } + root := commitTrie(t, db, orig) + trie := indexTrie(t, edb, root) + sort.Sort(entries) + + for _, pos := range positionCases(len(entries)) { + proof := memorydb.New() + if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[pos].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + k := make([][]byte, 0) + v := make([][]byte, 0) + for i := 0; i <= pos; i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + } +} + +// TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff. +func TestReverseSingleSideRangeProof(t *testing.T) { + edb := rawdb.NewMemoryDatabase() + db := geth_trie.NewDatabase(edb) + orig := geth_trie.NewEmpty(db) + var entries entrySlice + for i := 0; i < 8*scaleFactor; i++ { + value := &entry{randBytes(32), packValue(mrand.Int63())} + orig.Update(value.k, value.v) + entries = append(entries, value) + } + root := commitTrie(t, db, orig) + trie := indexTrie(t, edb, root) + sort.Sort(entries) + + for _, pos := range positionCases(len(entries)) { + proof := memorydb.New() + if err := trie.Prove(entries[pos].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + if err := trie.Prove(last.Bytes(), 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + k := make([][]byte, 0) + v := make([][]byte, 0) + for i := pos; i < len(entries); i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + } +} + +// TestBadRangeProof tests a few cases which the proof is wrong. +// The prover is expected to detect the error. +func TestBadRangeProof(t *testing.T) { + trie, vals := randomTrie(t, 8*scaleFactor) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, packEntry(kv)) + } + sort.Sort(entries) + + for i := 0; i < 500; i++ { + start := mrand.Intn(len(entries)) + end := mrand.Intn(len(entries)-start) + start + 1 + proof := memorydb.New() + if err := trie.Prove(entries[start].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + var keys [][]byte + var vals [][]byte + for i := start; i < end; i++ { + keys = append(keys, entries[i].k) + vals = append(vals, entries[i].v) + } + var first, last = keys[0], keys[len(keys)-1] + testcase := mrand.Intn(6) + var index int + switch testcase { + case 0: + // Modified key + index = mrand.Intn(end - start) + keys[index] = randBytes(32) // In theory it can't be same + case 1: + // Modified val + index = mrand.Intn(end - start) + vals[index] = randBytes(20) // In theory it can't be same + case 2: + // Gapped entry slice + index = mrand.Intn(end - start) + if (index == 0 && start < 100) || (index == end-start-1 && end <= 100) { + continue + } + keys = append(keys[:index], keys[index+1:]...) + vals = append(vals[:index], vals[index+1:]...) + case 3: + // Out of order + index1 := mrand.Intn(end - start) + index2 := mrand.Intn(end - start) + if index1 == index2 { + continue + } + keys[index1], keys[index2] = keys[index2], keys[index1] + vals[index1], vals[index2] = vals[index2], vals[index1] + case 4: + // Set random key to nil, do nothing + index = mrand.Intn(end - start) + keys[index] = nil + case 5: + // Set random value to nil, deletion + index = mrand.Intn(end - start) + vals[index] = nil + } + _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) + if err == nil { + t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1) + } + } +} + +// TestGappedRangeProof focuses on the small trie with embedded nodes. +// If the gapped node is embedded in the trie, it should be detected too. +func TestGappedRangeProof(t *testing.T) { + edb := rawdb.NewMemoryDatabase() + db := geth_trie.NewDatabase(edb) + orig := geth_trie.NewEmpty(db) + var entries entrySlice + for i := byte(0); i < 10; i++ { + value := &entry{common.LeftPadBytes([]byte{i}, 32), packValue(int64(i))} + orig.Update(value.k, value.v) + entries = append(entries, value) + } + root := commitTrie(t, db, orig) + trie := indexTrie(t, edb, root) + first, last := 2, 8 + proof := memorydb.New() + if err := trie.Prove(entries[first].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[last-1].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + var keys [][]byte + var vals [][]byte + for i := first; i < last; i++ { + if i == (first+last)/2 { + continue + } + keys = append(keys, entries[i].k) + vals = append(vals, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) + if err == nil { + t.Fatal("expect error, got nil") + } +} + +// TestSameSideProofs tests the element is not in the range covered by proofs +func TestSameSideProofs(t *testing.T) { + trie, vals := randomTrie(t, 8*scaleFactor) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, packEntry(kv)) + } + sort.Sort(entries) + + pos := 1000 + first := decreaseKey(common.CopyBytes(entries[pos].k)) + first = decreaseKey(first) + last := decreaseKey(common.CopyBytes(entries[pos].k)) + + proof := memorydb.New() + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) + if err == nil { + t.Fatalf("Expected error, got nil") + } + + first = increaseKey(common.CopyBytes(entries[pos].k)) + last = increaseKey(common.CopyBytes(entries[pos].k)) + last = increaseKey(last) + + proof = memorydb.New() + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) + if err == nil { + t.Fatalf("Expected error, got nil") + } +} + +func TestHasRightElement(t *testing.T) { + edb := rawdb.NewMemoryDatabase() + db := geth_trie.NewDatabase(edb) + orig := geth_trie.NewEmpty(db) + var entries entrySlice + for i := 0; i < 8*scaleFactor; i++ { + value := &entry{randBytes(32), packValue(int64(i))} + orig.Update(value.k, value.v) + entries = append(entries, value) + } + root := commitTrie(t, db, orig) + trie := indexTrie(t, edb, root) + sort.Sort(entries) + + var cases = []struct { + start int + end int + hasMore bool + }{ + {-1, 1, true}, // single element with non-existent left proof + {0, 1, true}, // single element with existent left proof + {0, 10, true}, + {50, 100, true}, + {50, len(entries), false}, // No more element expected + {len(entries) - 1, len(entries), false}, // Single last element with two existent proofs(point to same key) + {len(entries) - 1, -1, false}, // Single last element with non-existent right proof + {0, len(entries), false}, // The whole set with existent left proof + {-1, len(entries), false}, // The whole set with non-existent left proof + {-1, -1, false}, // The whole set with non-existent left/right proof + } + for _, c := range cases { + var ( + firstKey []byte + lastKey []byte + start = c.start + end = c.end + proof = memorydb.New() + ) + if c.start == -1 { + firstKey, start = common.Hash{}.Bytes(), 0 + if err := trie.Prove(firstKey, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + } else { + firstKey = entries[c.start].k + if err := trie.Prove(entries[c.start].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + } + if c.end == -1 { + lastKey, end = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes(), len(entries) + if err := trie.Prove(lastKey, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + } else { + lastKey = entries[c.end-1].k + if err := trie.Prove(entries[c.end-1].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + } + k := make([][]byte, 0) + v := make([][]byte, 0) + for i := start; i < end; i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if hasMore != c.hasMore { + t.Fatalf("Wrong hasMore indicator, want %t, got %t", c.hasMore, hasMore) + } + } +} + +// TestEmptyRangeProof tests the range proof with "no" element. +// The first edge proof must be a non-existent proof. +func TestEmptyRangeProof(t *testing.T) { + trie, vals := randomTrie(t, 8*scaleFactor) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, packEntry(kv)) + } + sort.Sort(entries) + + var cases = []struct { + pos int + err bool + }{ + {len(entries) - 1, false}, + {500, true}, + } + for _, c := range cases { + proof := memorydb.New() + first := increaseKey(common.CopyBytes(entries[c.pos].k)) + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof) + if c.err && err == nil { + t.Fatalf("Expected error, got nil") + } + if !c.err && err != nil { + t.Fatalf("Expected no error, got %v", err) + } + } +} + +// TestEmptyValueRangeProof tests normal range proof with both edge proofs +// as the existent proof, but with an extra empty value included, which is a +// noop technically, but practically should be rejected. +func TestEmptyValueRangeProof(t *testing.T) { + trie, values := randomTrie(t, scaleFactor) + var entries entrySlice + for _, kv := range values { + entries = append(entries, packEntry(kv)) + } + sort.Sort(entries) + + // Create a new entry with a slightly modified key + mid := len(entries) / 2 + key := common.CopyBytes(entries[mid-1].k) + for n := len(key) - 1; n >= 0; n-- { + if key[n] < 0xff { + key[n]++ + break + } + } + noop := &entry{key, []byte{}} + entries = append(append(append(entrySlice{}, entries[:mid]...), noop), entries[mid:]...) + + start, end := 1, len(entries)-1 + + proof := memorydb.New() + if err := trie.Prove(entries[start].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + var keys [][]byte + var vals [][]byte + for i := start; i < end; i++ { + keys = append(keys, entries[i].k) + vals = append(vals, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) + if err == nil { + t.Fatalf("Expected failure on noop entry") + } +} + +// TestAllElementsEmptyValueRangeProof tests the range proof with all elements, +// but with an extra empty value included, which is a noop technically, but +// practically should be rejected. +func TestAllElementsEmptyValueRangeProof(t *testing.T) { + trie, values := randomTrie(t, scaleFactor) + var entries entrySlice + for _, kv := range values { + entries = append(entries, packEntry(kv)) + } + sort.Sort(entries) + + // Create a new entry with a slightly modified key + mid := len(entries) / 2 + key := common.CopyBytes(entries[mid-1].k) + for n := len(key) - 1; n >= 0; n-- { + if key[n] < 0xff { + key[n]++ + break + } + } + noop := &entry{key, nil} + entries = append(append(append(entrySlice{}, entries[:mid]...), noop), entries[mid:]...) + + var keys [][]byte + var vals [][]byte + for i := 0; i < len(entries); i++ { + keys = append(keys, entries[i].k) + vals = append(vals, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), nil, nil, keys, vals, nil) + if err == nil { + t.Fatalf("Expected failure on noop entry") + } +} + +// mutateByte changes one byte in b. +func mutateByte(b []byte) { + for r := mrand.Intn(len(b)); ; { + new := byte(mrand.Intn(255)) + if new != b[r] { + b[r] = new + break + } + } +} + +func increaseKey(key []byte) []byte { + for i := len(key) - 1; i >= 0; i-- { + key[i]++ + if key[i] != 0x0 { + break + } + } + return key +} + +func decreaseKey(key []byte) []byte { + for i := len(key) - 1; i >= 0; i-- { + key[i]-- + if key[i] != 0xff { + break + } + } + return key +} + +func BenchmarkProve(b *testing.B) { + trie, vals := randomTrie(b, 100) + var keys []string + for k := range vals { + keys = append(keys, k) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + kv := vals[keys[i%len(keys)]] + proofs := memorydb.New() + if trie.Prove(kv.k, 0, proofs); proofs.Len() == 0 { + b.Fatalf("zero length proof for %x", kv.k) + } + } +} + +func BenchmarkVerifyProof(b *testing.B) { + trie, vals := randomTrie(b, 100) + root := trie.Hash() + var keys []string + var proofs []*memorydb.Database + for k := range vals { + keys = append(keys, k) + proof := memorydb.New() + trie.Prove([]byte(k), 0, proof) + proofs = append(proofs, proof) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + im := i % len(keys) + if _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil { + b.Fatalf("key %x: %v", keys[im], err) + } + } +} + +func BenchmarkVerifyRangeProof10(b *testing.B) { benchmarkVerifyRangeProof(b, 10) } +func BenchmarkVerifyRangeProof100(b *testing.B) { benchmarkVerifyRangeProof(b, 100) } +func BenchmarkVerifyRangeProof1000(b *testing.B) { benchmarkVerifyRangeProof(b, 1000) } +func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b, 5000) } + +func benchmarkVerifyRangeProof(b *testing.B, size int) { + trie, vals := randomTrie(b, 8192) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, packEntry(kv)) + } + sort.Sort(entries) + + start := 2 + end := start + size + proof := memorydb.New() + if err := trie.Prove(entries[start].k, 0, proof); err != nil { + b.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { + b.Fatalf("Failed to prove the last node %v", err) + } + var keys [][]byte + var values [][]byte + for i := start; i < end; i++ { + keys = append(keys, entries[i].k) + values = append(values, entries[i].v) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof) + if err != nil { + b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) + } + } +} + +func BenchmarkVerifyRangeNoProof10(b *testing.B) { benchmarkVerifyRangeNoProof(b, 100) } +func BenchmarkVerifyRangeNoProof500(b *testing.B) { benchmarkVerifyRangeNoProof(b, 500) } +func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof(b, 1000) } + +func benchmarkVerifyRangeNoProof(b *testing.B, size int) { + trie, vals := randomTrie(b, size) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, packEntry(kv)) + } + sort.Sort(entries) + + var keys [][]byte + var values [][]byte + for _, entry := range entries { + keys = append(keys, entry.k) + values = append(values, entry.v) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, nil) + if err != nil { + b.Fatalf("Expected no error, got %v", err) + } + } +} + +func TestRangeProofKeysWithSharedPrefix(t *testing.T) { + keys := [][]byte{ + common.Hex2Bytes("aa10000000000000000000000000000000000000000000000000000000000000"), + common.Hex2Bytes("aa20000000000000000000000000000000000000000000000000000000000000"), + } + vals := [][]byte{ + packValue(2), + packValue(3), + } + edb := rawdb.NewMemoryDatabase() + db := geth_trie.NewDatabase(edb) + orig := geth_trie.NewEmpty(db) + for i, key := range keys { + orig.Update(key, vals[i]) + } + root := commitTrie(t, db, orig) + trie := indexTrie(t, edb, root) + + proof := memorydb.New() + start := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000") + end := common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + if err := trie.Prove(start, 0, proof); err != nil { + t.Fatalf("failed to prove start: %v", err) + } + if err := trie.Prove(end, 0, proof); err != nil { + t.Fatalf("failed to prove end: %v", err) + } + + more, err := VerifyRangeProof(root, start, end, keys, vals, proof) + if err != nil { + t.Fatalf("failed to verify range proof: %v", err) + } + if more != false { + t.Error("expected more to be false") + } +} diff --git a/trie_by_cid/trie/util_test.go b/trie_by_cid/trie/util_test.go index 8ed1fe9..4756314 100644 --- a/trie_by_cid/trie/util_test.go +++ b/trie_by_cid/trie/util_test.go @@ -2,10 +2,149 @@ package trie_test import ( "fmt" + "math/big" + "math/rand" + "testing" "github.com/jmoiron/sqlx" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + geth_state "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/rlp" + geth_trie "github.com/ethereum/go-ethereum/trie" + + pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0" + "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper" + "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/state" + "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" + "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" + "github.com/ethereum/go-ethereum/statediff/indexer/ipld" + "github.com/ethereum/go-ethereum/statediff/test_helpers" ) +type kv struct { + k []byte + v int64 +} + +type kvMap map[string]*kv + +type kvs struct { + k string + v int64 +} + +func packValue(val int64) []byte { + acct := &types.StateAccount{ + Balance: big.NewInt(val), + CodeHash: test_helpers.NullCodeHash.Bytes(), + Root: test_helpers.EmptyContractRoot, + } + acct_rlp, err := rlp.EncodeToBytes(acct) + if err != nil { + panic(err) + } + return acct_rlp +} + +func unpackValue(val []byte) int64 { + var acct types.StateAccount + if err := rlp.DecodeBytes(val, &acct); err != nil { + panic(err) + } + return acct.Balance.Int64() +} + +func updateTrie(tr *geth_trie.Trie, vals []kvs) (kvMap, error) { + all := kvMap{} + for _, val := range vals { + all[string(val.k)] = &kv{[]byte(val.k), val.v} + tr.Update([]byte(val.k), packValue(val.v)) + } + return all, nil +} + +func commitTrie(t testing.TB, db *geth_trie.Database, tr *geth_trie.Trie) common.Hash { + t.Helper() + root, nodes := tr.Commit(false) + if err := db.Update(geth_trie.NewWithNodeSet(nodes)); err != nil { + t.Fatal(err) + } + if err := db.Commit(root, false); err != nil { + t.Fatal(err) + } + return root +} + +// commit a LevelDB state trie, index to IPLD and return new trie +func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *trie.Trie { + t.Helper() + dbConfig.Driver = postgres.PGX + err := helper.IndexChain(dbConfig, geth_state.NewDatabase(edb), common.Hash{}, root) + if err != nil { + t.Fatal(err) + } + + pg_db, err := postgres.ConnectSQLX(ctx, dbConfig) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := TearDownDB(pg_db); err != nil { + t.Fatal(err) + } + }) + + ipfs_db := pgipfsethdb.NewDatabase(pg_db, makeCacheConfig(t)) + sdb_db := state.NewDatabase(ipfs_db) + tr, err := trie.New(common.Hash{}, root, sdb_db.TrieDB(), ipld.MEthStateTrie) + if err != nil { + t.Fatal(err) + } + return tr +} + +// generates a random Geth LevelDB trie of n key-value pairs and corresponding value map +func randomGethTrie(n int, db *geth_trie.Database) (*geth_trie.Trie, kvMap) { + trie := geth_trie.NewEmpty(db) + var vals []*kv + for i := byte(0); i < 100; i++ { + e := &kv{common.LeftPadBytes([]byte{i}, 32), int64(i)} + e2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), int64(i)} + vals = append(vals, e, e2) + } + for i := 0; i < n; i++ { + k := randBytes(32) + v := rand.Int63() + vals = append(vals, &kv{k, v}) + } + all := kvMap{} + for _, val := range vals { + all[string(val.k)] = &kv{[]byte(val.k), val.v} + trie.Update([]byte(val.k), packValue(val.v)) + } + return trie, all +} + +// generates a random IPLD-indexed trie +func randomTrie(t testing.TB, n int) (*trie.Trie, kvMap) { + edb := rawdb.NewMemoryDatabase() + db := geth_trie.NewDatabase(edb) + orig, vals := randomGethTrie(n, db) + root := commitTrie(t, db, orig) + trie := indexTrie(t, edb, root) + return trie, vals +} + +func randBytes(n int) []byte { + r := make([]byte, n) + rand.Read(r) + return r +} + // TearDownDB is used to tear down the watcher dbs after tests func TearDownDB(db *sqlx.DB) error { tx, err := db.Beginx()