Add fast full-list shuffle for swap-or-not

- Passes test vectors
- Implemented in beacon state
- Added more docs
This commit is contained in:
Paul Hauner 2019-03-01 12:19:05 +11:00
parent 8aa7f25bbc
commit c3b2f802a7
No known key found for this signature in database
GPG Key ID: D362883A9218FCC6
5 changed files with 402 additions and 217 deletions

View File

@ -12,7 +12,7 @@ use rand::RngCore;
use serde_derive::Serialize; use serde_derive::Serialize;
use ssz::{hash, Decodable, DecodeError, Encodable, SszStream, TreeHash}; use ssz::{hash, Decodable, DecodeError, Encodable, SszStream, TreeHash};
use std::collections::HashMap; use std::collections::HashMap;
use swap_or_not_shuffle::get_permutated_list; use swap_or_not_shuffle::shuffle_list;
pub use builder::BeaconStateBuilder; pub use builder::BeaconStateBuilder;
@ -423,10 +423,11 @@ impl BeaconState {
let active_validator_indices: Vec<usize> = let active_validator_indices: Vec<usize> =
active_validator_indices.iter().cloned().collect(); active_validator_indices.iter().cloned().collect();
let shuffled_active_validator_indices = get_permutated_list( let shuffled_active_validator_indices = shuffle_list(
&active_validator_indices, active_validator_indices,
&seed[..],
spec.shuffle_round_count, spec.shuffle_round_count,
&seed[..],
true,
) )
.ok_or_else(|| Error::UnableToShuffle)?; .ok_or_else(|| Error::UnableToShuffle)?;

View File

@ -1,6 +1,6 @@
use criterion::Criterion; use criterion::Criterion;
use criterion::{black_box, criterion_group, criterion_main, Benchmark}; use criterion::{black_box, criterion_group, criterion_main, Benchmark};
use swap_or_not_shuffle::{get_permutated_index, get_permutated_list}; use swap_or_not_shuffle::{get_permutated_index, shuffle_list as fast_shuffle};
const SHUFFLE_ROUND_COUNT: u8 = 90; const SHUFFLE_ROUND_COUNT: u8 = 90;
@ -53,7 +53,7 @@ fn shuffles(c: &mut Criterion) {
Benchmark::new("512 elements", move |b| { Benchmark::new("512 elements", move |b| {
let seed = vec![42; 32]; let seed = vec![42; 32];
let list: Vec<usize> = (0..512).collect(); let list: Vec<usize> = (0..512).collect();
b.iter(|| black_box(get_permutated_list(&list, &seed, SHUFFLE_ROUND_COUNT))) b.iter(|| black_box(fast_shuffle(list.clone(), SHUFFLE_ROUND_COUNT, &seed, true)))
}) })
.sample_size(10), .sample_size(10),
); );
@ -72,7 +72,17 @@ fn shuffles(c: &mut Criterion) {
Benchmark::new("16384 elements", move |b| { Benchmark::new("16384 elements", move |b| {
let seed = vec![42; 32]; let seed = vec![42; 32];
let list: Vec<usize> = (0..16384).collect(); let list: Vec<usize> = (0..16384).collect();
b.iter(|| black_box(get_permutated_list(&list, &seed, SHUFFLE_ROUND_COUNT))) b.iter(|| black_box(fast_shuffle(list.clone(), SHUFFLE_ROUND_COUNT, &seed, true)))
})
.sample_size(10),
);
c.bench(
"_fast_ whole list shuffle",
Benchmark::new("4m elements", move |b| {
let seed = vec![42; 32];
let list: Vec<usize> = (0..4_000_000).collect();
b.iter(|| black_box(fast_shuffle(list.clone(), SHUFFLE_ROUND_COUNT, &seed, true)))
}) })
.sample_size(10), .sample_size(10),
); );

View File

