diff --git a/beacon_chain/validator_shuffling/src/lib.rs b/beacon_chain/validator_shuffling/src/lib.rs index a11483e56..ee6f5e8fb 100644 --- a/beacon_chain/validator_shuffling/src/lib.rs +++ b/beacon_chain/validator_shuffling/src/lib.rs @@ -5,4 +5,7 @@ extern crate types; mod active_validator_indices; mod shuffle; -pub use shuffle::shard_and_committees_for_cycle; +pub use shuffle::{ + shard_and_committees_for_cycle, + ValidatorAssignmentError, +}; diff --git a/beacon_chain/validator_shuffling/src/shuffle.rs b/beacon_chain/validator_shuffling/src/shuffle.rs index 1ce778a6b..4768c1cb1 100644 --- a/beacon_chain/validator_shuffling/src/shuffle.rs +++ b/beacon_chain/validator_shuffling/src/shuffle.rs @@ -1,7 +1,10 @@ use std::cmp::min; use honey_badger_split::SplitExt; -use vec_shuffle::shuffle; +use vec_shuffle::{ + shuffle, + ShuffleErr, +}; use types::{ ShardAndCommittee, ValidatorRecord, @@ -12,12 +15,12 @@ use super::active_validator_indices::active_validator_indices; type DelegatedCycle = Vec>; -#[derive(Debug)] -pub enum TransitionError { - InvalidInput(String), +#[derive(Debug, PartialEq)] +pub enum ValidatorAssignmentError { + TooManyValidators, + TooFewShards, } - /// Delegates active validators into slots for a given cycle, given a random seed. /// Returns a vector or ShardAndComitte vectors representing the shards and committiees for /// each slot. @@ -27,15 +30,11 @@ pub fn shard_and_committees_for_cycle( validators: &[ValidatorRecord], crosslinking_shard_start: u16, config: &ChainConfig) - -> Result + -> Result { let shuffled_validator_indices = { let mut validator_indices = active_validator_indices(validators); - match shuffle(seed, validator_indices) { - Ok(shuffled) => shuffled, - _ => return Err(TransitionError::InvalidInput( - String::from("Shuffle list length exceed."))) - } + shuffle(seed, validator_indices)? }; let shard_indices: Vec = (0_usize..config.shard_count as usize).into_iter().collect(); let crosslinking_shard_start = crosslinking_shard_start as usize; @@ -56,17 +55,14 @@ fn generate_cycle( crosslinking_shard_start: usize, cycle_length: usize, min_committee_size: usize) - -> Result + -> Result { let validator_count = validator_indices.len(); let shard_count = shard_indices.len(); if shard_count / cycle_length == 0 { - return Err(TransitionError::InvalidInput(String::from("Number of - shards needs to be greater than - cycle length"))); - + return Err(ValidatorAssignmentError::TooFewShards) } let (committees_per_slot, slots_per_committee) = { @@ -105,6 +101,14 @@ fn generate_cycle( Ok(cycle) } +impl From for ValidatorAssignmentError { + fn from(e: ShuffleErr) -> ValidatorAssignmentError { + match e { + ShuffleErr::ExceedsListLength => ValidatorAssignmentError::TooManyValidators, + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -115,7 +119,7 @@ mod tests { crosslinking_shard_start: usize, cycle_length: usize, min_committee_size: usize) - -> (Vec, Vec, Result) + -> (Vec, Vec, Result) { let validator_indices: Vec = (0_usize..*validator_count).into_iter().collect(); let shard_indices: Vec = (0_usize..*shard_count).into_iter().collect();