From 44ff3f3dc98a5bd72e06ff6b05739c2dce8c9b62 Mon Sep 17 00:00:00 2001 From: gary rong Date: Fri, 24 Apr 2020 19:37:56 +0800 Subject: [PATCH] trie: initial implementation for range proof (#20908) * trie: initial implementation for range proof * trie: add benchmark * trie: fix lint * trie: fix minor issue * trie: unset the edge valuenode as well * trie: unset the edge valuenode as nilValuenode --- les/odr_requests.go | 6 +- trie/proof.go | 221 ++++++++++++++++++++++++++++++++++++++++++-- trie/proof_test.go | 188 ++++++++++++++++++++++++++++++++++++- 3 files changed, 400 insertions(+), 15 deletions(-) diff --git a/les/odr_requests.go b/les/odr_requests.go index 146da2213..c4b38060c 100644 --- a/les/odr_requests.go +++ b/les/odr_requests.go @@ -224,7 +224,7 @@ func (r *TrieRequest) Validate(db ethdb.Database, msg *Msg) error { // Verify the proof and store if checks out nodeSet := proofs.NodeSet() reads := &readTraceDB{db: nodeSet} - if _, _, err := trie.VerifyProof(r.Id.Root, r.Key, reads); err != nil { + if _, err := trie.VerifyProof(r.Id.Root, r.Key, reads); err != nil { return fmt.Errorf("merkle proof verification failed: %v", err) } // check if all nodes have been read by VerifyProof @@ -378,7 +378,7 @@ func (r *ChtRequest) Validate(db ethdb.Database, msg *Msg) error { binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) reads := &readTraceDB{db: nodeSet} - value, _, err := trie.VerifyProof(r.ChtRoot, encNumber[:], reads) + value, err := trie.VerifyProof(r.ChtRoot, encNumber[:], reads) if err != nil { return fmt.Errorf("merkle proof verification failed: %v", err) } @@ -470,7 +470,7 @@ func (r *BloomRequest) Validate(db ethdb.Database, msg *Msg) error { for i, idx := range r.SectionIndexList { binary.BigEndian.PutUint64(encNumber[2:], idx) - value, _, err := trie.VerifyProof(r.BloomTrieRoot, encNumber[:], reads) + value, err := trie.VerifyProof(r.BloomTrieRoot, encNumber[:], reads) if err != nil { return err } diff --git a/trie/proof.go b/trie/proof.go index 58ca69c68..07ce8e6d8 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -18,10 +18,12 @@ package trie import ( "bytes" + "errors" "fmt" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" ) @@ -101,33 +103,232 @@ func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWri // VerifyProof checks merkle proofs. The given proof must contain the value for // key in a trie with the given root hash. VerifyProof returns an error if the // proof contains invalid trie nodes or the wrong value. -func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) (value []byte, nodes int, err error) { +func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) { key = keybytesToHex(key) wantHash := rootHash for i := 0; ; i++ { buf, _ := proofDb.Get(wantHash[:]) if buf == nil { - return nil, i, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash) + return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash) } n, err := decodeNode(wantHash[:], buf) if err != nil { - return nil, i, fmt.Errorf("bad proof node %d: %v", i, err) + return nil, fmt.Errorf("bad proof node %d: %v", i, err) } - keyrest, cld := get(n, key) + keyrest, cld := get(n, key, true) switch cld := cld.(type) { case nil: // The trie doesn't contain the key. - return nil, i, nil + return nil, nil case hashNode: key = keyrest copy(wantHash[:], cld) case valueNode: - return cld, i + 1, nil + return cld, nil } } } -func get(tn node, key []byte) ([]byte, node) { +// proofToPath converts a merkle proof to trie node path. +// The main purpose of this function is recovering a node +// path from the merkle proof stream. All necessary nodes +// will be resolved and leave the remaining as hashnode. +func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyValueReader) (node, error) { + // resolveNode retrieves and resolves trie node from merkle proof stream + resolveNode := func(hash common.Hash) (node, error) { + buf, _ := proofDb.Get(hash[:]) + if buf == nil { + return nil, fmt.Errorf("proof node (hash %064x) missing", hash) + } + n, err := decodeNode(hash[:], buf) + if err != nil { + return nil, fmt.Errorf("bad proof node %v", err) + } + return n, err + } + // If the root node is empty, resolve it first + if root == nil { + n, err := resolveNode(rootHash) + if err != nil { + return nil, err + } + root = n + } + var ( + err error + child, parent node + keyrest []byte + terminate bool + ) + key, parent = keybytesToHex(key), root + for { + keyrest, child = get(parent, key, false) + switch cld := child.(type) { + case nil: + // The trie doesn't contain the key. + return nil, errors.New("the node is not contained in trie") + case *shortNode: + key, parent = keyrest, child // Already resolved + continue + case *fullNode: + key, parent = keyrest, child // Already resolved + continue + case hashNode: + child, err = resolveNode(common.BytesToHash(cld)) + if err != nil { + return nil, err + } + case valueNode: + terminate = true + } + // Link the parent and child. + switch pnode := parent.(type) { + case *shortNode: + pnode.Val = child + case *fullNode: + pnode.Children[key[0]] = child + default: + panic(fmt.Sprintf("%T: invalid node: %v", pnode, pnode)) + } + if terminate { + return root, nil // The whole path is resolved + } + key, parent = keyrest, child + } +} + +// unsetInternal removes all internal node references(hashnode, embedded node). +// It should be called after a trie is constructed with two edge proofs. Also +// the given boundary keys must be the one used to construct the edge proofs. +// +// It's the key step for range proof. All visited nodes should be marked dirty +// since the node content might be modified. Besides it can happen that some +// fullnodes only have one child which is disallowed. But if the proof is valid, +// the missing children will be filled, otherwise it will be thrown anyway. +func unsetInternal(node node, left []byte, right []byte) error { + left, right = keybytesToHex(left), keybytesToHex(right) + + // todo(rjl493456442) different length edge keys should be supported + if len(left) != len(right) { + return errors.New("inconsistent edge path") + } + // Step down to the fork point + prefix, pos := prefixLen(left, right), 0 + for { + if pos >= prefix { + break + } + switch n := (node).(type) { + case *shortNode: + if len(left)-pos < len(n.Key) || !bytes.Equal(n.Key, left[pos:pos+len(n.Key)]) { + return errors.New("invalid edge path") + } + n.flags = nodeFlag{dirty: true} + node, pos = n.Val, pos+len(n.Key) + case *fullNode: + n.flags = nodeFlag{dirty: true} + node, pos = n.Children[left[pos]], pos+1 + default: + panic(fmt.Sprintf("%T: invalid node: %v", node, node)) + } + } + fn, ok := node.(*fullNode) + if !ok { + return errors.New("the fork point must be a fullnode") + } + // Find the fork point! Unset all intermediate references + for i := left[prefix] + 1; i < right[prefix]; i++ { + fn.Children[i] = nil + } + fn.flags = nodeFlag{dirty: true} + unset(fn.Children[left[prefix]], left[prefix+1:], false) + unset(fn.Children[right[prefix]], right[prefix+1:], true) + return nil +} + +// unset removes all internal node references either the left most or right most. +func unset(root node, rest []byte, removeLeft bool) { + switch rn := root.(type) { + case *fullNode: + if removeLeft { + for i := 0; i < int(rest[0]); i++ { + rn.Children[i] = nil + } + rn.flags = nodeFlag{dirty: true} + } else { + for i := rest[0] + 1; i < 16; i++ { + rn.Children[i] = nil + } + rn.flags = nodeFlag{dirty: true} + } + unset(rn.Children[rest[0]], rest[1:], removeLeft) + case *shortNode: + rn.flags = nodeFlag{dirty: true} + if _, ok := rn.Val.(valueNode); ok { + rn.Val = nilValueNode + return + } + unset(rn.Val, rest[len(rn.Key):], removeLeft) + case hashNode, nil, valueNode: + panic("it shouldn't happen") + } +} + +// VerifyRangeProof checks whether the given leave nodes and edge proofs +// can prove the given trie leaves range is matched with given root hash +// and the range is consecutive(no gap inside). +func VerifyRangeProof(rootHash common.Hash, keys [][]byte, values [][]byte, firstProof ethdb.KeyValueReader, lastProof ethdb.KeyValueReader) error { + if len(keys) != len(values) { + return fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values)) + } + if len(keys) == 0 { + return fmt.Errorf("nothing to verify") + } + if len(keys) == 1 { + value, err := VerifyProof(rootHash, keys[0], firstProof) + if err != nil { + return err + } + if !bytes.Equal(value, values[0]) { + return fmt.Errorf("correct proof but invalid data") + } + return nil + } + // Convert the edge proofs to edge trie paths. Then we can + // have the same tree architecture with the original one. + root, err := proofToPath(rootHash, nil, keys[0], firstProof) + if err != nil { + return err + } + // Pass the root node here, the second path will be merged + // with the first one. + root, err = proofToPath(rootHash, root, keys[len(keys)-1], lastProof) + if err != nil { + return err + } + // Remove all internal references. All the removed parts should + // be re-filled(or re-constructed) by the given leaves range. + if err := unsetInternal(root, keys[0], keys[len(keys)-1]); err != nil { + return err + } + // Rebuild the trie with the leave stream, the shape of trie + // should be same with the original one. + newtrie := &Trie{root: root, db: NewDatabase(memorydb.New())} + for index, key := range keys { + newtrie.TryUpdate(key, values[index]) + } + if newtrie.Hash() != rootHash { + return fmt.Errorf("invalid proof, wanthash %x, got %x", rootHash, newtrie.Hash()) + } + return nil +} + +// get returns the child of the given node. Return nil if the +// node with specified key doesn't exist at all. +// +// There is an additional flag `skipResolved`. If it's set then +// all resolved nodes won't be returned. +func get(tn node, key []byte, skipResolved bool) ([]byte, node) { for { switch n := tn.(type) { case *shortNode: @@ -136,9 +337,15 @@ func get(tn node, key []byte) ([]byte, node) { } tn = n.Val key = key[len(n.Key):] + if !skipResolved { + return key, tn + } case *fullNode: tn = n.Children[key[0]] key = key[1:] + if !skipResolved { + return key, tn + } case hashNode: return key, n case nil: diff --git a/trie/proof_test.go b/trie/proof_test.go index 4caae7338..781702b87 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -20,6 +20,7 @@ import ( "bytes" crand "crypto/rand" mrand "math/rand" + "sort" "testing" "time" @@ -65,7 +66,7 @@ func TestProof(t *testing.T) { if proof == nil { t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k) } - val, _, err := VerifyProof(root, kv.k, proof) + val, err := VerifyProof(root, kv.k, proof) if err != nil { t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof) } @@ -87,7 +88,7 @@ func TestOneElementProof(t *testing.T) { if proof.Len() != 1 { t.Errorf("prover %d: proof should have one element", i) } - val, _, err := VerifyProof(trie.Hash(), []byte("k"), proof) + val, err := VerifyProof(trie.Hash(), []byte("k"), proof) if err != nil { t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) } @@ -97,6 +98,145 @@ func TestOneElementProof(t *testing.T) { } } +type entrySlice []*kv + +func (p entrySlice) Len() int { return len(p) } +func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 } +func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +func TestRangeProof(t *testing.T) { + trie, vals := randomTrie(4096) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, kv) + } + sort.Sort(entries) + for i := 0; i < 500; i++ { + start := mrand.Intn(len(entries)) + end := mrand.Intn(len(entries)-start) + start + if start == end { + continue + } + firstProof, lastProof := memorydb.New(), memorydb.New() + if err := trie.Prove(entries[start].k, 0, firstProof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[end-1].k, 0, lastProof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + var keys [][]byte + var vals [][]byte + for i := start; i < end; i++ { + keys = append(keys, entries[i].k) + vals = append(vals, entries[i].v) + } + err := VerifyRangeProof(trie.Hash(), keys, vals, firstProof, lastProof) + if err != nil { + t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) + } + } +} + +func TestBadRangeProof(t *testing.T) { + trie, vals := randomTrie(4096) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, kv) + } + sort.Sort(entries) + + for i := 0; i < 500; i++ { + start := mrand.Intn(len(entries)) + end := mrand.Intn(len(entries)-start) + start + if start == end { + continue + } + firstProof, lastProof := memorydb.New(), memorydb.New() + if err := trie.Prove(entries[start].k, 0, firstProof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[end-1].k, 0, lastProof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + var keys [][]byte + var vals [][]byte + for i := start; i < end; i++ { + keys = append(keys, entries[i].k) + vals = append(vals, entries[i].v) + } + testcase := mrand.Intn(6) + var index int + switch testcase { + case 0: + // Modified key + index = mrand.Intn(end - start) + keys[index] = randBytes(32) // In theory it can't be same + case 1: + // Modified val + index = mrand.Intn(end - start) + vals[index] = randBytes(20) // In theory it can't be same + case 2: + // Gapped entry slice + index = mrand.Intn(end - start) + keys = append(keys[:index], keys[index+1:]...) + vals = append(vals[:index], vals[index+1:]...) + if len(keys) <= 1 { + continue + } + case 3: + // Switched entry slice, same effect with gapped + index = mrand.Intn(end - start) + keys[index] = entries[len(entries)-1].k + vals[index] = entries[len(entries)-1].v + case 4: + // Set random key to nil + index = mrand.Intn(end - start) + keys[index] = nil + case 5: + // Set random value to nil + index = mrand.Intn(end - start) + vals[index] = nil + } + err := VerifyRangeProof(trie.Hash(), keys, vals, firstProof, lastProof) + if err == nil { + t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1) + } + } +} + +// TestGappedRangeProof focuses on the small trie with embedded nodes. +// If the gapped node is embedded in the trie, it should be detected too. +func TestGappedRangeProof(t *testing.T) { + trie := new(Trie) + var entries []*kv // Sorted entries + for i := byte(0); i < 10; i++ { + value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + trie.Update(value.k, value.v) + entries = append(entries, value) + } + first, last := 2, 8 + firstProof, lastProof := memorydb.New(), memorydb.New() + if err := trie.Prove(entries[first].k, 0, firstProof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[last-1].k, 0, lastProof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + var keys [][]byte + var vals [][]byte + for i := first; i < last; i++ { + if i == (first+last)/2 { + continue + } + keys = append(keys, entries[i].k) + vals = append(vals, entries[i].v) + } + err := VerifyRangeProof(trie.Hash(), keys, vals, firstProof, lastProof) + if err == nil { + t.Fatal("expect error, got nil") + } +} + func TestBadProof(t *testing.T) { trie, vals := randomTrie(800) root := trie.Hash() @@ -118,7 +258,7 @@ func TestBadProof(t *testing.T) { mutateByte(val) proof.Put(crypto.Keccak256(val), val) - if _, _, err := VerifyProof(root, kv.k, proof); err == nil { + if _, err := VerifyProof(root, kv.k, proof); err == nil { t.Fatalf("prover %d: expected proof to fail for key %x", i, kv.k) } } @@ -138,7 +278,7 @@ func TestMissingKeyProof(t *testing.T) { if proof.Len() != 1 { t.Errorf("test %d: proof should have one element", i) } - val, _, err := VerifyProof(trie.Hash(), []byte(key), proof) + val, err := VerifyProof(trie.Hash(), []byte(key), proof) if err != nil { t.Fatalf("test %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) } @@ -191,12 +331,50 @@ func BenchmarkVerifyProof(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { im := i % len(keys) - if _, _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil { + if _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil { b.Fatalf("key %x: %v", keys[im], err) } } } +func BenchmarkVerifyRangeProof10(b *testing.B) { benchmarkVerifyRangeProof(b, 10) } +func BenchmarkVerifyRangeProof100(b *testing.B) { benchmarkVerifyRangeProof(b, 100) } +func BenchmarkVerifyRangeProof1000(b *testing.B) { benchmarkVerifyRangeProof(b, 1000) } +func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b, 5000) } + +func benchmarkVerifyRangeProof(b *testing.B, size int) { + trie, vals := randomTrie(8192) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, kv) + } + sort.Sort(entries) + + start := 2 + end := start + size + firstProof, lastProof := memorydb.New(), memorydb.New() + if err := trie.Prove(entries[start].k, 0, firstProof); err != nil { + b.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[end-1].k, 0, lastProof); err != nil { + b.Fatalf("Failed to prove the last node %v", err) + } + var keys [][]byte + var values [][]byte + for i := start; i < end; i++ { + keys = append(keys, entries[i].k) + values = append(values, entries[i].v) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := VerifyRangeProof(trie.Hash(), keys, values, firstProof, lastProof) + if err != nil { + b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) + } + } +} + func randomTrie(n int) (*Trie, map[string]*kv) { trie := new(Trie) vals := make(map[string]*kv)