@ -0,0 +1,187 @@
use bytes::Buf;
use hashing::hash;
use int_to_bytes::{int_to_bytes1, int_to_bytes4};
use std::cmp::max;
use std::io::Cursor;
/// Return `p(index)` in a pseudorandom permutation `p` of `0...list_size-1` with ``seed`` as entropy.
///
/// Utilizes 'swap or not' shuffling found in
/// https://link.springer.com/content/pdf/10.1007%2F978-3-642-32009-5_1.pdf
/// See the 'generalized domain' algorithm on page 3.
///
/// Note: this function is significantly slower than the `shuffle_list` function in this crate.
/// Using `get_permutated_list` to shuffle an entire list, index by index, has been observed to be
/// 250x slower than `shuffle_list`. Therefore, this function is only useful when shuffling a small
/// portion of a much larger list.
///
/// Returns `None` under any of the following conditions:
/// - `list_size == 0`
/// - `index >= list_size`
/// - `list_size > 2**24`
/// - `list_size > usize::max_value() / 2`
pub fn get_permutated_index(
index: usize,
list_size: usize,
seed: &[u8],
shuffle_round_count: u8,
) -> Option<usize> {
if list_size == 0
|| index >= list_size
|| list_size > usize::max_value() / 2
|| list_size > 2_usize.pow(24)
{
return None;
}
let mut index = index;
for round in 0..shuffle_round_count {
let pivot = bytes_to_int64(&hash_with_round(seed, round)[..]) as usize % list_size;
index = do_round(seed, index, pivot, round, list_size)?;
}
Some(index)
}
fn do_round(seed: &[u8], index: usize, pivot: usize, round: u8, list_size: usize) -> Option<usize> {
let flip = (pivot + list_size - index) % list_size;
let position = max(index, flip);
let source = hash_with_round_and_position(seed, round, position)?;
let byte = source[(position % 256) / 8];
let bit = (byte >> (position % 8)) % 2;
Some(if bit == 1 { flip } else { index })
}
fn hash_with_round_and_position(seed: &[u8], round: u8, position: usize) -> Option<Vec<u8>> {
let mut seed = seed.to_vec();
seed.append(&mut int_to_bytes1(round));
/*
* Note: the specification has an implicit assertion in `int_to_bytes4` that `position / 256 <
* 2**24`. For efficiency, we do not check for that here as it is checked in `get_permutated_index`.
*/
seed.append(&mut int_to_bytes4((position / 256) as u32));
Some(hash(&seed[..]))
}
fn hash_with_round(seed: &[u8], round: u8) -> Vec<u8> {
let mut seed = seed.to_vec();
seed.append(&mut int_to_bytes1(round));
hash(&seed[..])
}
fn bytes_to_int64(bytes: &[u8]) -> u64 {
let mut cursor = Cursor::new(bytes);
cursor.get_u64_le()
}
#[cfg(test)]
mod tests {
use super::*;
use ethereum_types::H256 as Hash256;
use hex;
use std::{fs::File, io::prelude::*, path::PathBuf};
use yaml_rust::yaml;
#[test]
#[ignore]
fn fuzz_test() {
let max_list_size = 2_usize.pow(24);
let test_runs = 1000;
// Test at max list_size with the end index.
for _ in 0..test_runs {
let index = max_list_size - 1;
let list_size = max_list_size;
let seed = Hash256::random();
let shuffle_rounds = 90;
assert!(get_permutated_index(index, list_size, &seed[..], shuffle_rounds).is_some());
}
// Test at max list_size low indices.
for i in 0..test_runs {
let index = i;
let list_size = max_list_size;
let seed = Hash256::random();
let shuffle_rounds = 90;
assert!(get_permutated_index(index, list_size, &seed[..], shuffle_rounds).is_some());
}
// Test at max list_size high indices.
for i in 0..test_runs {
let index = max_list_size - 1 - i;
let list_size = max_list_size;
let seed = Hash256::random();
let shuffle_rounds = 90;
assert!(get_permutated_index(index, list_size, &seed[..], shuffle_rounds).is_some());
}
}
#[test]
fn returns_none_for_zero_length_list() {
assert_eq!(None, get_permutated_index(100, 0, &[42, 42], 90));
}
#[test]
fn returns_none_for_out_of_bounds_index() {
assert_eq!(None, get_permutated_index(100, 100, &[42, 42], 90));
}
#[test]
fn returns_none_for_too_large_list() {
assert_eq!(
None,
get_permutated_index(100, usize::max_value() / 2, &[42, 42], 90)
);
}
#[test]
fn test_vectors() {
/*
* Test vectors are generated here:
*
* https://github.com/ethereum/eth2.0-test-generators
*/
let mut file = {
let mut file_path_buf = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
file_path_buf.push("src/specs/test_vector_permutated_index.yml");
File::open(file_path_buf).unwrap()
};
let mut yaml_str = String::new();
file.read_to_string(&mut yaml_str).unwrap();
let docs = yaml::YamlLoader::load_from_str(&yaml_str).unwrap();
let doc = &docs[0];
let test_cases = doc["test_cases"].as_vec().unwrap();
for (i, test_case) in test_cases.iter().enumerate() {
let index = test_case["index"].as_i64().unwrap() as usize;
let list_size = test_case["list_size"].as_i64().unwrap() as usize;
let permutated_index = test_case["permutated_index"].as_i64().unwrap() as usize;
let shuffle_round_count = test_case["shuffle_round_count"].as_i64().unwrap();
let seed_string = test_case["seed"].clone().into_string().unwrap();
let seed = hex::decode(seed_string.replace("0x", "")).unwrap();
let shuffle_round_count = if shuffle_round_count < (u8::max_value() as i64) {
shuffle_round_count as u8
} else {
panic!("shuffle_round_count must be a u8")
};
assert_eq!(
Some(permutated_index),
get_permutated_index(index, list_size, &seed[..], shuffle_round_count),
"Failure on case #{} index: {}, list_size: {}, round_count: {}, seed: {}",
i,
index,
list_size,
shuffle_round_count,
seed_string,
);
}
}
}

