diff --git a/eth2/utils/ssz_types/src/bitfield.rs b/eth2/utils/ssz_types/src/bitfield.rs index b7cdb2876..1cb53ab1e 100644 --- a/eth2/utils/ssz_types/src/bitfield.rs +++ b/eth2/utils/ssz_types/src/bitfield.rs @@ -71,8 +71,10 @@ pub struct Bitfield { impl Bitfield> { pub fn with_capacity(num_bits: usize) -> Option { if num_bits <= N::to_usize() { + let num_bytes = std::cmp::max(bytes_for_bit_len(num_bits), 1); + Some(Self { - bytes: vec![0; bytes_for_bit_len(num_bits)], + bytes: vec![0; num_bytes], len: num_bits, _phantom: PhantomData, }) @@ -81,7 +83,7 @@ impl Bitfield> { } } - pub fn capacity() -> usize { + pub fn max_len() -> usize { N::to_usize() } @@ -89,15 +91,13 @@ impl Bitfield> { let len = self.len(); let mut bytes = self.as_slice().to_vec(); - if bytes_for_bit_len(len + 1) == bytes.len() + 1 { + while bytes_for_bit_len(len + 1) > bytes.len() { bytes.insert(0, 0); } let mut bitfield: Bitfield> = Bitfield::from_raw_bytes(bytes, len + 1) .expect("Bitfield capacity has been confirmed earlier."); - bitfield - .set(len, true) - .expect("Bitfield capacity has been confirmed earlier."); + bitfield.set(len, true).expect("Bitfield index must exist."); bitfield.bytes } @@ -110,17 +110,22 @@ impl Bitfield> { }; let len = initial_bitfield.highest_set_bit()?; - initial_bitfield - .set(len, false) - .expect("Bit has been confirmed to exist"); - let mut bytes = initial_bitfield.to_raw_bytes(); + if len <= Self::max_len() { + initial_bitfield + .set(len, false) + .expect("Bit has been confirmed to exist"); - if bytes_for_bit_len(len) < bytes.len() { - bytes.remove(0); + let mut bytes = initial_bitfield.to_raw_bytes(); + + if bytes_for_bit_len(len) < bytes.len() && bytes != &[0] { + bytes.remove(0); + } + + Self::from_raw_bytes(bytes, len) + } else { + None } - - Self::from_raw_bytes(bytes, len) } } @@ -203,7 +208,7 @@ impl Bitfield { &self.bytes } - pub fn from_raw_bytes(bytes: Vec, bit_len: usize) -> Option { + fn from_raw_bytes(bytes: Vec, bit_len: usize) -> Option { if bytes.len() == 1 && bit_len == 0 && bytes == &[0] { // A bitfield with `bit_len` 0 can only be represented by a single zero byte. Some(Self { @@ -552,7 +557,137 @@ impl cached_tree_hash::CachedTreeHash for Bitfield>; + mod bitlist { + use super::*; + + pub type BitList = crate::Bitfield>; + pub type BitList0 = BitList; + pub type BitList1 = BitList; + pub type BitList8 = BitList; + pub type BitList16 = BitList; + + #[test] + fn ssz_encode() { + assert_eq!( + BitList0::with_capacity(0).unwrap().as_ssz_bytes(), + vec![0b0000_00001], + ); + + assert_eq!( + BitList1::with_capacity(0).unwrap().as_ssz_bytes(), + vec![0b0000_00001], + ); + + assert_eq!( + BitList1::with_capacity(1).unwrap().as_ssz_bytes(), + vec![0b0000_00010], + ); + + assert_eq!( + BitList8::with_capacity(8).unwrap().as_ssz_bytes(), + vec![0b0000_0001, 0b0000_0000], + ); + + assert_eq!( + BitList8::with_capacity(7).unwrap().as_ssz_bytes(), + vec![0b1000_0000] + ); + + let mut b = BitList8::with_capacity(8).unwrap(); + for i in 0..8 { + b.set(i, true).unwrap(); + } + assert_eq!(b.as_ssz_bytes(), vec![0b0000_0001, 255]); + + let mut b = BitList8::with_capacity(8).unwrap(); + for i in 0..4 { + b.set(i, true).unwrap(); + } + assert_eq!(b.as_ssz_bytes(), vec![0b0000_0001, 0b0000_1111]); + + assert_eq!( + BitList16::with_capacity(16).unwrap().as_ssz_bytes(), + vec![0b0000_0001, 0b0000_0000, 0b0000_0000] + ); + } + + #[test] + fn ssz_decode() { + assert!(BitList0::from_ssz_bytes(&[0b0000_0000]).is_err()); + assert!(BitList1::from_ssz_bytes(&[0b0000_0000, 0b0000_0000]).is_err()); + assert!(BitList8::from_ssz_bytes(&[0b0000_0000]).is_err()); + assert!(BitList16::from_ssz_bytes(&[0b0000_0000]).is_err()); + + assert!(BitList0::from_ssz_bytes(&[0b0000_0001]).is_ok()); + assert!(BitList0::from_ssz_bytes(&[0b0000_0010]).is_err()); + + assert!(BitList1::from_ssz_bytes(&[0b0000_0001]).is_ok()); + assert!(BitList1::from_ssz_bytes(&[0b0000_0010]).is_ok()); + assert!(BitList1::from_ssz_bytes(&[0b0000_0100]).is_err()); + + assert!(BitList8::from_ssz_bytes(&[0b0000_0001]).is_ok()); + assert!(BitList8::from_ssz_bytes(&[0b0000_0010]).is_ok()); + assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0100]).is_ok()); + assert!(BitList8::from_ssz_bytes(&[0b0000_0010, 0b0000_0100]).is_err()); + } + + #[test] + fn ssz_round_trip() { + assert_round_trip(BitList0::with_capacity(0).unwrap()); + + for i in 0..2 { + assert_round_trip(BitList1::with_capacity(i).unwrap()); + } + for i in 0..9 { + assert_round_trip(BitList8::with_capacity(i).unwrap()); + } + for i in 0..17 { + assert_round_trip(BitList16::with_capacity(i).unwrap()); + } + + let mut b = BitList1::with_capacity(1).unwrap(); + b.set(0, true); + assert_round_trip(b); + + for i in 0..8 { + let mut b = BitList8::with_capacity(i).unwrap(); + for j in 0..i { + if j % 2 == 0 { + b.set(j, true); + } + } + assert_round_trip(b); + + let mut b = BitList8::with_capacity(i).unwrap(); + for j in 0..i { + b.set(j, true); + } + assert_round_trip(b); + } + + for i in 0..16 { + let mut b = BitList16::with_capacity(i).unwrap(); + for j in 0..i { + if j % 2 == 0 { + b.set(j, true); + } + } + assert_round_trip(b); + + let mut b = BitList16::with_capacity(i).unwrap(); + for j in 0..i { + b.set(j, true); + } + assert_round_trip(b); + } + } + + fn assert_round_trip(t: T) { + assert_eq!(T::from_ssz_bytes(&t.as_ssz_bytes()).unwrap(), t); + } + } + + type Bitfield = crate::Bitfield>; #[test] fn from_raw_bytes() {