Remove temporary heap allocations during shuffling (#867)

* Remove temp allocs in compute shuffled index

* Update shuffle list
This commit is contained in:
Paul Hauner 2020-02-25 09:00:09 +11:00 committed by GitHub
parent 2a9c718a20
commit 123c63119d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 91 additions and 36 deletions

1
Cargo.lock generated
View File

@ -4025,7 +4025,6 @@ dependencies = [
"eth2_hashing 0.1.1",
"ethereum-types 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)",
"hex 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
"int_to_bytes 0.1.0",
"yaml-rust 0.4.3 (registry+https://github.com/rust-lang/crates.io-index)",
]

View File

@ -5,7 +5,7 @@
//! defining it once in this crate makes it easy to replace.
#[cfg(not(target_arch = "wasm32"))]
use ring::digest::{digest, Context, SHA256};
pub use ring::digest::{digest, Context, SHA256};
#[cfg(target_arch = "wasm32")]
use sha2::{Digest, Sha256};

View File

@ -16,4 +16,4 @@ ethereum-types = "0.8.0"
[dependencies]
eth2_hashing = "0.1.0"
int_to_bytes = { path = "../int_to_bytes" }
ethereum-types = "0.8.0"

View File

@ -1,5 +1,5 @@
use eth2_hashing::hash;
use int_to_bytes::{int_to_bytes1, int_to_bytes4};
use crate::Hash256;
use eth2_hashing::{Context, SHA256};
use std::cmp::max;
/// Return `p(index)` in a pseudorandom permutation `p` of `0...list_size-1` with ``seed`` as entropy.
@ -43,27 +43,35 @@ pub fn compute_shuffled_index(
fn do_round(seed: &[u8], index: usize, pivot: usize, round: u8, list_size: usize) -> Option<usize> {
let flip = (pivot + (list_size - index)) % list_size;
let position = max(index, flip);
let source = hash_with_round_and_position(seed, round, position)?;
let source = hash_with_round_and_position(seed, round, position);
let byte = source[(position % 256) / 8];
let bit = (byte >> (position % 8)) % 2;
Some(if bit == 1 { flip } else { index })
}
fn hash_with_round_and_position(seed: &[u8], round: u8, position: usize) -> Option<Vec<u8>> {
let mut seed = seed.to_vec();
seed.append(&mut int_to_bytes1(round));
fn hash_with_round_and_position(seed: &[u8], round: u8, position: usize) -> Hash256 {
let mut context = Context::new(&SHA256);
context.update(seed);
context.update(&[round]);
/*
* Note: the specification has an implicit assertion in `int_to_bytes4` that `position / 256 <
* 2**24`. For efficiency, we do not check for that here as it is checked in `compute_shuffled_index`.
*/
seed.append(&mut int_to_bytes4((position / 256) as u32));
Some(hash(&seed[..]))
context.update(&(position / 256).to_le_bytes()[0..4]);
let digest = context.finish();
Hash256::from_slice(digest.as_ref())
}
fn hash_with_round(seed: &[u8], round: u8) -> Vec<u8> {
let mut seed = seed.to_vec();
seed.append(&mut int_to_bytes1(round));
hash(&seed[..])
fn hash_with_round(seed: &[u8], round: u8) -> Hash256 {
let mut context = Context::new(&SHA256);
context.update(seed);
context.update(&[round]);
let digest = context.finish();
Hash256::from_slice(digest.as_ref())
}
fn bytes_to_int64(slice: &[u8]) -> u64 {

View File

@ -19,3 +19,5 @@ mod shuffle_list;
pub use compute_shuffled_index::compute_shuffled_index;
pub use shuffle_list::shuffle_list;
type Hash256 = ethereum_types::H256;

View File

@ -1,5 +1,6 @@
use eth2_hashing::hash;
use int_to_bytes::int_to_bytes4;
use crate::Hash256;
use eth2_hashing::{Context, SHA256};
use std::mem;
const SEED_SIZE: usize = 32;
const ROUND_SIZE: usize = 1;
@ -7,6 +8,52 @@ const POSITION_WINDOW_SIZE: usize = 4;
const PIVOT_VIEW_SIZE: usize = SEED_SIZE + ROUND_SIZE;
const TOTAL_SIZE: usize = SEED_SIZE + ROUND_SIZE + POSITION_WINDOW_SIZE;
/// A helper struct to manage the buffer used during shuffling.
struct Buf([u8; TOTAL_SIZE]);
impl Buf {
/// Create a new buffer from the given `seed`.
///
/// ## Panics
///
/// Panics if `seed.len() != 32`.
fn new(seed: &[u8]) -> Self {
let mut buf = [0; TOTAL_SIZE];
buf[0..SEED_SIZE].copy_from_slice(seed);
Self(buf)
}
/// Set the shuffling round.
fn set_round(&mut self, round: u8) {
self.0[SEED_SIZE] = round;
}
/// Returns the new pivot. It is "raw" because it has not modulo the list size (this must be
/// done by the caller).
fn raw_pivot(&self) -> u64 {
let mut context = Context::new(&SHA256);
context.update(&self.0[0..PIVOT_VIEW_SIZE]);
let digest = context.finish();
let mut bytes = [0; mem::size_of::<u64>()];
bytes[..].copy_from_slice(&digest.as_ref()[0..mem::size_of::<u64>()]);
u64::from_le_bytes(bytes)
}
/// Add the current position into the buffer.
fn mix_in_position(&mut self, position: usize) {
self.0[PIVOT_VIEW_SIZE..].copy_from_slice(&position.to_le_bytes()[0..POSITION_WINDOW_SIZE]);
}
/// Hash the entire buffer.
fn hash(&self) -> Hash256 {
let mut context = Context::new(&SHA256);
context.update(&self.0[..]);
let digest = context.finish();
Hash256::from_slice(digest.as_ref())
}
}
/// Shuffles an entire list in-place.
///
/// Note: this is equivalent to the `compute_shuffled_index` function, except it shuffles an entire
@ -50,29 +97,27 @@ pub fn shuffle_list(
return None;
}
let mut buf: Vec<u8> = Vec::with_capacity(TOTAL_SIZE);
let mut buf = Buf::new(seed);
let mut r = if forwards { 0 } else { rounds - 1 };
buf.extend_from_slice(seed);
loop {
buf.splice(SEED_SIZE.., vec![r]);
buf.set_round(r);
let pivot = bytes_to_int64(&hash(&buf[0..PIVOT_VIEW_SIZE])[0..8]) as usize % list_size;
let pivot = buf.raw_pivot() as usize % list_size;
let mirror = (pivot + 1) >> 1;
buf.splice(PIVOT_VIEW_SIZE.., int_to_bytes4((pivot >> 8) as u32));
let mut source = hash(&buf[..]);
buf.mix_in_position(pivot >> 8);
let mut source = buf.hash();
let mut byte_v = source[(pivot & 0xff) >> 3];
for i in 0..mirror {
let j = pivot - i;
if j & 0xff == 0xff {
buf.splice(PIVOT_VIEW_SIZE.., int_to_bytes4((j >> 8) as u32));
source = hash(&buf[..]);
buf.mix_in_position(j >> 8);
source = buf.hash();
}
if j & 0x07 == 0x07 {
@ -88,16 +133,16 @@ pub fn shuffle_list(
let mirror = (pivot + list_size + 1) >> 1;
let end = list_size - 1;
buf.splice(PIVOT_VIEW_SIZE.., int_to_bytes4((end >> 8) as u32));
let mut source = hash(&buf[..]);
buf.mix_in_position(end >> 8);
let mut source = buf.hash();
let mut byte_v = source[(end & 0xff) >> 3];
for (loop_iter, i) in ((pivot + 1)..mirror).enumerate() {
let j = end - loop_iter;
if j & 0xff == 0xff {
buf.splice(PIVOT_VIEW_SIZE.., int_to_bytes4((j >> 8) as u32));
source = hash(&buf[..]);
buf.mix_in_position(j >> 8);
source = buf.hash();
}
if j & 0x07 == 0x07 {
@ -126,12 +171,6 @@ pub fn shuffle_list(
Some(input)
}
fn bytes_to_int64(slice: &[u8]) -> u64 {
let mut bytes = [0; 8];
bytes.copy_from_slice(&slice[0..8]);
u64::from_le_bytes(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
@ -140,4 +179,11 @@ mod tests {
fn returns_none_for_zero_length_list() {
assert_eq!(None, shuffle_list(vec![], 90, &[42, 42], true));
}
#[test]
fn sanity_check_constants() {
assert!(TOTAL_SIZE > SEED_SIZE);
assert!(TOTAL_SIZE > PIVOT_VIEW_SIZE);
assert!(mem::size_of::<usize>() >= POSITION_WINDOW_SIZE);
}
}