View File

@ -1,212 +1,21 @@
use bytes::Buf; //! Provides list-shuffling functions matching the Ethereum 2.0 specification.
use hashing::hash; //!
use int_to_bytes::{int_to_bytes1, int_to_bytes4}; //! See
use std::cmp::max; //! [get_permutated_index](https://github.com/ethereum/eth2.0-specs/blob/0.4.0/specs/core/0_beacon-chain.md#get_permuted_index)
use std::io::Cursor; //! for specifications.
//!
//! There are two functions exported by this crate:
//!
//! - `get_permutated_index`: given a single index, computes the index resulting from a shuffle.
//! Runs in less time than it takes to run `shuffle_list`.
//! - `shuffle_list`: shuffles an entire list in-place. Runs in less time than it takes to run
//! `get_permutated_index` on each index.
//!
//! In general, use `get_permutated_list` to calculate the shuffling of a small subset of a much
//! larger list (~250x larger is a good guide, but solid figures yet to be calculated).
pub fn get_permutated_list( mod get_permutated_index;
list: &[usize], mod shuffle_list;
seed: &[u8],
shuffle_round_count: u8,
) -> Option<Vec<usize>> {
let list_size = list.len();
if list_size == 0 || list_size > usize::max_value() / 2 || list_size > 2_usize.pow(24) { pub use get_permutated_index::get_permutated_index;
return None; pub use shuffle_list::shuffle_list;
}
let mut pivots = Vec::with_capacity(shuffle_round_count as usize);
for round in 0..shuffle_round_count {
pivots.push(bytes_to_int64(&hash_with_round(seed, round)[..]) as usize % list_size);
}
let mut output = Vec::with_capacity(list_size);
for i in 0..list_size {
let mut index = i;
for round in 0..shuffle_round_count {
let pivot = pivots[round as usize];
index = do_round(seed, index, pivot, round, list_size)?;
}
output.push(list[index])
}
Some(output)
}
/// Return `p(index)` in a pseudorandom permutation `p` of `0...list_size-1` with ``seed`` as entropy.
///
/// Utilizes 'swap or not' shuffling found in
/// https://link.springer.com/content/pdf/10.1007%2F978-3-642-32009-5_1.pdf
/// See the 'generalized domain' algorithm on page 3.
///
/// Returns `None` under any of the following conditions:
/// - `list_size == 0`
/// - `index >= list_size`
/// - `list_size > 2**24`
/// - `list_size > usize::max_value() / 2`
pub fn get_permutated_index(
index: usize,
list_size: usize,
seed: &[u8],
shuffle_round_count: u8,
) -> Option<usize> {
if list_size == 0
|| index >= list_size
|| list_size > usize::max_value() / 2
|| list_size > 2_usize.pow(24)
{
return None;
}
let mut index = index;
for round in 0..shuffle_round_count {
let pivot = bytes_to_int64(&hash_with_round(seed, round)[..]) as usize % list_size;
index = do_round(seed, index, pivot, round, list_size)?;
}
Some(index)
}
fn do_round(seed: &[u8], index: usize, pivot: usize, round: u8, list_size: usize) -> Option<usize> {
let flip = (pivot + list_size - index) % list_size;
let position = max(index, flip);
let source = hash_with_round_and_position(seed, round, position)?;
let byte = source[(position % 256) / 8];
let bit = (byte >> (position % 8)) % 2;
Some(if bit == 1 { flip } else { index })
}
fn hash_with_round_and_position(seed: &[u8], round: u8, position: usize) -> Option<Vec<u8>> {
let mut seed = seed.to_vec();
seed.append(&mut int_to_bytes1(round));
/*
* Note: the specification has an implicit assertion in `int_to_bytes4` that `position / 256 <
* 2**24`. For efficiency, we do not check for that here as it is checked in `get_permutated_index`.
*/
seed.append(&mut int_to_bytes4((position / 256) as u32));
Some(hash(&seed[..]))
}
fn hash_with_round(seed: &[u8], round: u8) -> Vec<u8> {
let mut seed = seed.to_vec();
seed.append(&mut int_to_bytes1(round));
hash(&seed[..])
}
fn bytes_to_int64(bytes: &[u8]) -> u64 {
let mut cursor = Cursor::new(bytes);
cursor.get_u64_le()
}
#[cfg(test)]
mod tests {
use super::*;
use ethereum_types::H256 as Hash256;
use hex;
use std::{fs::File, io::prelude::*, path::PathBuf};
use yaml_rust::yaml;
#[test]
#[ignore]
fn fuzz_test() {
let max_list_size = 2_usize.pow(24);
let test_runs = 1000;
// Test at max list_size with the end index.
for _ in 0..test_runs {
let index = max_list_size - 1;
let list_size = max_list_size;
let seed = Hash256::random();
let shuffle_rounds = 90;
assert!(get_permutated_index(index, list_size, &seed[..], shuffle_rounds).is_some());
}
// Test at max list_size low indices.
for i in 0..test_runs {
let index = i;
let list_size = max_list_size;
let seed = Hash256::random();
let shuffle_rounds = 90;
assert!(get_permutated_index(index, list_size, &seed[..], shuffle_rounds).is_some());
}
// Test at max list_size high indices.
for i in 0..test_runs {
let index = max_list_size - 1 - i;
let list_size = max_list_size;
let seed = Hash256::random();
let shuffle_rounds = 90;
assert!(get_permutated_index(index, list_size, &seed[..], shuffle_rounds).is_some());
}
}
#[test]
fn returns_none_for_zero_length_list() {
assert_eq!(None, get_permutated_index(100, 0, &[42, 42], 90));
}
#[test]
fn returns_none_for_out_of_bounds_index() {
assert_eq!(None, get_permutated_index(100, 100, &[42, 42], 90));
}
#[test]
fn returns_none_for_too_large_list() {
assert_eq!(
None,
get_permutated_index(100, usize::max_value() / 2, &[42, 42], 90)
);
}
#[test]
fn test_vectors() {
/*
* Test vectors are generated here:
*
* https://github.com/ethereum/eth2.0-test-generators
*/
let mut file = {
let mut file_path_buf = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
file_path_buf.push("src/specs/test_vector_permutated_index.yml");
File::open(file_path_buf).unwrap()
};
let mut yaml_str = String::new();
file.read_to_string(&mut yaml_str).unwrap();
let docs = yaml::YamlLoader::load_from_str(&yaml_str).unwrap();
let doc = &docs[0];
let test_cases = doc["test_cases"].as_vec().unwrap();
for (i, test_case) in test_cases.iter().enumerate() {
let index = test_case["index"].as_i64().unwrap() as usize;
let list_size = test_case["list_size"].as_i64().unwrap() as usize;
let permutated_index = test_case["permutated_index"].as_i64().unwrap() as usize;
let shuffle_round_count = test_case["shuffle_round_count"].as_i64().unwrap();
let seed_string = test_case["seed"].clone().into_string().unwrap();
let seed = hex::decode(seed_string.replace("0x", "")).unwrap();
let shuffle_round_count = if shuffle_round_count < (u8::max_value() as i64) {
shuffle_round_count as u8
} else {
panic!("shuffle_round_count must be a u8")
};
assert_eq!(
Some(permutated_index),
get_permutated_index(index, list_size, &seed[..], shuffle_round_count),
"Failure on case #{} index: {}, list_size: {}, round_count: {}, seed: {}",
i,
index,
list_size,
shuffle_round_count,
seed_string,
);
}
}
}

