diff --git a/trie/proof.go b/trie/proof.go index 0f7d56a64..d4a1916be 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -219,54 +219,69 @@ func unsetInternal(n node, left []byte, right []byte) error { if len(left) != len(right) { return errors.New("inconsistent edge path") } - // Step down to the fork point - prefix, pos := prefixLen(left, right), 0 - var parent node + // Step down to the fork point. There are two scenarios can happen: + // - the fork point is a shortnode: the left proof MUST point to a + // non-existent key and the key doesn't match with the shortnode + // - the fork point is a fullnode: the left proof can point to an + // existent key or not. + var ( + pos = 0 + parent node + ) +findFork: for { - if pos >= prefix { - break - } switch rn := (n).(type) { case *shortNode: + // The right proof must point to an existent key. if len(right)-pos < len(rn.Key) || !bytes.Equal(rn.Key, right[pos:pos+len(rn.Key)]) { return errors.New("invalid edge path") } + rn.flags = nodeFlag{dirty: true} // Special case, the non-existent proof points to the same path // as the existent proof, but the path of existent proof is longer. - // In this case, truncate the extra path(it should be recovered - // by node insertion). + // In this case, the fork point is this shortnode. if len(left)-pos < len(rn.Key) || !bytes.Equal(rn.Key, left[pos:pos+len(rn.Key)]) { - fn := parent.(*fullNode) - fn.Children[left[pos-1]] = nil - return nil + break findFork } - rn.flags = nodeFlag{dirty: true} parent = n n, pos = rn.Val, pos+len(rn.Key) case *fullNode: + leftnode, rightnode := rn.Children[left[pos]], rn.Children[right[pos]] + // The right proof must point to an existent key. + if rightnode == nil { + return errors.New("invalid edge path") + } rn.flags = nodeFlag{dirty: true} + if leftnode != rightnode { + break findFork + } parent = n - n, pos = rn.Children[right[pos]], pos+1 + n, pos = rn.Children[left[pos]], pos+1 default: panic(fmt.Sprintf("%T: invalid node: %v", n, n)) } } - fn, ok := n.(*fullNode) - if !ok { - return errors.New("the fork point must be a fullnode") + switch rn := n.(type) { + case *shortNode: + if _, ok := rn.Val.(valueNode); ok { + parent.(*fullNode).Children[right[pos-1]] = nil + return nil + } + return unset(rn, rn.Val, right[pos:], len(rn.Key), true) + case *fullNode: + for i := left[pos] + 1; i < right[pos]; i++ { + rn.Children[i] = nil + } + if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil { + return err + } + if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil { + return err + } + return nil + default: + panic(fmt.Sprintf("%T: invalid node: %v", n, n)) } - // Find the fork point! Unset all intermediate references - for i := left[prefix] + 1; i < right[prefix]; i++ { - fn.Children[i] = nil - } - fn.flags = nodeFlag{dirty: true} - if err := unset(fn, fn.Children[left[prefix]], left[prefix:], 1, false); err != nil { - return err - } - if err := unset(fn, fn.Children[right[prefix]], right[prefix:], 1, true); err != nil { - return err - } - return nil } // unset removes all internal node references either the left most or right most. @@ -314,8 +329,8 @@ func unset(parent node, child node, key []byte, pos int, removeLeft bool) error // The key of fork shortnode is less than the // path(it doesn't belong to the range), keep // it with the cached hash available. - return nil } + return nil } if _, ok := cld.Val.(valueNode); ok { fn := parent.(*fullNode) diff --git a/trie/proof_test.go b/trie/proof_test.go index a68503f7d..9c11d5bc5 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -397,33 +397,35 @@ func TestAllElementsProof(t *testing.T) { // TestSingleSideRangeProof tests the range starts from zero. func TestSingleSideRangeProof(t *testing.T) { - trie := new(Trie) - var entries entrySlice - for i := 0; i < 4096; i++ { - value := &kv{randBytes(32), randBytes(20), false} - trie.Update(value.k, value.v) - entries = append(entries, value) - } - sort.Sort(entries) + for i := 0; i < 64; i++ { + trie := new(Trie) + var entries entrySlice + for i := 0; i < 4096; i++ { + value := &kv{randBytes(32), randBytes(20), false} + trie.Update(value.k, value.v) + entries = append(entries, value) + } + sort.Sort(entries) - var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} - for _, pos := range cases { - firstProof, lastProof := memorydb.New(), memorydb.New() - if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil { - t.Fatalf("Failed to prove the first node %v", err) - } - if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil { - t.Fatalf("Failed to prove the first node %v", err) - } - k := make([][]byte, 0) - v := make([][]byte, 0) - for i := 0; i <= pos; i++ { - k = append(k, entries[i].k) - v = append(v, entries[i].v) - } - err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof) - if err != nil { - t.Fatalf("Expected no error, got %v", err) + var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} + for _, pos := range cases { + firstProof, lastProof := memorydb.New(), memorydb.New() + if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + k := make([][]byte, 0) + v := make([][]byte, 0) + for i := 0; i <= pos; i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } } } }