Add safe_sum
and use it in state_processing (#1620)
## Issue Addressed Closes #1098 ## Proposed Changes Add a `SafeArithIter` trait with a `safe_sum` method, and use it in `state_processing`. This seems to be the only place in `consensus` where it is relevant -- i.e. where we were using `sum` and the integer_arith lint is enabled. ## Additional Info This PR doesn't include any Clippy linting to prevent `sum` from being called. It seems there is no existing Clippy lint that suits our purpose, but I'm going to look into that and maybe schedule writing one as a lower-priority task. This theoretically _is_ a consensus breaking change, but it shouldn't impact Medalla (or any other testnet) because `slashings` shouldn't overflow!
This commit is contained in:
parent
4fca306397
commit
7aceff4d13
70
consensus/safe_arith/src/iter.rs
Normal file
70
consensus/safe_arith/src/iter.rs
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
use crate::{Result, SafeArith};
|
||||||
|
|
||||||
|
/// Extension trait for iterators, providing a safe replacement for `sum`.
|
||||||
|
pub trait SafeArithIter<T> {
|
||||||
|
fn safe_sum(self) -> Result<T>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I, T> SafeArithIter<T> for I
|
||||||
|
where
|
||||||
|
I: Iterator<Item = T> + Sized,
|
||||||
|
T: SafeArith,
|
||||||
|
{
|
||||||
|
fn safe_sum(mut self) -> Result<T> {
|
||||||
|
self.try_fold(T::ZERO, |acc, x| acc.safe_add(x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
use crate::ArithError;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_sum() {
|
||||||
|
let v: Vec<u64> = vec![];
|
||||||
|
assert_eq!(v.into_iter().safe_sum(), Ok(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unsigned_sum_small() {
|
||||||
|
let v = vec![400u64, 401, 402, 403, 404, 405, 406];
|
||||||
|
assert_eq!(
|
||||||
|
v.iter().copied().safe_sum().unwrap(),
|
||||||
|
v.iter().copied().sum()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unsigned_sum_overflow() {
|
||||||
|
let v = vec![u64::MAX, 1];
|
||||||
|
assert_eq!(v.into_iter().safe_sum(), Err(ArithError::Overflow));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn signed_sum_small() {
|
||||||
|
let v = vec![-1i64, -2i64, -3i64, 3, 2, 1];
|
||||||
|
assert_eq!(v.into_iter().safe_sum(), Ok(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn signed_sum_overflow_above() {
|
||||||
|
let v = vec![1, 2, 3, 4, i16::MAX, 0, 1, 2, 3];
|
||||||
|
assert_eq!(v.into_iter().safe_sum(), Err(ArithError::Overflow));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn signed_sum_overflow_below() {
|
||||||
|
let v = vec![i16::MIN, -1];
|
||||||
|
assert_eq!(v.into_iter().safe_sum(), Err(ArithError::Overflow));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn signed_sum_almost_overflow() {
|
||||||
|
let v = vec![i64::MIN, 1, -1i64, i64::MAX, i64::MAX, 1];
|
||||||
|
assert_eq!(
|
||||||
|
v.iter().copied().safe_sum().unwrap(),
|
||||||
|
v.iter().copied().sum()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,7 @@
|
|||||||
//! Library for safe arithmetic on integers, avoiding overflow and division by zero.
|
//! Library for safe arithmetic on integers, avoiding overflow and division by zero.
|
||||||
|
mod iter;
|
||||||
|
|
||||||
|
pub use iter::SafeArithIter;
|
||||||
|
|
||||||
/// Error representing the failure of an arithmetic operation.
|
/// Error representing the failure of an arithmetic operation.
|
||||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
@ -7,7 +10,7 @@ pub enum ArithError {
|
|||||||
DivisionByZero,
|
DivisionByZero,
|
||||||
}
|
}
|
||||||
|
|
||||||
type Result<T> = std::result::Result<T, ArithError>;
|
pub type Result<T> = std::result::Result<T, ArithError>;
|
||||||
|
|
||||||
macro_rules! assign_method {
|
macro_rules! assign_method {
|
||||||
($name:ident, $op:ident, $doc_op:expr) => {
|
($name:ident, $op:ident, $doc_op:expr) => {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use safe_arith::SafeArith;
|
use safe_arith::{SafeArith, SafeArithIter};
|
||||||
use types::{BeaconStateError as Error, *};
|
use types::{BeaconStateError as Error, *};
|
||||||
|
|
||||||
/// Process slashings.
|
/// Process slashings.
|
||||||
@ -10,7 +10,7 @@ pub fn process_slashings<T: EthSpec>(
|
|||||||
spec: &ChainSpec,
|
spec: &ChainSpec,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
let epoch = state.current_epoch();
|
let epoch = state.current_epoch();
|
||||||
let sum_slashings = state.get_all_slashings().iter().sum::<u64>();
|
let sum_slashings = state.get_all_slashings().iter().copied().safe_sum()?;
|
||||||
|
|
||||||
for (index, validator) in state.validators.iter().enumerate() {
|
for (index, validator) in state.validators.iter().enumerate() {
|
||||||
if validator.slashed
|
if validator.slashed
|
||||||
|
Loading…
Reference in New Issue
Block a user