View File

@ -0,0 +1,178 @@
use bytes::Buf;
use hashing::hash;
use int_to_bytes::int_to_bytes4;
use std::io::Cursor;
const SEED_SIZE: usize = 32;
const ROUND_SIZE: usize = 1;
const POSITION_WINDOW_SIZE: usize = 4;
const PIVOT_VIEW_SIZE: usize = SEED_SIZE + ROUND_SIZE;
const TOTAL_SIZE: usize = SEED_SIZE + ROUND_SIZE + POSITION_WINDOW_SIZE;
/// Shuffles an entire list in-place.
///
/// Note: this is equivalent to the `get_permutated_index` function, except it shuffles an entire
/// list not just a single index. With large lists this function has been observed to be 250x
/// faster than running `get_permutated_index` across an entire list.
///
/// Credits to [@protolambda](https://github.com/protolambda) for defining this algorithm.
///
/// Shuffles if `forwards == true`, otherwise un-shuffles.
///
/// Returns `None` under any of the following conditions:
/// - `list_size == 0`
/// - `list_size > 2**24`
/// - `list_size > usize::max_value() / 2`
pub fn shuffle_list(
mut input: Vec<usize>,
rounds: u8,
seed: &[u8],
forwards: bool,
) -> Option<Vec<usize>> {
let list_size = input.len();
if input.is_empty()
|| list_size > usize::max_value() / 2
|| list_size > 2_usize.pow(24)
|| rounds == 0
{
return None;
}
let mut buf: Vec<u8> = Vec::with_capacity(TOTAL_SIZE);
let mut r = if forwards { 0 } else { rounds - 1 };
buf.extend_from_slice(seed);
loop {
buf.splice(SEED_SIZE.., vec![r]);
let pivot = bytes_to_int64(&hash(&buf[0..PIVOT_VIEW_SIZE])[0..8]) as usize % list_size;
let mirror = (pivot + 1) >> 1;
buf.splice(PIVOT_VIEW_SIZE.., int_to_bytes4((pivot >> 8) as u32));
let mut source = hash(&buf[..]);
let mut byte_v = source[(pivot & 0xff) >> 3];
for i in 0..mirror {
let j = pivot - i;
if j & 0xff == 0xff {
buf.splice(PIVOT_VIEW_SIZE.., int_to_bytes4((j >> 8) as u32));
source = hash(&buf[..]);
}
if j & 0x07 == 0x07 {
byte_v = source[(j & 0xff) >> 3];
}
let bit_v = (byte_v >> (j & 0x07)) & 0x01;
if bit_v == 1 {
input.swap(i, j);
}
}
let mirror = (pivot + list_size + 1) >> 1;
let end = list_size - 1;
buf.splice(PIVOT_VIEW_SIZE.., int_to_bytes4((end >> 8) as u32));
let mut source = hash(&buf[..]);
let mut byte_v = source[(end & 0xff) >> 3];
for (loop_iter, i) in ((pivot + 1)..mirror).enumerate() {
let j = end - loop_iter;
if j & 0xff == 0xff {
buf.splice(PIVOT_VIEW_SIZE.., int_to_bytes4((j >> 8) as u32));
source = hash(&buf[..]);
}
if j & 0x07 == 0x07 {
byte_v = source[(j & 0xff) >> 3];
}
let bit_v = (byte_v >> (j & 0x07)) & 0x01;
if bit_v == 1 {
input.swap(i, j);
}
}
if forwards {
r += 1;
if r == rounds {
break;
}
} else {
if r == 0 {
break;
}
r -= 1;
}
}
Some(input)
}
fn bytes_to_int64(bytes: &[u8]) -> u64 {
let mut cursor = Cursor::new(bytes);
cursor.get_u64_le()
}
#[cfg(test)]
mod tests {
use super::*;
use hex;
use std::{fs::File, io::prelude::*, path::PathBuf};
use yaml_rust::yaml;
#[test]
fn returns_none_for_zero_length_list() {
assert_eq!(None, shuffle_list(vec![], 90, &[42, 42], true));
}
#[test]
fn test_vectors() {
let mut file = {
let mut file_path_buf = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
file_path_buf.push("src/specs/test_vector_permutated_index.yml");
File::open(file_path_buf).unwrap()
};
let mut yaml_str = String::new();
file.read_to_string(&mut yaml_str).unwrap();
let docs = yaml::YamlLoader::load_from_str(&yaml_str).unwrap();
let doc = &docs[0];
let test_cases = doc["test_cases"].as_vec().unwrap();
for (i, test_case) in test_cases.iter().enumerate() {
let index = test_case["index"].as_i64().unwrap() as usize;
let list_size = test_case["list_size"].as_i64().unwrap() as usize;
let permutated_index = test_case["permutated_index"].as_i64().unwrap() as usize;
let shuffle_round_count = test_case["shuffle_round_count"].as_i64().unwrap();
let seed_string = test_case["seed"].clone().into_string().unwrap();
let seed = hex::decode(seed_string.replace("0x", "")).unwrap();
let shuffle_round_count = if shuffle_round_count < (u8::max_value() as i64) {
shuffle_round_count as u8
} else {
panic!("shuffle_round_count must be a u8")
};
let list: Vec<usize> = (0..list_size).collect();
let shuffled =
shuffle_list(list.clone(), shuffle_round_count, &seed[..], true).unwrap();
assert_eq!(
list[index], shuffled[permutated_index],
"Failure on case #{} index: {}, list_size: {}, round_count: {}, seed: {}",
i, index, list_size, shuffle_round_count, seed_string
);
}
}
}