diff --git a/tests/ef_tests/Cargo.toml b/tests/ef_tests/Cargo.toml index fdba1bcb3..b0a5e27c5 100644 --- a/tests/ef_tests/Cargo.toml +++ b/tests/ef_tests/Cargo.toml @@ -20,6 +20,7 @@ ssz = { path = "../../eth2/utils/ssz" } tree_hash = { path = "../../eth2/utils/tree_hash" } cached_tree_hash = { path = "../../eth2/utils/cached_tree_hash" } state_processing = { path = "../../eth2/state_processing" } +swap_or_not_shuffle = { path = "../../eth2/utils/swap_or_not_shuffle" } types = { path = "../../eth2/types" } walkdir = "2" yaml-rust = { git = "https://github.com/sigp/yaml-rust", branch = "escape_all_str"} diff --git a/tests/ef_tests/src/cases.rs b/tests/ef_tests/src/cases.rs index 511759875..44be18756 100644 --- a/tests/ef_tests/src/cases.rs +++ b/tests/ef_tests/src/cases.rs @@ -10,6 +10,7 @@ mod bls_sign_msg; mod operations_deposit; mod operations_exit; mod operations_transfer; +mod shuffling; mod ssz_generic; mod ssz_static; @@ -22,6 +23,7 @@ pub use bls_sign_msg::*; pub use operations_deposit::*; pub use operations_exit::*; pub use operations_transfer::*; +pub use shuffling::*; pub use ssz_generic::*; pub use ssz_static::*; diff --git a/tests/ef_tests/src/cases/shuffling.rs b/tests/ef_tests/src/cases/shuffling.rs new file mode 100644 index 000000000..ef8a1b934 --- /dev/null +++ b/tests/ef_tests/src/cases/shuffling.rs @@ -0,0 +1,48 @@ +use super::*; +use crate::case_result::compare_result; +use serde_derive::Deserialize; +use std::marker::PhantomData; +use swap_or_not_shuffle::{get_permutated_index, shuffle_list}; + +#[derive(Debug, Clone, Deserialize)] +pub struct Shuffling { + pub seed: String, + pub count: usize, + pub shuffled: Vec, + #[serde(skip)] + _phantom: PhantomData, +} + +impl YamlDecode for Shuffling { + fn yaml_decode(yaml: &String) -> Result { + Ok(serde_yaml::from_str(&yaml.as_str()).unwrap()) + } +} + +impl Case for Shuffling { + fn result(&self, _case_index: usize) -> Result<(), Error> { + if self.count == 0 { + compare_result::<_, Error>(&Ok(vec![]), &Some(self.shuffled.clone()))?; + } else { + let spec = T::spec(); + let seed = hex::decode(&self.seed[2..]) + .map_err(|e| Error::FailedToParseTest(format!("{:?}", e)))?; + + // Test get_permuted_index + let shuffling = (0..self.count) + .into_iter() + .map(|i| { + get_permutated_index(i, self.count, &seed, spec.shuffle_round_count).unwrap() + }) + .collect(); + compare_result::<_, Error>(&Ok(shuffling), &Some(self.shuffled.clone()))?; + + // Test "shuffle_list" + let input: Vec = (0..self.count).collect(); + let shuffling = shuffle_list(input, spec.shuffle_round_count, &seed, false).unwrap(); + compare_result::<_, Error>(&Ok(shuffling), &Some(self.shuffled.clone()))?; + } + + Ok(()) + } +} diff --git a/tests/ef_tests/src/doc.rs b/tests/ef_tests/src/doc.rs index 686173df3..b0854451a 100644 --- a/tests/ef_tests/src/doc.rs +++ b/tests/ef_tests/src/doc.rs @@ -41,6 +41,8 @@ impl Doc { ("ssz", "uint", _) => run_test::(self), ("ssz", "static", "minimal") => run_test::>(self), ("ssz", "static", "mainnet") => run_test::>(self), + ("shuffling", "core", "minimal") => run_test::>(self), + ("shuffling", "core", "mainnet") => run_test::>(self), ("bls", "aggregate_pubkeys", "mainnet") => run_test::(self), ("bls", "aggregate_sigs", "mainnet") => run_test::(self), ("bls", "msg_hash_compressed", "mainnet") => run_test::(self), diff --git a/tests/ef_tests/src/eth_specs.rs b/tests/ef_tests/src/eth_specs.rs index b2d46d8bc..cdf8b94e8 100644 --- a/tests/ef_tests/src/eth_specs.rs +++ b/tests/ef_tests/src/eth_specs.rs @@ -21,7 +21,9 @@ impl EthSpec for MinimalEthSpec { fn spec() -> ChainSpec { // TODO: this spec is likely incorrect! - FewValidatorsEthSpec::spec() + let mut spec = FewValidatorsEthSpec::spec(); + spec.shuffle_round_count = 10; + spec } } diff --git a/tests/ef_tests/tests/tests.rs b/tests/ef_tests/tests/tests.rs index ecfbca14a..86f188671 100644 --- a/tests/ef_tests/tests/tests.rs +++ b/tests/ef_tests/tests/tests.rs @@ -60,6 +60,15 @@ fn ssz_static() { }); } +#[test] +fn shuffling() { + yaml_files_in_test_dir(&Path::new("shuffling").join("core")) + .into_par_iter() + .for_each(|file| { + Doc::assert_tests_pass(file); + }); +} + #[test] #[cfg(not(feature = "fake_crypto"))] fn operations_deposit() {