trie: fix for range proof (#21107)
* trie: fix for range proof * trie: fix typo
This commit is contained in:
		
							parent
							
								
									81e9caed7d
								
							
						
					
					
						commit
						070a5e1252
					
				| @ -219,54 +219,69 @@ func unsetInternal(n node, left []byte, right []byte) error { | ||||
| 	if len(left) != len(right) { | ||||
| 		return errors.New("inconsistent edge path") | ||||
| 	} | ||||
| 	// Step down to the fork point
 | ||||
| 	prefix, pos := prefixLen(left, right), 0 | ||||
| 	var parent node | ||||
| 	// Step down to the fork point. There are two scenarios can happen:
 | ||||
| 	// - the fork point is a shortnode: the left proof MUST point to a
 | ||||
| 	//   non-existent key and the key doesn't match with the shortnode
 | ||||
| 	// - the fork point is a fullnode: the left proof can point to an
 | ||||
| 	//   existent key or not.
 | ||||
| 	var ( | ||||
| 		pos    = 0 | ||||
| 		parent node | ||||
| 	) | ||||
| findFork: | ||||
| 	for { | ||||
| 		if pos >= prefix { | ||||
| 			break | ||||
| 		} | ||||
| 		switch rn := (n).(type) { | ||||
| 		case *shortNode: | ||||
| 			// The right proof must point to an existent key.
 | ||||
| 			if len(right)-pos < len(rn.Key) || !bytes.Equal(rn.Key, right[pos:pos+len(rn.Key)]) { | ||||
| 				return errors.New("invalid edge path") | ||||
| 			} | ||||
| 			rn.flags = nodeFlag{dirty: true} | ||||
| 			// Special case, the non-existent proof points to the same path
 | ||||
| 			// as the existent proof, but the path of existent proof is longer.
 | ||||
| 			// In this case, truncate the extra path(it should be recovered
 | ||||
| 			// by node insertion).
 | ||||
| 			// In this case, the fork point is this shortnode.
 | ||||
| 			if len(left)-pos < len(rn.Key) || !bytes.Equal(rn.Key, left[pos:pos+len(rn.Key)]) { | ||||
| 				fn := parent.(*fullNode) | ||||
| 				fn.Children[left[pos-1]] = nil | ||||
| 				return nil | ||||
| 				break findFork | ||||
| 			} | ||||
| 			rn.flags = nodeFlag{dirty: true} | ||||
| 			parent = n | ||||
| 			n, pos = rn.Val, pos+len(rn.Key) | ||||
| 		case *fullNode: | ||||
| 			leftnode, rightnode := rn.Children[left[pos]], rn.Children[right[pos]] | ||||
| 			// The right proof must point to an existent key.
 | ||||
| 			if rightnode == nil { | ||||
| 				return errors.New("invalid edge path") | ||||
| 			} | ||||
| 			rn.flags = nodeFlag{dirty: true} | ||||
| 			if leftnode != rightnode { | ||||
| 				break findFork | ||||
| 			} | ||||
| 			parent = n | ||||
| 			n, pos = rn.Children[right[pos]], pos+1 | ||||
| 			n, pos = rn.Children[left[pos]], pos+1 | ||||
| 		default: | ||||
| 			panic(fmt.Sprintf("%T: invalid node: %v", n, n)) | ||||
| 		} | ||||
| 	} | ||||
| 	fn, ok := n.(*fullNode) | ||||
| 	if !ok { | ||||
| 		return errors.New("the fork point must be a fullnode") | ||||
| 	switch rn := n.(type) { | ||||
| 	case *shortNode: | ||||
| 		if _, ok := rn.Val.(valueNode); ok { | ||||
| 			parent.(*fullNode).Children[right[pos-1]] = nil | ||||
| 			return nil | ||||
| 		} | ||||
| 		return unset(rn, rn.Val, right[pos:], len(rn.Key), true) | ||||
| 	case *fullNode: | ||||
| 		for i := left[pos] + 1; i < right[pos]; i++ { | ||||
| 			rn.Children[i] = nil | ||||
| 		} | ||||
| 		if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return nil | ||||
| 	default: | ||||
| 		panic(fmt.Sprintf("%T: invalid node: %v", n, n)) | ||||
| 	} | ||||
| 	// Find the fork point! Unset all intermediate references
 | ||||
| 	for i := left[prefix] + 1; i < right[prefix]; i++ { | ||||
| 		fn.Children[i] = nil | ||||
| 	} | ||||
| 	fn.flags = nodeFlag{dirty: true} | ||||
| 	if err := unset(fn, fn.Children[left[prefix]], left[prefix:], 1, false); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err := unset(fn, fn.Children[right[prefix]], right[prefix:], 1, true); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // unset removes all internal node references either the left most or right most.
 | ||||
| @ -314,8 +329,8 @@ func unset(parent node, child node, key []byte, pos int, removeLeft bool) error | ||||
| 				// The key of fork shortnode is less than the
 | ||||
| 				// path(it doesn't belong to the range), keep
 | ||||
| 				// it with the cached hash available.
 | ||||
| 				return nil | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| 		if _, ok := cld.Val.(valueNode); ok { | ||||
| 			fn := parent.(*fullNode) | ||||
|  | ||||
| @ -397,33 +397,35 @@ func TestAllElementsProof(t *testing.T) { | ||||
| 
 | ||||
| // TestSingleSideRangeProof tests the range starts from zero.
 | ||||
| func TestSingleSideRangeProof(t *testing.T) { | ||||
| 	trie := new(Trie) | ||||
| 	var entries entrySlice | ||||
| 	for i := 0; i < 4096; i++ { | ||||
| 		value := &kv{randBytes(32), randBytes(20), false} | ||||
| 		trie.Update(value.k, value.v) | ||||
| 		entries = append(entries, value) | ||||
| 	} | ||||
| 	sort.Sort(entries) | ||||
| 	for i := 0; i < 64; i++ { | ||||
| 		trie := new(Trie) | ||||
| 		var entries entrySlice | ||||
| 		for i := 0; i < 4096; i++ { | ||||
| 			value := &kv{randBytes(32), randBytes(20), false} | ||||
| 			trie.Update(value.k, value.v) | ||||
| 			entries = append(entries, value) | ||||
| 		} | ||||
| 		sort.Sort(entries) | ||||
| 
 | ||||
| 	var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} | ||||
| 	for _, pos := range cases { | ||||
| 		firstProof, lastProof := memorydb.New(), memorydb.New() | ||||
| 		if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil { | ||||
| 			t.Fatalf("Failed to prove the first node %v", err) | ||||
| 		} | ||||
| 		if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil { | ||||
| 			t.Fatalf("Failed to prove the first node %v", err) | ||||
| 		} | ||||
| 		k := make([][]byte, 0) | ||||
| 		v := make([][]byte, 0) | ||||
| 		for i := 0; i <= pos; i++ { | ||||
| 			k = append(k, entries[i].k) | ||||
| 			v = append(v, entries[i].v) | ||||
| 		} | ||||
| 		err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("Expected no error, got %v", err) | ||||
| 		var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} | ||||
| 		for _, pos := range cases { | ||||
| 			firstProof, lastProof := memorydb.New(), memorydb.New() | ||||
| 			if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil { | ||||
| 				t.Fatalf("Failed to prove the first node %v", err) | ||||
| 			} | ||||
| 			if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil { | ||||
| 				t.Fatalf("Failed to prove the first node %v", err) | ||||
| 			} | ||||
| 			k := make([][]byte, 0) | ||||
| 			v := make([][]byte, 0) | ||||
| 			for i := 0; i <= pos; i++ { | ||||
| 				k = append(k, entries[i].k) | ||||
| 				v = append(v, entries[i].v) | ||||
| 			} | ||||
| 			err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("Expected no error, got %v", err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user