trie: fix for range proof (#21107)
* trie: fix for range proof * trie: fix typo
This commit is contained in:
		
							parent
							
								
									81e9caed7d
								
							
						
					
					
						commit
						070a5e1252
					
				| @ -219,54 +219,69 @@ func unsetInternal(n node, left []byte, right []byte) error { | |||||||
| 	if len(left) != len(right) { | 	if len(left) != len(right) { | ||||||
| 		return errors.New("inconsistent edge path") | 		return errors.New("inconsistent edge path") | ||||||
| 	} | 	} | ||||||
| 	// Step down to the fork point
 | 	// Step down to the fork point. There are two scenarios can happen:
 | ||||||
| 	prefix, pos := prefixLen(left, right), 0 | 	// - the fork point is a shortnode: the left proof MUST point to a
 | ||||||
| 	var parent node | 	//   non-existent key and the key doesn't match with the shortnode
 | ||||||
|  | 	// - the fork point is a fullnode: the left proof can point to an
 | ||||||
|  | 	//   existent key or not.
 | ||||||
|  | 	var ( | ||||||
|  | 		pos    = 0 | ||||||
|  | 		parent node | ||||||
|  | 	) | ||||||
|  | findFork: | ||||||
| 	for { | 	for { | ||||||
| 		if pos >= prefix { |  | ||||||
| 			break |  | ||||||
| 		} |  | ||||||
| 		switch rn := (n).(type) { | 		switch rn := (n).(type) { | ||||||
| 		case *shortNode: | 		case *shortNode: | ||||||
|  | 			// The right proof must point to an existent key.
 | ||||||
| 			if len(right)-pos < len(rn.Key) || !bytes.Equal(rn.Key, right[pos:pos+len(rn.Key)]) { | 			if len(right)-pos < len(rn.Key) || !bytes.Equal(rn.Key, right[pos:pos+len(rn.Key)]) { | ||||||
| 				return errors.New("invalid edge path") | 				return errors.New("invalid edge path") | ||||||
| 			} | 			} | ||||||
|  | 			rn.flags = nodeFlag{dirty: true} | ||||||
| 			// Special case, the non-existent proof points to the same path
 | 			// Special case, the non-existent proof points to the same path
 | ||||||
| 			// as the existent proof, but the path of existent proof is longer.
 | 			// as the existent proof, but the path of existent proof is longer.
 | ||||||
| 			// In this case, truncate the extra path(it should be recovered
 | 			// In this case, the fork point is this shortnode.
 | ||||||
| 			// by node insertion).
 |  | ||||||
| 			if len(left)-pos < len(rn.Key) || !bytes.Equal(rn.Key, left[pos:pos+len(rn.Key)]) { | 			if len(left)-pos < len(rn.Key) || !bytes.Equal(rn.Key, left[pos:pos+len(rn.Key)]) { | ||||||
| 				fn := parent.(*fullNode) | 				break findFork | ||||||
| 				fn.Children[left[pos-1]] = nil |  | ||||||
| 				return nil |  | ||||||
| 			} | 			} | ||||||
| 			rn.flags = nodeFlag{dirty: true} |  | ||||||
| 			parent = n | 			parent = n | ||||||
| 			n, pos = rn.Val, pos+len(rn.Key) | 			n, pos = rn.Val, pos+len(rn.Key) | ||||||
| 		case *fullNode: | 		case *fullNode: | ||||||
|  | 			leftnode, rightnode := rn.Children[left[pos]], rn.Children[right[pos]] | ||||||
|  | 			// The right proof must point to an existent key.
 | ||||||
|  | 			if rightnode == nil { | ||||||
|  | 				return errors.New("invalid edge path") | ||||||
|  | 			} | ||||||
| 			rn.flags = nodeFlag{dirty: true} | 			rn.flags = nodeFlag{dirty: true} | ||||||
|  | 			if leftnode != rightnode { | ||||||
|  | 				break findFork | ||||||
|  | 			} | ||||||
| 			parent = n | 			parent = n | ||||||
| 			n, pos = rn.Children[right[pos]], pos+1 | 			n, pos = rn.Children[left[pos]], pos+1 | ||||||
| 		default: | 		default: | ||||||
| 			panic(fmt.Sprintf("%T: invalid node: %v", n, n)) | 			panic(fmt.Sprintf("%T: invalid node: %v", n, n)) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	fn, ok := n.(*fullNode) | 	switch rn := n.(type) { | ||||||
| 	if !ok { | 	case *shortNode: | ||||||
| 		return errors.New("the fork point must be a fullnode") | 		if _, ok := rn.Val.(valueNode); ok { | ||||||
|  | 			parent.(*fullNode).Children[right[pos-1]] = nil | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 		return unset(rn, rn.Val, right[pos:], len(rn.Key), true) | ||||||
|  | 	case *fullNode: | ||||||
|  | 		for i := left[pos] + 1; i < right[pos]; i++ { | ||||||
|  | 			rn.Children[i] = nil | ||||||
|  | 		} | ||||||
|  | 		if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	default: | ||||||
|  | 		panic(fmt.Sprintf("%T: invalid node: %v", n, n)) | ||||||
| 	} | 	} | ||||||
| 	// Find the fork point! Unset all intermediate references
 |  | ||||||
| 	for i := left[prefix] + 1; i < right[prefix]; i++ { |  | ||||||
| 		fn.Children[i] = nil |  | ||||||
| 	} |  | ||||||
| 	fn.flags = nodeFlag{dirty: true} |  | ||||||
| 	if err := unset(fn, fn.Children[left[prefix]], left[prefix:], 1, false); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	if err := unset(fn, fn.Children[right[prefix]], right[prefix:], 1, true); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // unset removes all internal node references either the left most or right most.
 | // unset removes all internal node references either the left most or right most.
 | ||||||
| @ -314,8 +329,8 @@ func unset(parent node, child node, key []byte, pos int, removeLeft bool) error | |||||||
| 				// The key of fork shortnode is less than the
 | 				// The key of fork shortnode is less than the
 | ||||||
| 				// path(it doesn't belong to the range), keep
 | 				// path(it doesn't belong to the range), keep
 | ||||||
| 				// it with the cached hash available.
 | 				// it with the cached hash available.
 | ||||||
| 				return nil |  | ||||||
| 			} | 			} | ||||||
|  | 			return nil | ||||||
| 		} | 		} | ||||||
| 		if _, ok := cld.Val.(valueNode); ok { | 		if _, ok := cld.Val.(valueNode); ok { | ||||||
| 			fn := parent.(*fullNode) | 			fn := parent.(*fullNode) | ||||||
|  | |||||||
| @ -397,33 +397,35 @@ func TestAllElementsProof(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| // TestSingleSideRangeProof tests the range starts from zero.
 | // TestSingleSideRangeProof tests the range starts from zero.
 | ||||||
| func TestSingleSideRangeProof(t *testing.T) { | func TestSingleSideRangeProof(t *testing.T) { | ||||||
| 	trie := new(Trie) | 	for i := 0; i < 64; i++ { | ||||||
| 	var entries entrySlice | 		trie := new(Trie) | ||||||
| 	for i := 0; i < 4096; i++ { | 		var entries entrySlice | ||||||
| 		value := &kv{randBytes(32), randBytes(20), false} | 		for i := 0; i < 4096; i++ { | ||||||
| 		trie.Update(value.k, value.v) | 			value := &kv{randBytes(32), randBytes(20), false} | ||||||
| 		entries = append(entries, value) | 			trie.Update(value.k, value.v) | ||||||
| 	} | 			entries = append(entries, value) | ||||||
| 	sort.Sort(entries) | 		} | ||||||
|  | 		sort.Sort(entries) | ||||||
| 
 | 
 | ||||||
| 	var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} | 		var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} | ||||||
| 	for _, pos := range cases { | 		for _, pos := range cases { | ||||||
| 		firstProof, lastProof := memorydb.New(), memorydb.New() | 			firstProof, lastProof := memorydb.New(), memorydb.New() | ||||||
| 		if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil { | 			if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil { | ||||||
| 			t.Fatalf("Failed to prove the first node %v", err) | 				t.Fatalf("Failed to prove the first node %v", err) | ||||||
| 		} | 			} | ||||||
| 		if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil { | 			if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil { | ||||||
| 			t.Fatalf("Failed to prove the first node %v", err) | 				t.Fatalf("Failed to prove the first node %v", err) | ||||||
| 		} | 			} | ||||||
| 		k := make([][]byte, 0) | 			k := make([][]byte, 0) | ||||||
| 		v := make([][]byte, 0) | 			v := make([][]byte, 0) | ||||||
| 		for i := 0; i <= pos; i++ { | 			for i := 0; i <= pos; i++ { | ||||||
| 			k = append(k, entries[i].k) | 				k = append(k, entries[i].k) | ||||||
| 			v = append(v, entries[i].v) | 				v = append(v, entries[i].v) | ||||||
| 		} | 			} | ||||||
| 		err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof) | 			err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof) | ||||||
| 		if err != nil { | 			if err != nil { | ||||||
| 			t.Fatalf("Expected no error, got %v", err) | 				t.Fatalf("Expected no error, got %v", err) | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user