diff --git a/beacon_chain/utils/boolean-bitfield/src/lib.rs b/beacon_chain/utils/boolean-bitfield/src/lib.rs index e0adc64dd..98518d70c 100644 --- a/beacon_chain/utils/boolean-bitfield/src/lib.rs +++ b/beacon_chain/utils/boolean-bitfield/src/lib.rs @@ -3,11 +3,12 @@ extern crate ssz; use bit_vec::BitVec; +use std::cmp; use std::default; /// A BooleanBitfield represents a set of booleans compactly stored as a vector of bits. /// The BooleanBitfield is given a fixed size during construction. Reads outside of the current size return an out-of-bounds error. Writes outside of the current size expand the size of the set. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct BooleanBitfield(BitVec); /// Error represents some reason a request against a bitfield was not satisfied @@ -23,6 +24,10 @@ impl BooleanBitfield { Default::default() } + pub fn with_capacity(initial_len: usize) -> Self { + Self::from_elem(initial_len, false) + } + /// Create a new bitfield with the given length `initial_len` and all values set to `bit`. pub fn from_elem(inital_len: usize, bit: bool) -> Self { Self { @@ -100,25 +105,18 @@ impl default::Default for BooleanBitfield { } } -// borrowed from bit_vec crate -fn reverse_bits(byte: u8) -> u8 { - let mut result = 0; - for i in 0..8 { - result = result | ((byte >> i) & 1) << (7 - i); +impl cmp::PartialEq for BooleanBitfield { + /// Determines equality by comparing the `ssz` encoding of the two candidates. + /// This method ensures that the presence of high-order (empty) bits in the highest byte do not exclude equality when they are in fact representing the same information. + fn eq(&self, other: &Self) -> bool { + ssz::ssz_encode(self) == ssz::ssz_encode(other) } - result } impl ssz::Encodable for BooleanBitfield { // ssz_append encodes Self according to the `ssz` spec. - // Note that we have to flip the endianness of the encoding with `reverse_bits` to account for an implementation detail of `bit-vec` crate. fn ssz_append(&self, s: &mut ssz::SszStream) { - let bytes: Vec = self - .to_bytes() - .iter() - .map(|&byte| reverse_bits(byte)) - .collect(); - s.append_vec(&bytes); + s.append_vec(&self.to_bytes()) } } @@ -134,10 +132,11 @@ impl ssz::Decodable for BooleanBitfield { } else { let bytes = &bytes[(index + 4)..(index + len + 4)]; - let mut field = BooleanBitfield::from_elem(0, false); + let count = len * 8; + let mut field = BooleanBitfield::with_capacity(count); for (byte_index, byte) in bytes.iter().enumerate() { for i in 0..8 { - let bit = byte & (1 << i); + let bit = byte & (128 >> i); if bit != 0 { field.set(8 * byte_index + i, true); } @@ -153,7 +152,7 @@ impl ssz::Decodable for BooleanBitfield { #[cfg(test)] mod tests { use super::*; - use ssz::SszStream; + use ssz::{ssz_encode, Decodable, SszStream}; #[test] fn test_new_bitfield() { @@ -317,28 +316,47 @@ mod tests { #[test] fn test_ssz_encode() { - let field = BooleanBitfield::from_elem(5, true); + let field = create_test_bitfield(); let mut stream = SszStream::new(); stream.append(&field); - assert_eq!(stream.drain(), vec![0, 0, 0, 1, 31]); + assert_eq!(stream.drain(), vec![0, 0, 0, 2, 225, 192]); let field = BooleanBitfield::from_elem(18, true); let mut stream = SszStream::new(); stream.append(&field); - assert_eq!(stream.drain(), vec![0, 0, 0, 3, 255, 255, 3]); + assert_eq!(stream.drain(), vec![0, 0, 0, 3, 255, 255, 192]); + } + + fn create_test_bitfield() -> BooleanBitfield { + let count = 2 * 8; + let mut field = BooleanBitfield::with_capacity(count); + + let indices = &[0, 1, 2, 7, 8, 9]; + for &i in indices { + field.set(i, true); + } + field } #[test] fn test_ssz_decode() { - let encoded = vec![0, 0, 0, 1, 31]; + let encoded = vec![0, 0, 0, 2, 225, 192]; let (field, _): (BooleanBitfield, usize) = ssz::decode_ssz(&encoded, 0).unwrap(); - let expected = BooleanBitfield::from_elem(5, true); + let expected = create_test_bitfield(); assert_eq!(field, expected); let encoded = vec![0, 0, 0, 3, 255, 255, 3]; let (field, _): (BooleanBitfield, usize) = ssz::decode_ssz(&encoded, 0).unwrap(); - let expected = BooleanBitfield::from_elem(18, true); + let expected = BooleanBitfield::from_bytes(&[255, 255, 3]); assert_eq!(field, expected); } + + #[test] + fn test_ssz_round_trip() { + let original = BooleanBitfield::from_bytes(&vec![18; 12][..]); + let ssz = ssz_encode(&original); + let (decoded, _) = BooleanBitfield::ssz_decode(&ssz, 0).unwrap(); + assert_eq!(original, decoded); + } }