trie: fix range prover (#22210)

Fixes a special case when the trie only has a single trie node and the range proof only contains a single element.
This commit is contained in:
gary rong 2021-01-22 17:11:24 +08:00 committed by GitHub
parent 231040c633
commit 9e1bd0f367
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 15 deletions

View File

@ -216,7 +216,7 @@ func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyV
// //
// Note we have the assumption here the given boundary keys are different // Note we have the assumption here the given boundary keys are different
// and right is larger than left. // and right is larger than left.
func unsetInternal(n node, left []byte, right []byte) error { func unsetInternal(n node, left []byte, right []byte) (bool, error) {
left, right = keybytesToHex(left), keybytesToHex(right) left, right = keybytesToHex(left), keybytesToHex(right)
// Step down to the fork point. There are two scenarios can happen: // Step down to the fork point. There are two scenarios can happen:
@ -278,45 +278,55 @@ findFork:
// - left proof points to the shortnode, but right proof is greater // - left proof points to the shortnode, but right proof is greater
// - right proof points to the shortnode, but left proof is less // - right proof points to the shortnode, but left proof is less
if shortForkLeft == -1 && shortForkRight == -1 { if shortForkLeft == -1 && shortForkRight == -1 {
return errors.New("empty range") return false, errors.New("empty range")
} }
if shortForkLeft == 1 && shortForkRight == 1 { if shortForkLeft == 1 && shortForkRight == 1 {
return errors.New("empty range") return false, errors.New("empty range")
} }
if shortForkLeft != 0 && shortForkRight != 0 { if shortForkLeft != 0 && shortForkRight != 0 {
// The fork point is root node, unset the entire trie
if parent == nil {
return true, nil
}
parent.(*fullNode).Children[left[pos-1]] = nil parent.(*fullNode).Children[left[pos-1]] = nil
return nil return false, nil
} }
// Only one proof points to non-existent key. // Only one proof points to non-existent key.
if shortForkRight != 0 { if shortForkRight != 0 {
// Unset left proof's path
if _, ok := rn.Val.(valueNode); ok { if _, ok := rn.Val.(valueNode); ok {
// The fork point is root node, unset the entire trie
if parent == nil {
return true, nil
}
parent.(*fullNode).Children[left[pos-1]] = nil parent.(*fullNode).Children[left[pos-1]] = nil
return nil return false, nil
} }
return unset(rn, rn.Val, left[pos:], len(rn.Key), false) return false, unset(rn, rn.Val, left[pos:], len(rn.Key), false)
} }
if shortForkLeft != 0 { if shortForkLeft != 0 {
// Unset right proof's path.
if _, ok := rn.Val.(valueNode); ok { if _, ok := rn.Val.(valueNode); ok {
// The fork point is root node, unset the entire trie
if parent == nil {
return true, nil
}
parent.(*fullNode).Children[right[pos-1]] = nil parent.(*fullNode).Children[right[pos-1]] = nil
return nil return false, nil
} }
return unset(rn, rn.Val, right[pos:], len(rn.Key), true) return false, unset(rn, rn.Val, right[pos:], len(rn.Key), true)
} }
return nil return false, nil
case *fullNode: case *fullNode:
// unset all internal nodes in the forkpoint // unset all internal nodes in the forkpoint
for i := left[pos] + 1; i < right[pos]; i++ { for i := left[pos] + 1; i < right[pos]; i++ {
rn.Children[i] = nil rn.Children[i] = nil
} }
if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil { if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil {
return err return false, err
} }
if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil { if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil {
return err return false, err
} }
return nil return false, nil
default: default:
panic(fmt.Sprintf("%T: invalid node: %v", n, n)) panic(fmt.Sprintf("%T: invalid node: %v", n, n))
} }
@ -560,7 +570,8 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key
} }
// Remove all internal references. All the removed parts should // Remove all internal references. All the removed parts should
// be re-filled(or re-constructed) by the given leaves range. // be re-filled(or re-constructed) by the given leaves range.
if err := unsetInternal(root, firstKey, lastKey); err != nil { empty, err := unsetInternal(root, firstKey, lastKey)
if err != nil {
return nil, nil, nil, false, err return nil, nil, nil, false, err
} }
// Rebuild the trie with the leaf stream, the shape of trie // Rebuild the trie with the leaf stream, the shape of trie
@ -570,6 +581,9 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key
triedb = NewDatabase(diskdb) triedb = NewDatabase(diskdb)
) )
tr := &Trie{root: root, db: triedb} tr := &Trie{root: root, db: triedb}
if empty {
tr.root = nil
}
for index, key := range keys { for index, key := range keys {
tr.TryUpdate(key, values[index]) tr.TryUpdate(key, values[index])
} }

View File

@ -384,6 +384,25 @@ func TestOneElementRangeProof(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
// Test the mini trie with only a single element.
tinyTrie := new(Trie)
entry := &kv{randBytes(32), randBytes(20), false}
tinyTrie.Update(entry.k, entry.v)
first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
last = entry.k
proof = memorydb.New()
if err := tinyTrie.Prove(first, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err)
}
if err := tinyTrie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err)
}
_, _, _, _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
} }
// TestAllElementsProof tests the range proof with all elements. // TestAllElementsProof tests the range proof with all elements.