diff --git a/consensus/types/src/beacon_state/tree_hash_cache.rs b/consensus/types/src/beacon_state/tree_hash_cache.rs index 863970c27..fc14e9b18 100644 --- a/consensus/types/src/beacon_state/tree_hash_cache.rs +++ b/consensus/types/src/beacon_state/tree_hash_cache.rs @@ -3,9 +3,7 @@ #![allow(clippy::indexing_slicing)] use super::Error; -use crate::{ - BeaconState, EthSpec, Hash256, ParticipationFlags, ParticipationList, Slot, Unsigned, Validator, -}; +use crate::{BeaconState, EthSpec, Hash256, ParticipationList, Slot, Unsigned, Validator}; use cached_tree_hash::{int_log, CacheArena, CachedTreeHash, TreeHashCache}; use rayon::prelude::*; use ssz_derive::{Decode, Encode}; @@ -141,9 +139,10 @@ pub struct BeaconTreeHashCacheInner { randao_mixes: TreeHashCache, slashings: TreeHashCache, eth1_data_votes: Eth1DataVotesTreeHashCache, + inactivity_scores: OptionalTreeHashCache, // Participation caches - previous_epoch_participation: ParticipationTreeHashCache, - current_epoch_participation: ParticipationTreeHashCache, + previous_epoch_participation: OptionalTreeHashCache, + current_epoch_participation: OptionalTreeHashCache, } impl BeaconTreeHashCacheInner { @@ -168,10 +167,22 @@ impl BeaconTreeHashCacheInner { let mut slashings_arena = CacheArena::default(); let slashings = state.slashings().new_tree_hash_cache(&mut slashings_arena); - let previous_epoch_participation = - ParticipationTreeHashCache::new(state, BeaconState::previous_epoch_participation); - let current_epoch_participation = - ParticipationTreeHashCache::new(state, BeaconState::current_epoch_participation); + let inactivity_scores = OptionalTreeHashCache::new(state.inactivity_scores().ok()); + + let previous_epoch_participation = OptionalTreeHashCache::new( + state + .previous_epoch_participation() + .ok() + .map(ParticipationList::new) + .as_ref(), + ); + let current_epoch_participation = OptionalTreeHashCache::new( + state + .current_epoch_participation() + .ok() + .map(ParticipationList::new) + .as_ref(), + ); Self { previous_state: None, @@ -185,6 +196,7 @@ impl BeaconTreeHashCacheInner { balances, randao_mixes, slashings, + inactivity_scores, eth1_data_votes: Eth1DataVotesTreeHashCache::new(state), previous_epoch_participation, current_epoch_participation, @@ -287,12 +299,16 @@ impl BeaconTreeHashCacheInner { } else { hasher.write( self.previous_epoch_participation - .recalculate_tree_hash_root(state.previous_epoch_participation()?)? + .recalculate_tree_hash_root(&ParticipationList::new( + state.previous_epoch_participation()?, + ))? .as_bytes(), )?; hasher.write( self.current_epoch_participation - .recalculate_tree_hash_root(state.current_epoch_participation()?)? + .recalculate_tree_hash_root(&ParticipationList::new( + state.current_epoch_participation()?, + ))? .as_bytes(), )?; } @@ -314,8 +330,11 @@ impl BeaconTreeHashCacheInner { // Inactivity & light-client sync committees if let BeaconState::Altair(ref state) = state { - // FIXME(altair): add cache for this field - hasher.write(state.inactivity_scores.tree_hash_root().as_bytes())?; + hasher.write( + self.inactivity_scores + .recalculate_tree_hash_root(&state.inactivity_scores)? + .as_bytes(), + )?; hasher.write(state.current_sync_committee.tree_hash_root().as_bytes())?; hasher.write(state.next_sync_committee.tree_hash_root().as_bytes())?; @@ -513,53 +532,43 @@ impl ParallelValidatorTreeHash { } #[derive(Debug, PartialEq, Clone)] -pub struct ParticipationTreeHashCache { - inner: Option, +pub struct OptionalTreeHashCache { + inner: Option, } #[derive(Debug, PartialEq, Clone)] -pub struct ParticipationTreeHashCacheInner { +pub struct OptionalTreeHashCacheInner { arena: CacheArena, tree_hash_cache: TreeHashCache, } -impl ParticipationTreeHashCache { - /// Initialize a new cache for the participation list returned by `field` (if any). - fn new( - state: &BeaconState, - field: impl FnOnce( - &BeaconState, - ) -> Result< - &VariableList, - Error, - >, - ) -> Self { - let inner = field(state).map(ParticipationTreeHashCacheInner::new).ok(); +impl OptionalTreeHashCache { + /// Initialize a new cache if `item.is_some()`. + fn new>(item: Option<&C>) -> Self { + let inner = item.map(OptionalTreeHashCacheInner::new); Self { inner } } - /// Compute the tree hash root for the given `epoch_participation`. + /// Compute the tree hash root for the given `item`. /// /// This function will initialize the inner cache if necessary (e.g. when crossing the fork). - fn recalculate_tree_hash_root( + fn recalculate_tree_hash_root>( &mut self, - epoch_participation: &VariableList, + item: &C, ) -> Result { let cache = self .inner - .get_or_insert_with(|| ParticipationTreeHashCacheInner::new(epoch_participation)); - ParticipationList::new(epoch_participation) - .recalculate_tree_hash_root(&mut cache.arena, &mut cache.tree_hash_cache) + .get_or_insert_with(|| OptionalTreeHashCacheInner::new(item)); + item.recalculate_tree_hash_root(&mut cache.arena, &mut cache.tree_hash_cache) .map_err(Into::into) } } -impl ParticipationTreeHashCacheInner { - fn new(epoch_participation: &VariableList) -> Self { +impl OptionalTreeHashCacheInner { + fn new>(item: &C) -> Self { let mut arena = CacheArena::default(); - let tree_hash_cache = - ParticipationList::new(epoch_participation).new_tree_hash_cache(&mut arena); - ParticipationTreeHashCacheInner { + let tree_hash_cache = item.new_tree_hash_cache(&mut arena); + OptionalTreeHashCacheInner { arena, tree_hash_cache, } @@ -576,7 +585,7 @@ impl arbitrary::Arbitrary for BeaconTreeHashCache { #[cfg(test)] mod test { use super::*; - use crate::MainnetEthSpec; + use crate::{MainnetEthSpec, ParticipationFlags}; #[test] fn validator_node_count() { @@ -594,13 +603,13 @@ mod test { test_flag.add_flag(0).unwrap(); let epoch_participation = VariableList::<_, N>::new(vec![test_flag; len]).unwrap(); - let mut cache = ParticipationTreeHashCache { inner: None }; + let mut cache = OptionalTreeHashCache { inner: None }; let cache_root = cache - .recalculate_tree_hash_root(&epoch_participation) + .recalculate_tree_hash_root(&ParticipationList::new(&epoch_participation)) .unwrap(); let recalc_root = cache - .recalculate_tree_hash_root(&epoch_participation) + .recalculate_tree_hash_root(&ParticipationList::new(&epoch_participation)) .unwrap(); assert_eq!(cache_root, recalc_root, "recalculated root should match");