From 3412a3ec54c787782717e88dd15d19f9c4494c8b Mon Sep 17 00:00:00 2001 From: Michael Sproul Date: Fri, 25 Sep 2020 05:18:21 +0000 Subject: [PATCH] Remove saturating arith from state_processing (#1644) ## Issue Addressed Resolves #1100 ## Proposed Changes * Implement the `SafeArith` trait for `Slot` and `Epoch`, so that methods like `safe_add` become available. * Tweak the `SafeArith` trait to allow a different `Rhs` type (analagous to `std::ops::Add`, etc). * Add a `legacy-arith` feature to `types` and `state_processing` that conditionally enables implementations of the `std` ops with saturating semantics. * Check compilation of `types` and `state_processing` _without_ `legacy-arith` on CI, thus guaranteeing that they only use the `SafeArith` primitives :tada: ## Additional Info The `legacy-arith` feature gets turned on by all higher-level crates that depend on `state_processing` or `types`, thus allowing the beacon chain, networking, and other components to continue to rely on the availability of ops like `+`, `-`, `*`, etc. **This is a consensus-breaking change**, but brings us in line with the spec, and our incompatibilities shouldn't have been reachable with any valid configuration of Eth2 parameters. --- .github/workflows/test-suite.yml | 8 ++ Makefile | 4 + consensus/safe_arith/src/lib.rs | 30 ++-- consensus/state_processing/Cargo.toml | 4 +- .../src/common/deposit_data_tree.rs | 2 +- .../src/common/initiate_validator_exit.rs | 7 +- .../src/common/slash_validator.rs | 2 +- .../src/per_block_processing.rs | 4 +- .../verify_attestation.rs | 5 +- .../src/per_block_processing/verify_exit.rs | 8 +- .../src/per_epoch_processing.rs | 19 ++- .../src/per_epoch_processing/apply_rewards.rs | 5 +- .../per_epoch_processing/process_slashings.rs | 2 +- .../per_epoch_processing/registry_updates.rs | 8 +- .../src/per_slot_processing.rs | 17 ++- consensus/types/Cargo.toml | 4 +- consensus/types/src/beacon_state.rs | 49 ++++--- .../types/src/beacon_state/committee_cache.rs | 3 +- .../src/beacon_state/committee_cache/tests.rs | 8 +- .../types/src/beacon_state/exit_cache.rs | 2 +- consensus/types/src/relative_epoch.rs | 17 ++- consensus/types/src/slot_epoch.rs | 19 +-- consensus/types/src/slot_epoch_macros.rs | 134 ++++++++++++------ .../testing_attestation_data_builder.rs | 14 +- .../builders/testing_beacon_block_builder.rs | 10 +- .../builders/testing_beacon_state_builder.rs | 6 +- 26 files changed, 250 insertions(+), 141 deletions(-) diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index eb422e039..6c9b19d4f 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -115,6 +115,14 @@ jobs: - uses: actions/checkout@v1 - name: Typecheck benchmark code without running it run: make check-benches + check-consensus: + name: check-consensus + runs-on: ubuntu-latest + needs: cargo-fmt + steps: + - uses: actions/checkout@v1 + - name: Typecheck consensus code in strict mode + run: make check-consensus clippy: name: clippy runs-on: ubuntu-latest diff --git a/Makefile b/Makefile index 6fadf222f..39c2afaaa 100644 --- a/Makefile +++ b/Makefile @@ -93,6 +93,10 @@ cargo-fmt: check-benches: cargo check --all --benches +# Typechecks consensus code *without* allowing deprecated legacy arithmetic +check-consensus: + cargo check --manifest-path=consensus/state_processing/Cargo.toml --no-default-features + # Runs only the ef-test vectors. run-ef-tests: cargo test --release --manifest-path=$(EF_TESTS)/Cargo.toml --features "ef_tests" diff --git a/consensus/safe_arith/src/lib.rs b/consensus/safe_arith/src/lib.rs index 227568210..ab5985a6e 100644 --- a/consensus/safe_arith/src/lib.rs +++ b/consensus/safe_arith/src/lib.rs @@ -28,24 +28,24 @@ macro_rules! assign_method { } /// Trait providing safe arithmetic operations for built-in types. -pub trait SafeArith: Sized + Copy { +pub trait SafeArith: Sized + Copy { const ZERO: Self; const ONE: Self; /// Safe variant of `+` that guards against overflow. - fn safe_add(&self, other: Self) -> Result; + fn safe_add(&self, other: Rhs) -> Result; /// Safe variant of `-` that guards against overflow. - fn safe_sub(&self, other: Self) -> Result; + fn safe_sub(&self, other: Rhs) -> Result; /// Safe variant of `*` that guards against overflow. - fn safe_mul(&self, other: Self) -> Result; + fn safe_mul(&self, other: Rhs) -> Result; /// Safe variant of `/` that guards against division by 0. - fn safe_div(&self, other: Self) -> Result; + fn safe_div(&self, other: Rhs) -> Result; /// Safe variant of `%` that guards against division by 0. - fn safe_rem(&self, other: Self) -> Result; + fn safe_rem(&self, other: Rhs) -> Result; /// Safe variant of `<<` that guards against overflow. fn safe_shl(&self, other: u32) -> Result; @@ -53,18 +53,13 @@ pub trait SafeArith: Sized + Copy { /// Safe variant of `>>` that guards against overflow. fn safe_shr(&self, other: u32) -> Result; - assign_method!(safe_add_assign, safe_add, "+="); - assign_method!(safe_sub_assign, safe_sub, "-="); - assign_method!(safe_mul_assign, safe_mul, "*="); - assign_method!(safe_div_assign, safe_div, "/="); - assign_method!(safe_rem_assign, safe_rem, "%="); + assign_method!(safe_add_assign, safe_add, Rhs, "+="); + assign_method!(safe_sub_assign, safe_sub, Rhs, "-="); + assign_method!(safe_mul_assign, safe_mul, Rhs, "*="); + assign_method!(safe_div_assign, safe_div, Rhs, "/="); + assign_method!(safe_rem_assign, safe_rem, Rhs, "%="); assign_method!(safe_shl_assign, safe_shl, u32, "<<="); assign_method!(safe_shr_assign, safe_shr, u32, ">>="); - - /// Mutate `self` by adding 1, erroring on overflow. - fn increment(&mut self) -> Result<()> { - self.safe_add_assign(Self::ONE) - } } macro_rules! impl_safe_arith { @@ -136,8 +131,7 @@ mod test { #[test] fn mutate() { let mut x = 0u8; - x.increment().unwrap(); - x.increment().unwrap(); + x.safe_add_assign(2).unwrap(); assert_eq!(x, 2); x.safe_sub_assign(1).unwrap(); assert_eq!(x, 1); diff --git a/consensus/state_processing/Cargo.toml b/consensus/state_processing/Cargo.toml index bae0705c8..bd0de6c19 100644 --- a/consensus/state_processing/Cargo.toml +++ b/consensus/state_processing/Cargo.toml @@ -27,14 +27,16 @@ log = "0.4.8" safe_arith = { path = "../safe_arith" } tree_hash = "0.1.0" tree_hash_derive = "0.2.0" -types = { path = "../types" } +types = { path = "../types", default-features = false } rayon = "1.3.0" eth2_hashing = "0.1.0" int_to_bytes = { path = "../int_to_bytes" } arbitrary = { version = "0.4.4", features = ["derive"], optional = true } [features] +default = ["legacy-arith"] fake_crypto = ["bls/fake_crypto"] +legacy-arith = ["types/legacy-arith"] arbitrary-fuzz = [ "arbitrary", "types/arbitrary-fuzz", diff --git a/consensus/state_processing/src/common/deposit_data_tree.rs b/consensus/state_processing/src/common/deposit_data_tree.rs index 319c437ee..46f1ed8cc 100644 --- a/consensus/state_processing/src/common/deposit_data_tree.rs +++ b/consensus/state_processing/src/common/deposit_data_tree.rs @@ -47,7 +47,7 @@ impl DepositDataTree { /// Add a deposit to the merkle tree. pub fn push_leaf(&mut self, leaf: Hash256) -> Result<(), MerkleTreeError> { self.tree.push_leaf(leaf, self.depth)?; - self.mix_in_length.increment()?; + self.mix_in_length.safe_add_assign(1)?; Ok(()) } } diff --git a/consensus/state_processing/src/common/initiate_validator_exit.rs b/consensus/state_processing/src/common/initiate_validator_exit.rs index 00cd02de1..3d2638a35 100644 --- a/consensus/state_processing/src/common/initiate_validator_exit.rs +++ b/consensus/state_processing/src/common/initiate_validator_exit.rs @@ -1,3 +1,4 @@ +use safe_arith::SafeArith; use std::cmp::max; use types::{BeaconStateError as Error, *}; @@ -22,7 +23,7 @@ pub fn initiate_validator_exit( state.exit_cache.build(&state.validators, spec)?; // Compute exit queue epoch - let delayed_epoch = state.compute_activation_exit_epoch(state.current_epoch(), spec); + let delayed_epoch = state.compute_activation_exit_epoch(state.current_epoch(), spec)?; let mut exit_queue_epoch = state .exit_cache .max_epoch()? @@ -30,13 +31,13 @@ pub fn initiate_validator_exit( let exit_queue_churn = state.exit_cache.get_churn_at(exit_queue_epoch)?; if exit_queue_churn >= state.get_churn_limit(spec)? { - exit_queue_epoch += 1; + exit_queue_epoch.safe_add_assign(1)?; } state.exit_cache.record_validator_exit(exit_queue_epoch)?; state.validators[index].exit_epoch = exit_queue_epoch; state.validators[index].withdrawable_epoch = - exit_queue_epoch + spec.min_validator_withdrawability_delay; + exit_queue_epoch.safe_add(spec.min_validator_withdrawability_delay)?; Ok(()) } diff --git a/consensus/state_processing/src/common/slash_validator.rs b/consensus/state_processing/src/common/slash_validator.rs index 754534f0c..0b0874819 100644 --- a/consensus/state_processing/src/common/slash_validator.rs +++ b/consensus/state_processing/src/common/slash_validator.rs @@ -23,7 +23,7 @@ pub fn slash_validator( state.validators[slashed_index].slashed = true; state.validators[slashed_index].withdrawable_epoch = cmp::max( state.validators[slashed_index].withdrawable_epoch, - epoch + Epoch::from(T::EpochsPerSlashingsVector::to_u64()), + epoch.safe_add(T::EpochsPerSlashingsVector::to_u64())?, ); let validator_effective_balance = state.get_effective_balance(slashed_index, spec)?; state.set_slashings( diff --git a/consensus/state_processing/src/per_block_processing.rs b/consensus/state_processing/src/per_block_processing.rs index 9bedba23e..e6c57beb2 100644 --- a/consensus/state_processing/src/per_block_processing.rs +++ b/consensus/state_processing/src/per_block_processing.rs @@ -368,7 +368,7 @@ pub fn process_attestations( let pending_attestation = PendingAttestation { aggregation_bits: attestation.aggregation_bits.clone(), data: attestation.data.clone(), - inclusion_delay: (state.slot - attestation.data.slot).as_u64(), + inclusion_delay: state.slot.safe_sub(attestation.data.slot)?.as_u64(), proposer_index, }; @@ -444,7 +444,7 @@ pub fn process_deposit( .map_err(|e| e.into_with_index(deposit_index))?; } - state.eth1_deposit_index.increment()?; + state.eth1_deposit_index.safe_add_assign(1)?; // Get an `Option` where `u64` is the validator index if this deposit public key // already exists in the beacon_state. diff --git a/consensus/state_processing/src/per_block_processing/verify_attestation.rs b/consensus/state_processing/src/per_block_processing/verify_attestation.rs index 3ab962e2e..678ba28e1 100644 --- a/consensus/state_processing/src/per_block_processing/verify_attestation.rs +++ b/consensus/state_processing/src/per_block_processing/verify_attestation.rs @@ -2,6 +2,7 @@ use super::errors::{AttestationInvalid as Invalid, BlockOperationError}; use super::VerifySignatures; use crate::common::get_indexed_attestation; use crate::per_block_processing::is_valid_indexed_attestation; +use safe_arith::SafeArith; use types::*; type Result = std::result::Result>; @@ -25,7 +26,7 @@ pub fn verify_attestation_for_block_inclusion( let data = &attestation.data; verify!( - data.slot + spec.min_attestation_inclusion_delay <= state.slot, + data.slot.safe_add(spec.min_attestation_inclusion_delay)? <= state.slot, Invalid::IncludedTooEarly { state: state.slot, delay: spec.min_attestation_inclusion_delay, @@ -33,7 +34,7 @@ pub fn verify_attestation_for_block_inclusion( } ); verify!( - state.slot <= data.slot + T::slots_per_epoch(), + state.slot <= data.slot.safe_add(T::slots_per_epoch())?, Invalid::IncludedTooLate { state: state.slot, attestation: data.slot, diff --git a/consensus/state_processing/src/per_block_processing/verify_exit.rs b/consensus/state_processing/src/per_block_processing/verify_exit.rs index c77ffe536..16c4db221 100644 --- a/consensus/state_processing/src/per_block_processing/verify_exit.rs +++ b/consensus/state_processing/src/per_block_processing/verify_exit.rs @@ -3,6 +3,7 @@ use crate::per_block_processing::{ signature_sets::{exit_signature_set, get_pubkey_from_state}, VerifySignatures, }; +use safe_arith::SafeArith; use types::*; type Result = std::result::Result>; @@ -77,11 +78,14 @@ fn verify_exit_parametric( ); // Verify the validator has been active long enough. + let earliest_exit_epoch = validator + .activation_epoch + .safe_add(spec.shard_committee_period)?; verify!( - state.current_epoch() >= validator.activation_epoch + spec.shard_committee_period, + state.current_epoch() >= earliest_exit_epoch, ExitInvalid::TooYoungToExit { current_epoch: state.current_epoch(), - earliest_exit_epoch: validator.activation_epoch + spec.shard_committee_period, + earliest_exit_epoch, } ); diff --git a/consensus/state_processing/src/per_epoch_processing.rs b/consensus/state_processing/src/per_epoch_processing.rs index 0321bce35..19b87aa57 100644 --- a/consensus/state_processing/src/per_epoch_processing.rs +++ b/consensus/state_processing/src/per_epoch_processing.rs @@ -84,7 +84,7 @@ pub fn process_justification_and_finalization( state: &mut BeaconState, total_balances: &TotalBalances, ) -> Result<(), Error> { - if state.current_epoch() <= T::genesis_epoch() + 1 { + if state.current_epoch() <= T::genesis_epoch().safe_add(1)? { return Ok(()); } @@ -126,25 +126,25 @@ pub fn process_justification_and_finalization( // The 2nd/3rd/4th most recent epochs are all justified, the 2nd using the 4th as source. if (1..4).all(|i| bits.get(i).unwrap_or(false)) - && old_previous_justified_checkpoint.epoch + 3 == current_epoch + && old_previous_justified_checkpoint.epoch.safe_add(3)? == current_epoch { state.finalized_checkpoint = old_previous_justified_checkpoint; } // The 2nd/3rd most recent epochs are both justified, the 2nd using the 3rd as source. else if (1..3).all(|i| bits.get(i).unwrap_or(false)) - && old_previous_justified_checkpoint.epoch + 2 == current_epoch + && old_previous_justified_checkpoint.epoch.safe_add(2)? == current_epoch { state.finalized_checkpoint = old_previous_justified_checkpoint; } // The 1st/2nd/3rd most recent epochs are all justified, the 1st using the 3nd as source. if (0..3).all(|i| bits.get(i).unwrap_or(false)) - && old_current_justified_checkpoint.epoch + 2 == current_epoch + && old_current_justified_checkpoint.epoch.safe_add(2)? == current_epoch { state.finalized_checkpoint = old_current_justified_checkpoint; } // The 1st/2nd most recent epochs are both justified, the 1st using the 2nd as source. else if (0..2).all(|i| bits.get(i).unwrap_or(false)) - && old_current_justified_checkpoint.epoch + 1 == current_epoch + && old_current_justified_checkpoint.epoch.safe_add(1)? == current_epoch { state.finalized_checkpoint = old_current_justified_checkpoint; } @@ -160,10 +160,15 @@ pub fn process_final_updates( spec: &ChainSpec, ) -> Result<(), Error> { let current_epoch = state.current_epoch(); - let next_epoch = state.next_epoch(); + let next_epoch = state.next_epoch()?; // Reset eth1 data votes. - if (state.slot + 1) % T::SlotsPerEth1VotingPeriod::to_u64() == 0 { + if state + .slot + .safe_add(1)? + .safe_rem(T::SlotsPerEth1VotingPeriod::to_u64())? + == 0 + { state.eth1_data_votes = VariableList::empty(); } diff --git a/consensus/state_processing/src/per_epoch_processing/apply_rewards.rs b/consensus/state_processing/src/per_epoch_processing/apply_rewards.rs index 18c946520..4115bfef3 100644 --- a/consensus/state_processing/src/per_epoch_processing/apply_rewards.rs +++ b/consensus/state_processing/src/per_epoch_processing/apply_rewards.rs @@ -71,7 +71,10 @@ fn get_attestation_deltas( validator_statuses: &ValidatorStatuses, spec: &ChainSpec, ) -> Result, Error> { - let finality_delay = (state.previous_epoch() - state.finalized_checkpoint.epoch).as_u64(); + let finality_delay = state + .previous_epoch() + .safe_sub(state.finalized_checkpoint.epoch)? + .as_u64(); let mut deltas = vec![Delta::default(); state.validators.len()]; diff --git a/consensus/state_processing/src/per_epoch_processing/process_slashings.rs b/consensus/state_processing/src/per_epoch_processing/process_slashings.rs index 4901d3030..4a8f8120d 100644 --- a/consensus/state_processing/src/per_epoch_processing/process_slashings.rs +++ b/consensus/state_processing/src/per_epoch_processing/process_slashings.rs @@ -14,7 +14,7 @@ pub fn process_slashings( for (index, validator) in state.validators.iter().enumerate() { if validator.slashed - && epoch + T::EpochsPerSlashingsVector::to_u64().safe_div(2)? + && epoch.safe_add(T::EpochsPerSlashingsVector::to_u64().safe_div(2)?)? == validator.withdrawable_epoch { let increment = spec.effective_balance_increment; diff --git a/consensus/state_processing/src/per_epoch_processing/registry_updates.rs b/consensus/state_processing/src/per_epoch_processing/registry_updates.rs index 79ece8c60..26f055ba4 100644 --- a/consensus/state_processing/src/per_epoch_processing/registry_updates.rs +++ b/consensus/state_processing/src/per_epoch_processing/registry_updates.rs @@ -1,6 +1,6 @@ -use super::super::common::initiate_validator_exit; -use super::Error; +use crate::{common::initiate_validator_exit, per_epoch_processing::Error}; use itertools::Itertools; +use safe_arith::SafeArith; use types::*; /// Performs a validator registry update, if required. @@ -31,7 +31,7 @@ pub fn process_registry_updates( for index in indices_to_update { if state.validators[index].is_eligible_for_activation_queue(spec) { - state.validators[index].activation_eligibility_epoch = current_epoch + 1; + state.validators[index].activation_eligibility_epoch = current_epoch.safe_add(1)?; } if is_ejectable(&state.validators[index]) { initiate_validator_exit(state, index, spec)?; @@ -50,7 +50,7 @@ pub fn process_registry_updates( // Dequeue validators for activation up to churn limit let churn_limit = state.get_churn_limit(spec)? as usize; - let delayed_activation_epoch = state.compute_activation_exit_epoch(current_epoch, spec); + let delayed_activation_epoch = state.compute_activation_exit_epoch(current_epoch, spec)?; for index in activation_queue.into_iter().take(churn_limit) { let validator = &mut state.validators[index]; validator.activation_epoch = delayed_activation_epoch; diff --git a/consensus/state_processing/src/per_slot_processing.rs b/consensus/state_processing/src/per_slot_processing.rs index 02acfc825..a818bde52 100644 --- a/consensus/state_processing/src/per_slot_processing.rs +++ b/consensus/state_processing/src/per_slot_processing.rs @@ -1,10 +1,18 @@ use crate::{per_epoch_processing::EpochProcessingSummary, *}; +use safe_arith::{ArithError, SafeArith}; use types::*; #[derive(Debug, PartialEq)] pub enum Error { BeaconStateError(BeaconStateError), EpochProcessingError(EpochProcessingError), + ArithError(ArithError), +} + +impl From for Error { + fn from(e: ArithError) -> Self { + Self::ArithError(e) + } } /// Advances a state forward by one slot, performing per-epoch processing if required. @@ -21,14 +29,15 @@ pub fn per_slot_processing( ) -> Result, Error> { cache_state(state, state_root)?; - let summary = if state.slot > spec.genesis_slot && (state.slot + 1) % T::slots_per_epoch() == 0 + let summary = if state.slot > spec.genesis_slot + && state.slot.safe_add(1)?.safe_rem(T::slots_per_epoch())? == 0 { Some(per_epoch_processing(state, spec)?) } else { None }; - state.slot += 1; + state.slot.safe_add_assign(1)?; Ok(summary) } @@ -48,7 +57,7 @@ fn cache_state( // // This is a bit hacky, however it gets the job safely without lots of code. let previous_slot = state.slot; - state.slot += 1; + state.slot.safe_add_assign(1)?; // Store the previous slot's post state transition root. state.set_state_root(previous_slot, previous_state_root)?; @@ -63,7 +72,7 @@ fn cache_state( state.set_block_root(previous_slot, latest_block_root)?; // Set the state slot back to what it should be. - state.slot -= 1; + state.slot.safe_sub_assign(1)?; Ok(()) } diff --git a/consensus/types/Cargo.toml b/consensus/types/Cargo.toml index d893ff3ad..8f6fed4b4 100644 --- a/consensus/types/Cargo.toml +++ b/consensus/types/Cargo.toml @@ -46,7 +46,9 @@ serde_json = "1.0.52" criterion = "0.3.2" [features] -default = ["sqlite"] +default = ["sqlite", "legacy-arith"] +# Allow saturating arithmetic on slots and epochs. Enabled by default, but deprecated. +legacy-arith = [] sqlite = ["rusqlite"] arbitrary-fuzz = [ "arbitrary", diff --git a/consensus/types/src/beacon_state.rs b/consensus/types/src/beacon_state.rs index a335d0492..a2d923da9 100644 --- a/consensus/types/src/beacon_state.rs +++ b/consensus/types/src/beacon_state.rs @@ -100,10 +100,10 @@ enum AllowNextEpoch { } impl AllowNextEpoch { - fn upper_bound_of(self, current_epoch: Epoch) -> Epoch { + fn upper_bound_of(self, current_epoch: Epoch) -> Result { match self { - AllowNextEpoch::True => current_epoch + 1, - AllowNextEpoch::False => current_epoch, + AllowNextEpoch::True => Ok(current_epoch.safe_add(1)?), + AllowNextEpoch::False => Ok(current_epoch), } } } @@ -323,7 +323,9 @@ impl BeaconState { pub fn previous_epoch(&self) -> Epoch { let current_epoch = self.current_epoch(); if current_epoch > T::genesis_epoch() { - current_epoch - 1 + current_epoch + .safe_sub(1) + .expect("current epoch greater than genesis implies greater than 0") } else { current_epoch } @@ -332,8 +334,8 @@ impl BeaconState { /// The epoch following `self.current_epoch()`. /// /// Spec v0.12.1 - pub fn next_epoch(&self) -> Epoch { - self.current_epoch() + 1 + pub fn next_epoch(&self) -> Result { + Ok(self.current_epoch().safe_add(1)?) } /// Compute the number of committees at `slot`. @@ -378,7 +380,7 @@ impl BeaconState { epoch: Epoch, spec: &ChainSpec, ) -> Result, Error> { - if epoch >= self.compute_activation_exit_epoch(self.current_epoch(), spec) { + if epoch >= self.compute_activation_exit_epoch(self.current_epoch(), spec)? { Err(BeaconStateError::EpochOutOfBounds) } else { Ok(get_active_validator_indices(&self.validators, epoch)) @@ -475,7 +477,7 @@ impl BeaconState { { return Ok(candidate_index); } - i.increment()?; + i.safe_add_assign(1)?; } } @@ -553,7 +555,7 @@ impl BeaconState { /// /// Spec v0.12.1 fn get_latest_block_roots_index(&self, slot: Slot) -> Result { - if slot < self.slot && self.slot <= slot + self.block_roots.len() as u64 { + if slot < self.slot && self.slot <= slot.safe_add(self.block_roots.len() as u64)? { Ok(slot.as_usize().safe_rem(self.block_roots.len())?) } else { Err(BeaconStateError::SlotOutOfBounds) @@ -605,7 +607,9 @@ impl BeaconState { let current_epoch = self.current_epoch(); let len = T::EpochsPerHistoricalVector::to_u64(); - if current_epoch < epoch + len && epoch <= allow_next_epoch.upper_bound_of(current_epoch) { + if current_epoch < epoch.safe_add(len)? + && epoch <= allow_next_epoch.upper_bound_of(current_epoch)? + { Ok(epoch.as_usize().safe_rem(len as usize)?) } else { Err(Error::EpochOutOfBounds) @@ -652,7 +656,7 @@ impl BeaconState { /// /// Spec v0.12.1 fn get_latest_state_roots_index(&self, slot: Slot) -> Result { - if slot < self.slot && self.slot <= slot + self.state_roots.len() as u64 { + if slot < self.slot && self.slot <= slot.safe_add(self.state_roots.len() as u64)? { Ok(slot.as_usize().safe_rem(self.state_roots.len())?) } else { Err(BeaconStateError::SlotOutOfBounds) @@ -672,7 +676,7 @@ impl BeaconState { /// Spec v0.12.1 pub fn get_oldest_state_root(&self) -> Result<&Hash256, Error> { let i = - self.get_latest_state_roots_index(self.slot - Slot::from(self.state_roots.len()))?; + self.get_latest_state_roots_index(self.slot.saturating_sub(self.state_roots.len()))?; Ok(&self.state_roots[i]) } @@ -680,7 +684,9 @@ impl BeaconState { /// /// Spec v0.12.1 pub fn get_oldest_block_root(&self) -> Result<&Hash256, Error> { - let i = self.get_latest_block_roots_index(self.slot - self.block_roots.len() as u64)?; + let i = self.get_latest_block_roots_index( + self.slot.saturating_sub(self.block_roots.len() as u64), + )?; Ok(&self.block_roots[i]) } @@ -712,8 +718,8 @@ impl BeaconState { // We allow the slashings vector to be accessed at any cached epoch at or before // the current epoch, or the next epoch if `AllowNextEpoch::True` is passed. let current_epoch = self.current_epoch(); - if current_epoch < epoch + T::EpochsPerSlashingsVector::to_u64() - && epoch <= allow_next_epoch.upper_bound_of(current_epoch) + if current_epoch < epoch.safe_add(T::EpochsPerSlashingsVector::to_u64())? + && epoch <= allow_next_epoch.upper_bound_of(current_epoch)? { Ok(epoch .as_usize() @@ -775,7 +781,10 @@ impl BeaconState { // Bypass the safe getter for RANDAO so we can gracefully handle the scenario where `epoch // == 0`. let mix = { - let i = epoch + T::EpochsPerHistoricalVector::to_u64() - spec.min_seed_lookahead - 1; + let i = epoch + .safe_add(T::EpochsPerHistoricalVector::to_u64())? + .safe_sub(spec.min_seed_lookahead)? + .safe_sub(1)?; self.randao_mixes[i.as_usize().safe_rem(self.randao_mixes.len())?] }; let domain_bytes = int_to_bytes4(spec.get_domain_constant(domain_type)); @@ -811,8 +820,12 @@ impl BeaconState { /// Return the epoch at which an activation or exit triggered in ``epoch`` takes effect. /// /// Spec v0.12.1 - pub fn compute_activation_exit_epoch(&self, epoch: Epoch, spec: &ChainSpec) -> Epoch { - epoch + 1 + spec.max_seed_lookahead + pub fn compute_activation_exit_epoch( + &self, + epoch: Epoch, + spec: &ChainSpec, + ) -> Result { + Ok(epoch.safe_add(1)?.safe_add(spec.max_seed_lookahead)?) } /// Return the churn limit for the current epoch (number of validators who can leave per epoch). diff --git a/consensus/types/src/beacon_state/committee_cache.rs b/consensus/types/src/beacon_state/committee_cache.rs index f71ad2e89..6ee24cd2b 100644 --- a/consensus/types/src/beacon_state/committee_cache.rs +++ b/consensus/types/src/beacon_state/committee_cache.rs @@ -3,6 +3,7 @@ use super::BeaconState; use crate::*; use core::num::NonZeroUsize; +use safe_arith::SafeArith; use serde_derive::{Deserialize, Serialize}; use ssz_derive::{Decode, Encode}; use std::ops::Range; @@ -197,7 +198,7 @@ impl CommitteeCache { let epoch_start_slot = self.initialized_epoch?.start_slot(self.slots_per_epoch); let slot_offset = global_committee_index / self.committees_per_slot; let index = global_committee_index % self.committees_per_slot; - Some((epoch_start_slot + slot_offset, index)) + Some((epoch_start_slot.safe_add(slot_offset).ok()?, index)) } /// Returns the number of active validators in the initialized epoch. diff --git a/consensus/types/src/beacon_state/committee_cache/tests.rs b/consensus/types/src/beacon_state/committee_cache/tests.rs index ee2ca8eed..e1256cb48 100644 --- a/consensus/types/src/beacon_state/committee_cache/tests.rs +++ b/consensus/types/src/beacon_state/committee_cache/tests.rs @@ -53,8 +53,8 @@ fn initializes_with_the_right_epoch() { let cache = CommitteeCache::initialized(&state, state.previous_epoch(), &spec).unwrap(); assert_eq!(cache.initialized_epoch, Some(state.previous_epoch())); - let cache = CommitteeCache::initialized(&state, state.next_epoch(), &spec).unwrap(); - assert_eq!(cache.initialized_epoch, Some(state.next_epoch())); + let cache = CommitteeCache::initialized(&state, state.next_epoch().unwrap(), &spec).unwrap(); + assert_eq!(cache.initialized_epoch, Some(state.next_epoch().unwrap())); } #[test] @@ -81,7 +81,7 @@ fn shuffles_for_the_right_epoch() { .get_seed(state.current_epoch(), Domain::BeaconAttester, spec) .unwrap(); let next_seed = state - .get_seed(state.next_epoch(), Domain::BeaconAttester, spec) + .get_seed(state.next_epoch().unwrap(), Domain::BeaconAttester, spec) .unwrap(); assert!((previous_seed != current_seed) && (current_seed != next_seed)); @@ -114,7 +114,7 @@ fn shuffles_for_the_right_epoch() { assert_eq!(cache.shuffling, shuffling_with_seed(previous_seed)); assert_shuffling_positions_accurate(&cache); - let cache = CommitteeCache::initialized(&state, state.next_epoch(), spec).unwrap(); + let cache = CommitteeCache::initialized(&state, state.next_epoch().unwrap(), spec).unwrap(); assert_eq!(cache.shuffling, shuffling_with_seed(next_seed)); assert_shuffling_positions_accurate(&cache); } diff --git a/consensus/types/src/beacon_state/exit_cache.rs b/consensus/types/src/beacon_state/exit_cache.rs index 0f75e0f28..364c1daf0 100644 --- a/consensus/types/src/beacon_state/exit_cache.rs +++ b/consensus/types/src/beacon_state/exit_cache.rs @@ -44,7 +44,7 @@ impl ExitCache { self.exit_epoch_counts .entry(exit_epoch) .or_insert(0) - .increment()?; + .safe_add_assign(1)?; Ok(()) } diff --git a/consensus/types/src/relative_epoch.rs b/consensus/types/src/relative_epoch.rs index 381f17308..e681ce15c 100644 --- a/consensus/types/src/relative_epoch.rs +++ b/consensus/types/src/relative_epoch.rs @@ -1,9 +1,17 @@ use crate::*; +use safe_arith::{ArithError, SafeArith}; #[derive(Debug, PartialEq, Clone, Copy)] pub enum Error { EpochTooLow { base: Epoch, other: Epoch }, EpochTooHigh { base: Epoch, other: Epoch }, + ArithError(ArithError), +} + +impl From for Error { + fn from(e: ArithError) -> Self { + Self::ArithError(e) + } } #[cfg(feature = "arbitrary-fuzz")] @@ -32,8 +40,8 @@ impl RelativeEpoch { match self { // Due to saturating nature of epoch, check for current first. RelativeEpoch::Current => base, - RelativeEpoch::Previous => base - 1, - RelativeEpoch::Next => base + 1, + RelativeEpoch::Previous => base.saturating_sub(1u64), + RelativeEpoch::Next => base.saturating_add(1u64), } } @@ -46,12 +54,11 @@ impl RelativeEpoch { /// /// Spec v0.12.1 pub fn from_epoch(base: Epoch, other: Epoch) -> Result { - // Due to saturating nature of epoch, check for current first. if other == base { Ok(RelativeEpoch::Current) - } else if other == base - 1 { + } else if other.safe_add(1)? == base { Ok(RelativeEpoch::Previous) - } else if other == base + 1 { + } else if other == base.safe_add(1)? { Ok(RelativeEpoch::Next) } else if other < base { Err(Error::EpochTooLow { base, other }) diff --git a/consensus/types/src/slot_epoch.rs b/consensus/types/src/slot_epoch.rs index ea4e23375..ffb88c7ee 100644 --- a/consensus/types/src/slot_epoch.rs +++ b/consensus/types/src/slot_epoch.rs @@ -14,21 +14,23 @@ use crate::test_utils::TestRandom; use crate::SignedRoot; use rand::RngCore; +use safe_arith::SafeArith; use serde_derive::{Deserialize, Serialize}; use ssz::{ssz_encode, Decode, DecodeError, Encode}; -use std::cmp::{Ord, Ordering}; use std::fmt; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; use std::iter::Iterator; + +#[cfg(feature = "legacy-arith")] use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, Sub, SubAssign}; #[cfg_attr(feature = "arbitrary-fuzz", derive(arbitrary::Arbitrary))] -#[derive(Eq, Clone, Copy, Default, Serialize, Deserialize)] +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] #[serde(transparent)] pub struct Slot(u64); #[cfg_attr(feature = "arbitrary-fuzz", derive(arbitrary::Arbitrary))] -#[derive(Eq, Clone, Copy, Default, Serialize, Deserialize)] +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] #[serde(transparent)] pub struct Epoch(u64); @@ -41,7 +43,9 @@ impl Slot { } pub fn epoch(self, slots_per_epoch: u64) -> Epoch { - Epoch::from(self.0) / Epoch::from(slots_per_epoch) + Epoch::new(self.0) + .safe_div(slots_per_epoch) + .expect("slots_per_epoch is not 0") } pub fn max_value() -> Slot { @@ -96,9 +100,6 @@ impl Epoch { } } -impl SignedRoot for Epoch {} -impl SignedRoot for Slot {} - pub struct SlotIter<'a> { current_iteration: u64, epoch: &'a Epoch, @@ -115,7 +116,7 @@ impl<'a> Iterator for SlotIter<'a> { let start_slot = self.epoch.start_slot(self.slots_per_epoch); let previous = self.current_iteration; self.current_iteration = self.current_iteration.checked_add(1)?; - Some(start_slot + previous) + start_slot.safe_add(previous).ok() } } } diff --git a/consensus/types/src/slot_epoch_macros.rs b/consensus/types/src/slot_epoch_macros.rs index 15263f654..26b80692c 100644 --- a/consensus/types/src/slot_epoch_macros.rs +++ b/consensus/types/src/slot_epoch_macros.rs @@ -42,22 +42,84 @@ macro_rules! impl_from_into_usize { }; } +macro_rules! impl_u64_eq_ord { + ($type: ident) => { + impl PartialEq for $type { + fn eq(&self, other: &u64) -> bool { + self.as_u64() == *other + } + } + + impl PartialOrd for $type { + fn partial_cmp(&self, other: &u64) -> Option { + self.as_u64().partial_cmp(other) + } + } + }; +} + +macro_rules! impl_safe_arith { + ($type: ident, $rhs_ty: ident) => { + impl safe_arith::SafeArith<$rhs_ty> for $type { + const ZERO: Self = $type::new(0); + const ONE: Self = $type::new(1); + + fn safe_add(&self, other: $rhs_ty) -> safe_arith::Result { + self.0 + .checked_add(other.into()) + .map(Self::new) + .ok_or(safe_arith::ArithError::Overflow) + } + + fn safe_sub(&self, other: $rhs_ty) -> safe_arith::Result { + self.0 + .checked_sub(other.into()) + .map(Self::new) + .ok_or(safe_arith::ArithError::Overflow) + } + + fn safe_mul(&self, other: $rhs_ty) -> safe_arith::Result { + self.0 + .checked_mul(other.into()) + .map(Self::new) + .ok_or(safe_arith::ArithError::Overflow) + } + + fn safe_div(&self, other: $rhs_ty) -> safe_arith::Result { + self.0 + .checked_div(other.into()) + .map(Self::new) + .ok_or(safe_arith::ArithError::DivisionByZero) + } + + fn safe_rem(&self, other: $rhs_ty) -> safe_arith::Result { + self.0 + .checked_rem(other.into()) + .map(Self::new) + .ok_or(safe_arith::ArithError::DivisionByZero) + } + + fn safe_shl(&self, other: u32) -> safe_arith::Result { + self.0 + .checked_shl(other) + .map(Self::new) + .ok_or(safe_arith::ArithError::Overflow) + } + + fn safe_shr(&self, other: u32) -> safe_arith::Result { + self.0 + .checked_shr(other) + .map(Self::new) + .ok_or(safe_arith::ArithError::Overflow) + } + } + }; +} + +// Deprecated: prefer `SafeArith` methods for new code. +#[cfg(feature = "legacy-arith")] macro_rules! impl_math_between { ($main: ident, $other: ident) => { - impl PartialOrd<$other> for $main { - /// Utilizes `partial_cmp` on the underlying `u64`. - fn partial_cmp(&self, other: &$other) -> Option { - Some(self.0.cmp(&(*other).into())) - } - } - - impl PartialEq<$other> for $main { - fn eq(&self, other: &$other) -> bool { - let other: u64 = (*other).into(); - self.0 == other - } - } - impl Add<$other> for $main { type Output = $main; @@ -144,33 +206,17 @@ macro_rules! impl_math { ($type: ident) => { impl $type { pub fn saturating_sub>(&self, other: T) -> $type { - *self - other.into() + $type::new(self.as_u64().saturating_sub(other.into().as_u64())) } pub fn saturating_add>(&self, other: T) -> $type { - *self + other.into() - } - - pub fn checked_div>(&self, rhs: T) -> Option<$type> { - let rhs: $type = rhs.into(); - if rhs == 0 { - None - } else { - Some(*self / rhs) - } + $type::new(self.as_u64().saturating_add(other.into().as_u64())) } pub fn is_power_of_two(&self) -> bool { self.0.is_power_of_two() } } - - impl Ord for $type { - fn cmp(&self, other: &$type) -> Ordering { - let other: u64 = (*other).into(); - self.0.cmp(&other) - } - } }; } @@ -257,6 +303,8 @@ macro_rules! impl_ssz { } } + impl SignedRoot for $type {} + impl TestRandom for $type { fn random_for_test(rng: &mut impl RngCore) -> Self { $type::from(u64::random_for_test(rng)) @@ -265,29 +313,21 @@ macro_rules! impl_ssz { }; } -macro_rules! impl_hash { - ($type: ident) => { - // Implemented to stop clippy lint: - // https://rust-lang.github.io/rust-clippy/master/index.html#derive_hash_xor_eq - impl Hash for $type { - fn hash(&self, state: &mut H) { - ssz_encode(self).hash(state) - } - } - }; -} - macro_rules! impl_common { ($type: ident) => { impl_from_into_u64!($type); impl_from_into_usize!($type); + impl_u64_eq_ord!($type); + impl_safe_arith!($type, $type); + impl_safe_arith!($type, u64); + #[cfg(feature = "legacy-arith")] impl_math_between!($type, $type); + #[cfg(feature = "legacy-arith")] impl_math_between!($type, u64); impl_math!($type); impl_display!($type); impl_debug!($type); impl_ssz!($type); - impl_hash!($type); }; } @@ -335,6 +375,7 @@ macro_rules! math_between_tests { ($type: ident, $other: ident) => { #[test] fn partial_ord() { + use std::cmp::Ordering; let assert_partial_ord = |a: u64, partial_ord: Ordering, b: u64| { let other: $other = $type(b).into(); assert_eq!($type(a).partial_cmp(&other), Some(partial_ord)); @@ -518,7 +559,7 @@ macro_rules! math_tests { #[test] fn checked_div() { let assert_checked_div = |a: u64, b: u64, result: Option| { - let division_result_as_u64 = match $type(a).checked_div($type(b)) { + let division_result_as_u64 = match $type(a).safe_div($type(b)).ok() { None => None, Some(val) => Some(val.as_u64()), }; @@ -560,6 +601,7 @@ macro_rules! math_tests { #[test] fn ord() { + use std::cmp::Ordering; let assert_ord = |a: u64, ord: Ordering, b: u64| { assert_eq!($type(a).cmp(&$type(b)), ord); }; diff --git a/consensus/types/src/test_utils/builders/testing_attestation_data_builder.rs b/consensus/types/src/test_utils/builders/testing_attestation_data_builder.rs index 9ecef2815..56b3e3bbe 100644 --- a/consensus/types/src/test_utils/builders/testing_attestation_data_builder.rs +++ b/consensus/types/src/test_utils/builders/testing_attestation_data_builder.rs @@ -1,5 +1,6 @@ use crate::test_utils::AttestationTestTask; use crate::*; +use safe_arith::SafeArith; /// Builds an `AttestationData` to be used for testing purposes. /// @@ -49,12 +50,19 @@ impl TestingAttestationDataBuilder { match test_task { AttestationTestTask::IncludedTooEarly => { - slot = state.slot - spec.min_attestation_inclusion_delay + 1 + slot = state + .slot + .safe_sub(spec.min_attestation_inclusion_delay) + .unwrap() + .safe_add(1u64) + .unwrap(); } - AttestationTestTask::IncludedTooLate => slot -= T::SlotsPerEpoch::to_u64(), + AttestationTestTask::IncludedTooLate => slot + .safe_sub_assign(Slot::new(T::SlotsPerEpoch::to_u64())) + .unwrap(), AttestationTestTask::TargetEpochSlotMismatch => { target = Checkpoint { - epoch: current_epoch + 1, + epoch: current_epoch.safe_add(1u64).unwrap(), root: Hash256::zero(), }; assert_ne!(target.epoch, slot.epoch(T::slots_per_epoch())); diff --git a/consensus/types/src/test_utils/builders/testing_beacon_block_builder.rs b/consensus/types/src/test_utils/builders/testing_beacon_block_builder.rs index c396d8c96..97fe62780 100644 --- a/consensus/types/src/test_utils/builders/testing_beacon_block_builder.rs +++ b/consensus/types/src/test_utils/builders/testing_beacon_block_builder.rs @@ -9,6 +9,7 @@ use crate::{ use int_to_bytes::int_to_bytes32; use merkle_proof::MerkleTree; use rayon::prelude::*; +use safe_arith::SafeArith; use tree_hash::TreeHash; /// Builds a beacon block to be used for testing purposes. @@ -172,7 +173,10 @@ impl TestingBeaconBlockBuilder { num_attestations: usize, spec: &ChainSpec, ) -> Result<(), BeaconStateError> { - let mut slot = self.block.slot - spec.min_attestation_inclusion_delay; + let mut slot = self + .block + .slot + .safe_sub(spec.min_attestation_inclusion_delay)?; let mut attestations_added = 0; // Stores the following (in order): @@ -192,7 +196,7 @@ impl TestingBeaconBlockBuilder { // - The slot is too old to be included in a block at this slot. // - The `MAX_ATTESTATIONS`. loop { - if state.slot >= slot + T::slots_per_epoch() { + if state.slot >= slot.safe_add(T::slots_per_epoch())? { break; } @@ -211,7 +215,7 @@ impl TestingBeaconBlockBuilder { attestations_added += 1; } - slot -= 1; + slot.safe_sub_assign(1u64)?; } // Loop through all the committees, splitting each one in half until we have diff --git a/consensus/types/src/test_utils/builders/testing_beacon_state_builder.rs b/consensus/types/src/test_utils/builders/testing_beacon_state_builder.rs index 1dda9de98..67a3dae26 100644 --- a/consensus/types/src/test_utils/builders/testing_beacon_state_builder.rs +++ b/consensus/types/src/test_utils/builders/testing_beacon_state_builder.rs @@ -143,11 +143,11 @@ impl TestingBeaconStateBuilder { state.slot = slot; - state.previous_justified_checkpoint.epoch = epoch - 3; - state.current_justified_checkpoint.epoch = epoch - 2; + state.previous_justified_checkpoint.epoch = epoch.saturating_sub(3u64); + state.current_justified_checkpoint.epoch = epoch.saturating_sub(2u64); state.justification_bits = BitVector::from_bytes(vec![0b0000_1111]).unwrap(); - state.finalized_checkpoint.epoch = epoch - 3; + state.finalized_checkpoint.epoch = state.previous_justified_checkpoint.epoch; } /// Creates a full set of attestations for the `BeaconState`. Each attestation has full