diff --git a/eth2/utils/ssz/src/cached_tree_hash/resize.rs b/eth2/utils/ssz/src/cached_tree_hash/resize.rs index 3c2d2c407..0b492770f 100644 --- a/eth2/utils/ssz/src/cached_tree_hash/resize.rs +++ b/eth2/utils/ssz/src/cached_tree_hash/resize.rs @@ -62,6 +62,52 @@ pub fn grow_merkle_cache( Some((bytes, flags)) } +/// New vec is smaller than old vec. +pub fn shrink_merkle_cache( + from_bytes: &[u8], + from_flags: &[bool], + from_height: usize, + to_height: usize, + to_nodes: usize, +) -> Option<(Vec, Vec)> { + let mut bytes = vec![0; to_nodes * HASHSIZE]; + let mut flags = vec![true; to_nodes]; + + let leaf_level = to_height; + + for i in 0..=leaf_level as usize { + let from_i = i + from_height - to_height; + let (from_byte_slice, from_flag_slice) = if from_i == leaf_level { + ( + from_bytes.get(first_byte_at_height(from_i)..)?, + from_flags.get(first_node_at_height(from_i)..)?, + ) + } else { + ( + from_bytes.get(byte_range_at_height(from_i))?, + from_flags.get(node_range_at_height(from_i))?, + ) + }; + + let (to_byte_slice, to_flag_slice) = if i == leaf_level { + ( + bytes.get_mut(first_byte_at_height(i)..)?, + flags.get_mut(first_node_at_height(i)..)?, + ) + } else { + ( + bytes.get_mut(byte_range_at_height(i))?, + flags.get_mut(node_range_at_height(i))?, + ) + }; + + to_byte_slice.copy_from_slice(from_byte_slice.get(0..to_byte_slice.len())?); + to_flag_slice.copy_from_slice(from_flag_slice.get(0..to_flag_slice.len())?); + } + + Some((bytes, flags)) +} + fn nodes_in_tree_of_height(h: usize) -> usize { 2 * (1 << h) - 1 } @@ -92,18 +138,18 @@ mod test { use super::*; #[test] - fn can_grow_three_levels() { - let from: usize = 1; - let to: usize = 15; + fn can_grow_and_shrink_three_levels() { + let small: usize = 1; + let big: usize = 15; - let old_bytes = vec![42; from * HASHSIZE]; - let old_flags = vec![false; from]; + let original_bytes = vec![42; small * HASHSIZE]; + let original_flags = vec![false; small]; - let (new_bytes, new_flags) = grow_merkle_cache( - &old_bytes, - &old_flags, - (from + 1).trailing_zeros() as usize - 1, - (to + 1).trailing_zeros() as usize - 1, + let (grown_bytes, grown_flags) = grow_merkle_cache( + &original_bytes, + &original_flags, + (small + 1).trailing_zeros() as usize - 1, + (big + 1).trailing_zeros() as usize - 1, ) .unwrap(); @@ -144,23 +190,35 @@ mod test { expected_flags.push(true); expected_flags.push(true); - assert_eq!(expected_bytes, new_bytes); - assert_eq!(expected_flags, new_flags); + assert_eq!(expected_bytes, grown_bytes); + assert_eq!(expected_flags, grown_flags); + + let (shrunk_bytes, shrunk_flags) = shrink_merkle_cache( + &grown_bytes, + &grown_flags, + (big + 1).trailing_zeros() as usize - 1, + (small + 1).trailing_zeros() as usize - 1, + small, + ) + .unwrap(); + + assert_eq!(original_bytes, shrunk_bytes); + assert_eq!(original_flags, shrunk_flags); } #[test] - fn can_grow_one_level() { - let from: usize = 7; - let to: usize = 15; + fn can_grow_and_shrink_one_level() { + let small: usize = 7; + let big: usize = 15; - let old_bytes = vec![42; from * HASHSIZE]; - let old_flags = vec![false; from]; + let original_bytes = vec![42; small * HASHSIZE]; + let original_flags = vec![false; small]; - let (new_bytes, new_flags) = grow_merkle_cache( - &old_bytes, - &old_flags, - (from + 1).trailing_zeros() as usize - 1, - (to + 1).trailing_zeros() as usize - 1, + let (grown_bytes, grown_flags) = grow_merkle_cache( + &original_bytes, + &original_flags, + (small + 1).trailing_zeros() as usize - 1, + (big + 1).trailing_zeros() as usize - 1, ) .unwrap(); @@ -201,7 +259,19 @@ mod test { expected_flags.push(true); expected_flags.push(true); - assert_eq!(expected_bytes, new_bytes); - assert_eq!(expected_flags, new_flags); + assert_eq!(expected_bytes, grown_bytes); + assert_eq!(expected_flags, grown_flags); + + let (shrunk_bytes, shrunk_flags) = shrink_merkle_cache( + &grown_bytes, + &grown_flags, + (big + 1).trailing_zeros() as usize - 1, + (small + 1).trailing_zeros() as usize - 1, + small, + ) + .unwrap(); + + assert_eq!(original_bytes, shrunk_bytes); + assert_eq!(original_flags, shrunk_flags); } }