From 832d1bd295cbded813af1ff5b3eaa2e7ec9b4b08 Mon Sep 17 00:00:00 2001 From: Alex Stokes Date: Thu, 15 Nov 2018 09:19:59 -0800 Subject: [PATCH] Update bitfield to expand size when writing out-of-bounds --- .../utils/boolean-bitfield/Cargo.toml | 5 +- .../utils/boolean-bitfield/src/lib.rs | 96 ++++++++++++++----- 2 files changed, 73 insertions(+), 28 deletions(-) diff --git a/beacon_chain/utils/boolean-bitfield/Cargo.toml b/beacon_chain/utils/boolean-bitfield/Cargo.toml index b93e88f23..1633401e2 100644 --- a/beacon_chain/utils/boolean-bitfield/Cargo.toml +++ b/beacon_chain/utils/boolean-bitfield/Cargo.toml @@ -5,7 +5,4 @@ authors = ["Paul Hauner "] [dependencies] ssz = { path = "../ssz" } -bit-vec = "0.5.0" - -[dev-dependencies] -rand = "0.5.5" \ No newline at end of file +bit-vec = "0.5.0" \ No newline at end of file diff --git a/beacon_chain/utils/boolean-bitfield/src/lib.rs b/beacon_chain/utils/boolean-bitfield/src/lib.rs index 1f96a9afd..ceff3bbcf 100644 --- a/beacon_chain/utils/boolean-bitfield/src/lib.rs +++ b/beacon_chain/utils/boolean-bitfield/src/lib.rs @@ -1,12 +1,12 @@ extern crate bit_vec; extern crate ssz; -#[cfg(test)] -extern crate rand; - use bit_vec::BitVec; +use std::default; + /// A BooleanBitfield represents a set of booleans compactly stored as a vector of bits. +/// The BooleanBitfield is given a fixed size during construction. Reads outside of the current size return an out-of-bounds error. Writes outside of the current size expand the size of the set. #[derive(Debug, Clone, PartialEq)] pub struct BooleanBitfield(BitVec); @@ -18,13 +18,20 @@ pub enum Error { } impl BooleanBitfield { - /// Create a new bitfield with a length of zero. + /// Create a new bitfield. pub fn new() -> Self { - Self { 0: BitVec::new() } + Default::default() + } + + /// Create a new bitfield with the given length `initial_len` and all values set to `bit`. + pub fn from_elem(inital_len: usize, bit: bool) -> Self { + Self { + 0: BitVec::from_elem(inital_len, bit), + } } /// Create a new bitfield using the supplied `bytes` as input - pub fn from(bytes: &[u8]) -> Self { + pub fn from_bytes(bytes: &[u8]) -> Self { Self { 0: BitVec::from_bytes(bytes), } @@ -43,12 +50,19 @@ impl BooleanBitfield { /// Set the value of a bit. /// - /// Returns the previous value if successful. - /// If the index is out of bounds, we return an error to that extent. - pub fn set(&mut self, i: usize, value: bool) -> Result { - let previous = self.get(i)?; + /// If the index is out of bounds, we expand the size of the underlying set to include the new index. + /// Returns the previous value if there was one. + pub fn set(&mut self, i: usize, value: bool) -> Option { + let previous = match self.get(i) { + Ok(previous) => Some(previous), + Err(Error::OutOfBounds(_, len)) => { + let new_len = i - len + 1; + self.0.grow(new_len, false); + None + } + }; self.0.set(i, value); - Ok(previous) + previous } /// Returns the index of the highest set bit. Some(n) if some bit is set, None otherwise. @@ -72,6 +86,14 @@ impl BooleanBitfield { } } +impl default::Default for BooleanBitfield { + /// default provides the "empty" bitfield + /// Note: the empty bitfield is set to the `0` byte. + fn default() -> Self { + Self::from_elem(8, false) + } +} + impl ssz::Decodable for BooleanBitfield { fn ssz_decode(bytes: &[u8], index: usize) -> Result<(Self, usize), ssz::DecodeError> { let len = ssz::decode::decode_length(bytes, index, ssz::LENGTH_BYTES)?; @@ -82,7 +104,7 @@ impl ssz::Decodable for BooleanBitfield { if len == 0 { Ok((BooleanBitfield::new(), index + ssz::LENGTH_BYTES)) } else { - let field = BooleanBitfield::from(&bytes[(index + 4)..(index + len + 4)]); + let field = BooleanBitfield::from_bytes(&bytes[(index + 4)..(index + len + 4)]); let index = index + ssz::LENGTH_BYTES + len; Ok((field, index)) } @@ -96,11 +118,20 @@ mod tests { #[test] fn test_empty_bitfield() { let mut field = BooleanBitfield::new(); + let original_len = field.len(); - for _ in 0..100 { - let index: usize = rand::random(); - assert!(field.get(index).is_err()); - assert!(field.set(index, rand::random()).is_err()) + for i in 0..100 { + if i < original_len { + assert!(!field.get(i).unwrap()); + } else { + assert!(field.get(i).is_err()); + } + let previous = field.set(i, true); + if i < original_len { + assert!(!previous.unwrap()); + } else { + assert!(previous.is_none()); + } } } @@ -108,7 +139,7 @@ mod tests { #[test] fn test_get_from_bitfield() { - let field = BooleanBitfield::from(INPUT); + let field = BooleanBitfield::from_bytes(INPUT); let unset = field.get(0).unwrap(); assert!(!unset); let set = field.get(6).unwrap(); @@ -119,7 +150,7 @@ mod tests { #[test] fn test_set_for_bitfield() { - let mut field = BooleanBitfield::from(INPUT); + let mut field = BooleanBitfield::from_bytes(INPUT); let previous = field.set(10, true).unwrap(); assert!(!previous); let previous = field.get(10).unwrap(); @@ -132,7 +163,7 @@ mod tests { #[test] fn test_highest_set_bit() { - let field = BooleanBitfield::from(INPUT); + let field = BooleanBitfield::from_bytes(INPUT); assert_eq!(field.highest_set_bit().unwrap(), 14); let field = BooleanBitfield::new(); @@ -141,16 +172,16 @@ mod tests { #[test] fn test_len() { - let field = BooleanBitfield::from(INPUT); + let field = BooleanBitfield::from_bytes(INPUT); assert_eq!(field.len(), 16); let field = BooleanBitfield::new(); - assert_eq!(field.len(), 0); + assert_eq!(field.len(), 8); } #[test] fn test_num_set_bits() { - let field = BooleanBitfield::from(INPUT); + let field = BooleanBitfield::from_bytes(INPUT); assert_eq!(field.num_set_bits(), 2); let field = BooleanBitfield::new(); @@ -159,10 +190,27 @@ mod tests { #[test] fn test_to_bytes() { - let field = BooleanBitfield::from(INPUT); + let field = BooleanBitfield::from_bytes(INPUT); assert_eq!(field.to_bytes(), INPUT); let field = BooleanBitfield::new(); - assert_eq!(field.to_bytes(), vec![]); + assert_eq!(field.to_bytes(), vec![0]); + } + + #[test] + fn test_out_of_bounds() { + let mut field = BooleanBitfield::from_bytes(INPUT); + + let out_of_bounds_index = field.len(); + assert!(field.set(out_of_bounds_index, true).is_none()); + assert!(field.get(out_of_bounds_index).unwrap()); + + for i in 0..100 { + if i <= out_of_bounds_index { + assert!(field.set(i, true).is_some()); + } else { + assert!(field.set(i, true).is_none()); + } + } } }