diff --git a/trie/proof.go b/trie/proof.go index e7102f12b..61c35a842 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -216,7 +216,7 @@ func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyV // // Note we have the assumption here the given boundary keys are different // and right is larger than left. -func unsetInternal(n node, left []byte, right []byte) error { +func unsetInternal(n node, left []byte, right []byte) (bool, error) { left, right = keybytesToHex(left), keybytesToHex(right) // Step down to the fork point. There are two scenarios can happen: @@ -278,45 +278,55 @@ findFork: // - left proof points to the shortnode, but right proof is greater // - right proof points to the shortnode, but left proof is less if shortForkLeft == -1 && shortForkRight == -1 { - return errors.New("empty range") + return false, errors.New("empty range") } if shortForkLeft == 1 && shortForkRight == 1 { - return errors.New("empty range") + return false, errors.New("empty range") } if shortForkLeft != 0 && shortForkRight != 0 { + // The fork point is root node, unset the entire trie + if parent == nil { + return true, nil + } parent.(*fullNode).Children[left[pos-1]] = nil - return nil + return false, nil } // Only one proof points to non-existent key. if shortForkRight != 0 { - // Unset left proof's path if _, ok := rn.Val.(valueNode); ok { + // The fork point is root node, unset the entire trie + if parent == nil { + return true, nil + } parent.(*fullNode).Children[left[pos-1]] = nil - return nil + return false, nil } - return unset(rn, rn.Val, left[pos:], len(rn.Key), false) + return false, unset(rn, rn.Val, left[pos:], len(rn.Key), false) } if shortForkLeft != 0 { - // Unset right proof's path. if _, ok := rn.Val.(valueNode); ok { + // The fork point is root node, unset the entire trie + if parent == nil { + return true, nil + } parent.(*fullNode).Children[right[pos-1]] = nil - return nil + return false, nil } - return unset(rn, rn.Val, right[pos:], len(rn.Key), true) + return false, unset(rn, rn.Val, right[pos:], len(rn.Key), true) } - return nil + return false, nil case *fullNode: // unset all internal nodes in the forkpoint for i := left[pos] + 1; i < right[pos]; i++ { rn.Children[i] = nil } if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil { - return err + return false, err } if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil { - return err + return false, err } - return nil + return false, nil default: panic(fmt.Sprintf("%T: invalid node: %v", n, n)) } @@ -560,7 +570,8 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key } // Remove all internal references. All the removed parts should // be re-filled(or re-constructed) by the given leaves range. - if err := unsetInternal(root, firstKey, lastKey); err != nil { + empty, err := unsetInternal(root, firstKey, lastKey) + if err != nil { return nil, nil, nil, false, err } // Rebuild the trie with the leaf stream, the shape of trie @@ -570,6 +581,9 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key triedb = NewDatabase(diskdb) ) tr := &Trie{root: root, db: triedb} + if empty { + tr.root = nil + } for index, key := range keys { tr.TryUpdate(key, values[index]) } diff --git a/trie/proof_test.go b/trie/proof_test.go index 3ecd31888..304affa9f 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -384,6 +384,25 @@ func TestOneElementRangeProof(t *testing.T) { if err != nil { t.Fatalf("Expected no error, got %v", err) } + + // Test the mini trie with only a single element. + tinyTrie := new(Trie) + entry := &kv{randBytes(32), randBytes(20), false} + tinyTrie.Update(entry.k, entry.v) + + first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() + last = entry.k + proof = memorydb.New() + if err := tinyTrie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := tinyTrie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, _, _, _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } } // TestAllElementsProof tests the range proof with all elements.