diff --git a/beacon_node/operation_pool/src/lib.rs b/beacon_node/operation_pool/src/lib.rs index f35f12f2c..6609c5ac7 100644 --- a/beacon_node/operation_pool/src/lib.rs +++ b/beacon_node/operation_pool/src/lib.rs @@ -28,8 +28,8 @@ use std::ptr; use types::{ sync_aggregate::Error as SyncAggregateError, typenum::Unsigned, Attestation, AttesterSlashing, BeaconState, BeaconStateError, ChainSpec, Epoch, EthSpec, Fork, ForkVersion, Hash256, - ProposerSlashing, RelativeEpoch, SignedVoluntaryExit, Slot, SyncAggregate, - SyncCommitteeContribution, Validator, + ProposerSlashing, SignedVoluntaryExit, Slot, SyncAggregate, SyncCommitteeContribution, + Validator, }; type SyncContributions = RwLock>>>; @@ -259,11 +259,8 @@ impl OperationPool { let prev_epoch = state.previous_epoch(); let current_epoch = state.current_epoch(); let all_attestations = self.attestations.read(); - let active_indices = state - .get_cached_active_validator_indices(RelativeEpoch::Current) - .map_err(OpPoolError::GetAttestationsTotalBalanceError)?; let total_active_balance = state - .get_total_balance(active_indices, spec) + .get_total_active_balance() .map_err(OpPoolError::GetAttestationsTotalBalanceError)?; // Split attestations for the previous & current epochs, so that we @@ -1143,10 +1140,7 @@ mod release_tests { .expect("should have valid best attestations"); assert_eq!(best_attestations.len(), max_attestations); - let active_indices = state - .get_cached_active_validator_indices(RelativeEpoch::Current) - .unwrap(); - let total_active_balance = state.get_total_balance(active_indices, spec).unwrap(); + let total_active_balance = state.get_total_active_balance().unwrap(); // Set of indices covered by previous attestations in `best_attestations`. let mut seen_indices = BTreeSet::new(); diff --git a/beacon_node/store/src/partial_beacon_state.rs b/beacon_node/store/src/partial_beacon_state.rs index cf3863c93..8cee85e90 100644 --- a/beacon_node/store/src/partial_beacon_state.rs +++ b/beacon_node/store/src/partial_beacon_state.rs @@ -300,6 +300,7 @@ macro_rules! impl_try_into_beacon_state { finalized_checkpoint: $inner.finalized_checkpoint, // Caching + total_active_balance: <_>::default(), committee_caches: <_>::default(), pubkey_cache: <_>::default(), exit_cache: <_>::default(), diff --git a/consensus/state_processing/src/per_block_processing/altair/sync_committee.rs b/consensus/state_processing/src/per_block_processing/altair/sync_committee.rs index ac1e247e3..31386a8fb 100644 --- a/consensus/state_processing/src/per_block_processing/altair/sync_committee.rs +++ b/consensus/state_processing/src/per_block_processing/altair/sync_committee.rs @@ -42,7 +42,7 @@ pub fn process_sync_aggregate( } // Compute participant and proposer rewards - let total_active_balance = state.get_total_active_balance(spec)?; + let total_active_balance = state.get_total_active_balance()?; let total_active_increments = total_active_balance.safe_div(spec.effective_balance_increment)?; let total_base_rewards = get_base_reward_per_increment(total_active_balance, spec)? diff --git a/consensus/state_processing/src/per_block_processing/process_operations.rs b/consensus/state_processing/src/per_block_processing/process_operations.rs index 8ccfd0b26..f2cef47d6 100644 --- a/consensus/state_processing/src/per_block_processing/process_operations.rs +++ b/consensus/state_processing/src/per_block_processing/process_operations.rs @@ -127,7 +127,7 @@ pub mod altair { get_attestation_participation_flag_indices(state, data, inclusion_delay, spec)?; // Update epoch participation flags. - let total_active_balance = state.get_total_active_balance(spec)?; + let total_active_balance = state.get_total_active_balance()?; let mut proposer_reward_numerator = 0; for index in &indexed_attestation.attesting_indices { let index = *index as usize; diff --git a/consensus/state_processing/src/per_epoch_processing/altair.rs b/consensus/state_processing/src/per_epoch_processing/altair.rs index a915db63c..3acece267 100644 --- a/consensus/state_processing/src/per_epoch_processing/altair.rs +++ b/consensus/state_processing/src/per_epoch_processing/altair.rs @@ -72,7 +72,7 @@ pub fn process_epoch( process_sync_committee_updates(state, spec)?; // Rotate the epoch caches to suit the epoch transition. - state.advance_caches()?; + state.advance_caches(spec)?; Ok(EpochProcessingSummary::Altair { participation_cache, diff --git a/consensus/state_processing/src/per_epoch_processing/base.rs b/consensus/state_processing/src/per_epoch_processing/base.rs index 43d96bb93..40eff3b40 100644 --- a/consensus/state_processing/src/per_epoch_processing/base.rs +++ b/consensus/state_processing/src/per_epoch_processing/base.rs @@ -66,7 +66,7 @@ pub fn process_epoch( process_participation_record_updates(state)?; // Rotate the epoch caches to suit the epoch transition. - state.advance_caches()?; + state.advance_caches(spec)?; Ok(EpochProcessingSummary::Base { total_balances: validator_statuses.total_balances, diff --git a/consensus/state_processing/src/upgrade/altair.rs b/consensus/state_processing/src/upgrade/altair.rs index 476279998..5e4fcbcf5 100644 --- a/consensus/state_processing/src/upgrade/altair.rs +++ b/consensus/state_processing/src/upgrade/altair.rs @@ -100,6 +100,7 @@ pub fn upgrade_to_altair( current_sync_committee: temp_sync_committee.clone(), // not read next_sync_committee: temp_sync_committee, // not read // Caches + total_active_balance: pre.total_active_balance, committee_caches: mem::take(&mut pre.committee_caches), pubkey_cache: mem::take(&mut pre.pubkey_cache), exit_cache: mem::take(&mut pre.exit_cache), diff --git a/consensus/types/src/beacon_state.rs b/consensus/types/src/beacon_state.rs index a78b6130c..88d088c93 100644 --- a/consensus/types/src/beacon_state.rs +++ b/consensus/types/src/beacon_state.rs @@ -86,6 +86,11 @@ pub enum Error { }, PreviousCommitteeCacheUninitialized, CurrentCommitteeCacheUninitialized, + TotalActiveBalanceCacheUninitialized, + TotalActiveBalanceCacheInconsistent { + initialized_epoch: Epoch, + current_epoch: Epoch, + }, RelativeEpochError(RelativeEpochError), ExitCacheUninitialized, CommitteeCacheUninitialized(Option), @@ -275,6 +280,13 @@ where #[tree_hash(skip_hashing)] #[test_random(default)] #[derivative(Clone(clone_with = "clone_default"))] + pub total_active_balance: Option<(Epoch, u64)>, + #[serde(skip_serializing, skip_deserializing)] + #[ssz(skip_serializing)] + #[ssz(skip_deserializing)] + #[tree_hash(skip_hashing)] + #[test_random(default)] + #[derivative(Clone(clone_with = "clone_default"))] pub committee_caches: [CommitteeCache; CACHED_EPOCHS], #[serde(skip_serializing, skip_deserializing)] #[ssz(skip_serializing)] @@ -353,6 +365,7 @@ impl BeaconState { finalized_checkpoint: Checkpoint::default(), // Caching (not in spec) + total_active_balance: None, committee_caches: [ CommitteeCache::default(), CommitteeCache::default(), @@ -1226,12 +1239,45 @@ impl BeaconState { } /// Implementation of `get_total_active_balance`, matching the spec. - pub fn get_total_active_balance(&self, spec: &ChainSpec) -> Result { + /// + /// Requires the total active balance cache to be initialised, which is initialised whenever + /// the current committee cache is. + /// + /// Returns minimum `EFFECTIVE_BALANCE_INCREMENT`, to avoid div by 0. + pub fn get_total_active_balance(&self) -> Result { + let (initialized_epoch, balance) = self + .total_active_balance() + .ok_or(Error::TotalActiveBalanceCacheUninitialized)?; + + let current_epoch = self.current_epoch(); + if initialized_epoch == current_epoch { + Ok(balance) + } else { + Err(Error::TotalActiveBalanceCacheInconsistent { + initialized_epoch, + current_epoch, + }) + } + } + + /// Build the total active balance cache. + /// + /// This function requires the current committee cache to be already built. It is called + /// automatically when `build_committee_cache` is called for the current epoch. + fn build_total_active_balance_cache(&mut self, spec: &ChainSpec) -> Result<(), Error> { // Order is irrelevant, so use the cached indices. - self.get_total_balance( + let current_epoch = self.current_epoch(); + let total_active_balance = self.get_total_balance( self.get_cached_active_validator_indices(RelativeEpoch::Current)?, spec, - ) + )?; + *self.total_active_balance_mut() = Some((current_epoch, total_active_balance)); + Ok(()) + } + + /// Set the cached total active balance to `None`, representing no known value. + pub fn drop_total_active_balance_cache(&mut self) { + *self.total_active_balance_mut() = None; } /// Get a mutable reference to the epoch participation flags for `epoch`. @@ -1294,6 +1340,7 @@ impl BeaconState { /// Drop all caches on the state. pub fn drop_all_caches(&mut self) -> Result<(), Error> { + self.drop_total_active_balance_cache(); self.drop_committee_cache(RelativeEpoch::Previous)?; self.drop_committee_cache(RelativeEpoch::Current)?; self.drop_committee_cache(RelativeEpoch::Next)?; @@ -1323,11 +1370,14 @@ impl BeaconState { .committee_cache_at_index(i)? .is_initialized_at(relative_epoch.into_epoch(self.current_epoch())); - if is_initialized { - Ok(()) - } else { - self.force_build_committee_cache(relative_epoch, spec) + if !is_initialized { + self.force_build_committee_cache(relative_epoch, spec)?; } + + if self.total_active_balance().is_none() && relative_epoch == RelativeEpoch::Current { + self.build_total_active_balance_cache(spec)?; + } + Ok(()) } /// Always builds the previous epoch cache, even if it is already initialized. @@ -1359,10 +1409,36 @@ impl BeaconState { /// /// This should be used if the `slot` of this state is advanced beyond an epoch boundary. /// - /// Note: whilst this function will preserve already-built caches, it will not build any. - pub fn advance_caches(&mut self) -> Result<(), Error> { + /// Note: this function will not build any new committee caches, but will build the total + /// balance cache if the (new) current epoch cache is initialized. + pub fn advance_caches(&mut self, spec: &ChainSpec) -> Result<(), Error> { self.committee_caches_mut().rotate_left(1); + // Re-compute total active balance for current epoch. + // + // This can only be computed once the state's effective balances have been updated + // for the current epoch. I.e. it is not possible to know this value with the same + // lookahead as the committee shuffling. + let curr = Self::committee_cache_index(RelativeEpoch::Current); + let curr_cache = mem::take(self.committee_cache_at_index_mut(curr)?); + + // If current epoch cache is initialized, compute the total active balance from its + // indices. We check that the cache is initialized at the _next_ epoch because the slot has + // not yet been advanced. + let new_current_epoch = self.next_epoch()?; + if curr_cache.is_initialized_at(new_current_epoch) { + *self.total_active_balance_mut() = Some(( + new_current_epoch, + self.get_total_balance(curr_cache.active_validator_indices(), spec)?, + )); + } + // If the cache is not initialized, then the previous cached value for the total balance is + // wrong, so delete it. + else { + self.drop_total_active_balance_cache(); + } + *self.committee_cache_at_index_mut(curr)? = curr_cache; + let next = Self::committee_cache_index(RelativeEpoch::Next); *self.committee_cache_at_index_mut(next)? = CommitteeCache::default(); Ok(()) @@ -1504,6 +1580,7 @@ impl BeaconState { }; if config.committee_caches { *res.committee_caches_mut() = self.committee_caches().clone(); + *res.total_active_balance_mut() = *self.total_active_balance(); } if config.pubkey_cache { *res.pubkey_cache_mut() = self.pubkey_cache().clone(); diff --git a/consensus/types/src/beacon_state/tests.rs b/consensus/types/src/beacon_state/tests.rs index 790113768..9e069a8ca 100644 --- a/consensus/types/src/beacon_state/tests.rs +++ b/consensus/types/src/beacon_state/tests.rs @@ -165,6 +165,9 @@ fn test_clone_config(base_state: &BeaconState, clone_config: Clon state .committee_cache(RelativeEpoch::Next) .expect("committee cache exists"); + state + .total_active_balance() + .expect("total active balance exists"); } else { state .committee_cache(RelativeEpoch::Previous) diff --git a/testing/ef_tests/src/cases/rewards.rs b/testing/ef_tests/src/cases/rewards.rs index 0cdde4c32..03444ae76 100644 --- a/testing/ef_tests/src/cases/rewards.rs +++ b/testing/ef_tests/src/cases/rewards.rs @@ -123,7 +123,7 @@ impl Case for RewardsTest { Ok(convert_all_base_deltas(&deltas)) } else { - let total_active_balance = state.get_total_active_balance(spec)?; + let total_active_balance = state.get_total_active_balance()?; let source_deltas = compute_altair_flag_deltas( &state,