diff --git a/Cargo.toml b/Cargo.toml index ef17b431e..22ec6fd98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "eth2/utils/slot_clock", "eth2/utils/ssz", "eth2/utils/ssz_derive", + "eth2/utils/ssz_types", "eth2/utils/swap_or_not_shuffle", "eth2/utils/tree_hash", "eth2/utils/tree_hash_derive", diff --git a/eth2/utils/ssz/src/lib.rs b/eth2/utils/ssz/src/lib.rs index bcb9f525c..886433f14 100644 --- a/eth2/utils/ssz/src/lib.rs +++ b/eth2/utils/ssz/src/lib.rs @@ -47,9 +47,9 @@ pub use encode::{Encode, SszEncoder}; pub const BYTES_PER_LENGTH_OFFSET: usize = 4; /// The maximum value that can be represented using `BYTES_PER_LENGTH_OFFSET`. #[cfg(target_pointer_width = "32")] -pub const MAX_LENGTH_VALUE: usize = (std::u32::MAX >> 8 * (4 - BYTES_PER_LENGTH_OFFSET)) as usize; +pub const MAX_LENGTH_VALUE: usize = (std::u32::MAX >> (8 * (4 - BYTES_PER_LENGTH_OFFSET))) as usize; #[cfg(target_pointer_width = "64")] -pub const MAX_LENGTH_VALUE: usize = (std::u64::MAX >> 8 * (8 - BYTES_PER_LENGTH_OFFSET)) as usize; +pub const MAX_LENGTH_VALUE: usize = (std::u64::MAX >> (8 * (8 - BYTES_PER_LENGTH_OFFSET))) as usize; /// Convenience function to SSZ encode an object supporting ssz::Encode. /// diff --git a/eth2/utils/ssz_types/Cargo.toml b/eth2/utils/ssz_types/Cargo.toml new file mode 100644 index 000000000..2e4cbc899 --- /dev/null +++ b/eth2/utils/ssz_types/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "ssz_types" +version = "0.1.0" +authors = ["Paul Hauner "] +edition = "2018" + +[dependencies] +cached_tree_hash = { path = "../cached_tree_hash" } +tree_hash = { path = "../tree_hash" } +serde = "1.0" +serde_derive = "1.0" +serde_hex = { path = "../serde_hex" } +eth2_ssz = { path = "../ssz" } +typenum = "1.10" + +[dev-dependencies] +serde_yaml = "0.8" diff --git a/eth2/utils/ssz_types/src/bitfield.rs b/eth2/utils/ssz_types/src/bitfield.rs new file mode 100644 index 000000000..de9a198f3 --- /dev/null +++ b/eth2/utils/ssz_types/src/bitfield.rs @@ -0,0 +1,1140 @@ +use crate::Error; +use core::marker::PhantomData; +use serde::de::{Deserialize, Deserializer}; +use serde::ser::{Serialize, Serializer}; +use serde_hex::{encode as hex_encode, PrefixedHexVisitor}; +use ssz::{Decode, Encode}; +use typenum::Unsigned; + +/// A marker trait applied to `Variable` and `Fixed` that defines the behaviour of a `Bitfield`. +pub trait BitfieldBehaviour: Clone {} + +/// A marker struct used to declare SSZ `Variable` behaviour on a `Bitfield`. +/// +/// See the [`Bitfield`](struct.Bitfield.html) docs for usage. +#[derive(Clone, PartialEq, Debug)] +pub struct Variable { + _phantom: PhantomData, +} + +/// A marker struct used to declare SSZ `Fixed` behaviour on a `Bitfield`. +/// +/// See the [`Bitfield`](struct.Bitfield.html) docs for usage. +#[derive(Clone, PartialEq, Debug)] +pub struct Fixed { + _phantom: PhantomData, +} + +impl BitfieldBehaviour for Variable {} +impl BitfieldBehaviour for Fixed {} + +/// A heap-allocated, ordered, variable-length collection of `bool` values, limited to `N` bits. +pub type BitList = Bitfield>; + +/// A heap-allocated, ordered, fixed-length collection of `bool` values, with `N` bits. +/// +/// See [Bitfield](struct.Bitfield.html) documentation. +pub type BitVector = Bitfield>; + +/// A heap-allocated, ordered, fixed-length, collection of `bool` values. Use of +/// [`BitList`](type.BitList.html) or [`BitVector`](type.BitVector.html) type aliases is preferred +/// over direct use of this struct. +/// +/// The `T` type parameter is used to define length behaviour with the `Variable` or `Fixed` marker +/// structs. +/// +/// The length of the Bitfield is set at instantiation (i.e., runtime, not compile time). However, +/// use with a `Variable` sets a type-level (i.e., compile-time) maximum length and `Fixed` +/// provides a type-level fixed length. +/// +/// ## Example +/// +/// The example uses the following crate-level type aliases: +/// +/// - `BitList` is an alias for `Bitfield>` +/// - `BitVector` is an alias for `Bitfield>` +/// +/// ``` +/// use ssz_types::{BitVector, BitList, typenum}; +/// +/// // `BitList` has a type-level maximum length. The length of the list is specified at runtime +/// // and it must be less than or equal to `N`. After instantiation, `BitList` cannot grow or +/// // shrink. +/// type BitList8 = BitList; +/// +/// // Creating a `BitList` with a larger-than-`N` capacity returns `None`. +/// assert!(BitList8::with_capacity(9).is_err()); +/// +/// let mut bitlist = BitList8::with_capacity(4).unwrap(); // `BitList` permits a capacity of less than the maximum. +/// assert!(bitlist.set(3, true).is_ok()); // Setting inside the instantiation capacity is permitted. +/// assert!(bitlist.set(5, true).is_err()); // Setting outside that capacity is not. +/// +/// // `BitVector` has a type-level fixed length. Unlike `BitList`, it cannot be instantiated with a custom length +/// // or grow/shrink. +/// type BitVector8 = BitVector; +/// +/// let mut bitvector = BitVector8::new(); +/// assert_eq!(bitvector.len(), 8); // `BitVector` length is fixed at the type-level. +/// assert!(bitvector.set(7, true).is_ok()); // Setting inside the capacity is permitted. +/// assert!(bitvector.set(9, true).is_err()); // Setting outside the capacity is not. +/// +/// ``` +/// +/// ## Note +/// +/// The internal representation of the bitfield is the same as that required by SSZ. The highest +/// byte (by `Vec` index) stores the lowest bit-indices and the right-most bit stores the lowest +/// bit-index. E.g., `vec![0b0000_0010, 0b0000_0001]` has bits `0, 9` set. +#[derive(Clone, Debug, PartialEq)] +pub struct Bitfield { + bytes: Vec, + len: usize, + _phantom: PhantomData, +} + +impl Bitfield> { + /// Instantiate with capacity for `num_bits` boolean values. The length cannot be grown or + /// shrunk after instantiation. + /// + /// All bits are initialized to `false`. + /// + /// Returns `None` if `num_bits > N`. + pub fn with_capacity(num_bits: usize) -> Result { + if num_bits <= N::to_usize() { + Ok(Self { + bytes: vec![0; bytes_for_bit_len(num_bits)], + len: num_bits, + _phantom: PhantomData, + }) + } else { + Err(Error::OutOfBounds { + i: Self::max_len(), + len: Self::max_len(), + }) + } + } + + /// Equal to `N` regardless of the value supplied to `with_capacity`. + pub fn max_len() -> usize { + N::to_usize() + } + + /// Consumes `self`, returning a serialized representation. + /// + /// The output is faithful to the SSZ encoding of `self`, such that a leading `true` bit is + /// used to indicate the length of the bitfield. + /// + /// ## Example + /// ``` + /// use ssz_types::{BitList, typenum}; + /// + /// type BitList8 = BitList; + /// + /// let b = BitList8::with_capacity(4).unwrap(); + /// + /// assert_eq!(b.into_bytes(), vec![0b0001_0000]); + /// ``` + pub fn into_bytes(self) -> Vec { + let len = self.len(); + let mut bytes = self.as_slice().to_vec(); + + while bytes_for_bit_len(len + 1) > bytes.len() { + bytes.insert(0, 0); + } + + let mut bitfield: Bitfield> = Bitfield::from_raw_bytes(bytes, len + 1) + .expect("Bitfield capacity has been confirmed earlier."); + bitfield.set(len, true).expect("Bitfield index must exist."); + + bitfield.bytes + } + + /// Instantiates a new instance from `bytes`. Consumes the same format that `self.into_bytes()` + /// produces (SSZ). + /// + /// Returns `None` if `bytes` are not a valid encoding. + pub fn from_bytes(bytes: Vec) -> Result { + let mut initial_bitfield: Bitfield> = { + let num_bits = bytes.len() * 8; + Bitfield::from_raw_bytes(bytes, num_bits) + .expect("Must have adequate bytes for bit count.") + }; + + let len = initial_bitfield + .highest_set_bit() + .ok_or_else(|| Error::MissingLengthInformation)?; + + if len <= Self::max_len() { + initial_bitfield + .set(len, false) + .expect("Bit has been confirmed to exist"); + + let mut bytes = initial_bitfield.into_raw_bytes(); + + if bytes_for_bit_len(len) < bytes.len() && bytes != [0] { + bytes.remove(0); + } + + Self::from_raw_bytes(bytes, len) + } else { + Err(Error::OutOfBounds { + i: Self::max_len(), + len: Self::max_len(), + }) + } + } +} + +impl Bitfield> { + /// Instantiate a new `Bitfield` with a fixed-length of `N` bits. + /// + /// All bits are initialized to `false`. + pub fn new() -> Self { + Self { + bytes: vec![0; bytes_for_bit_len(Self::capacity())], + len: Self::capacity(), + _phantom: PhantomData, + } + } + + /// Returns `N`, the number of bits in `Self`. + pub fn capacity() -> usize { + N::to_usize() + } + + /// Consumes `self`, returning a serialized representation. + /// + /// The output is faithful to the SSZ encoding of `self`. + /// + /// ## Example + /// ``` + /// use ssz_types::{BitVector, typenum}; + /// + /// type BitVector4 = BitVector; + /// + /// assert_eq!(BitVector4::new().into_bytes(), vec![0b0000_0000]); + /// ``` + pub fn into_bytes(self) -> Vec { + self.into_raw_bytes() + } + + /// Instantiates a new instance from `bytes`. Consumes the same format that `self.into_bytes()` + /// produces (SSZ). + /// + /// Returns `None` if `bytes` are not a valid encoding. + pub fn from_bytes(bytes: Vec) -> Result { + Self::from_raw_bytes(bytes, Self::capacity()) + } +} + +impl Default for Bitfield> { + fn default() -> Self { + Self::new() + } +} + +impl Bitfield { + /// Sets the `i`'th bit to `value`. + /// + /// Returns `None` if `i` is out-of-bounds of `self`. + pub fn set(&mut self, i: usize, value: bool) -> Result<(), Error> { + if i < self.len { + let byte = { + let num_bytes = self.bytes.len(); + let offset = i / 8; + self.bytes + .get_mut(num_bytes - offset - 1) + .expect("Cannot be OOB if less than self.len") + }; + + if value { + *byte |= 1 << (i % 8) + } else { + *byte &= !(1 << (i % 8)) + } + + Ok(()) + } else { + Err(Error::OutOfBounds { i, len: self.len }) + } + } + + /// Returns the value of the `i`'th bit. + /// + /// Returns `None` if `i` is out-of-bounds of `self`. + pub fn get(&self, i: usize) -> Result { + if i < self.len { + let byte = { + let num_bytes = self.bytes.len(); + let offset = i / 8; + self.bytes + .get(num_bytes - offset - 1) + .expect("Cannot be OOB if less than self.len") + }; + + Ok(*byte & 1 << (i % 8) > 0) + } else { + Err(Error::OutOfBounds { i, len: self.len }) + } + } + + /// Returns the number of bits stored in `self`. + pub fn len(&self) -> usize { + self.len + } + + /// Returns `true` if `self.len() == 0`. + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the underlying bytes representation of the bitfield. + pub fn into_raw_bytes(self) -> Vec { + self.bytes + } + + /// Returns a view into the underlying bytes representation of the bitfield. + pub fn as_slice(&self) -> &[u8] { + &self.bytes + } + + /// Instantiates from the given `bytes`, which are the same format as output from + /// `self.into_raw_bytes()`. + /// + /// Returns `None` if: + /// + /// - `bytes` is not the minimal required bytes to represent a bitfield of `bit_len` bits. + /// - `bit_len` is not a multiple of 8 and `bytes` contains set bits that are higher than, or + /// equal to `bit_len`. + fn from_raw_bytes(bytes: Vec, bit_len: usize) -> Result { + if bit_len == 0 { + if bytes.len() == 1 && bytes == [0] { + // A bitfield with `bit_len` 0 can only be represented by a single zero byte. + Ok(Self { + bytes, + len: 0, + _phantom: PhantomData, + }) + } else { + Err(Error::ExcessBits) + } + } else if bytes.len() != bytes_for_bit_len(bit_len) { + // The number of bytes must be the minimum required to represent `bit_len`. + Err(Error::InvalidByteCount { + given: bytes.len(), + expected: bytes_for_bit_len(bit_len), + }) + } else { + // Ensure there are no bits higher than `bit_len` that are set to true. + let (mask, _) = u8::max_value().overflowing_shr(8 - (bit_len as u32 % 8)); + + if (bytes.first().expect("Guarded against empty bytes") & !mask) == 0 { + Ok(Self { + bytes, + len: bit_len, + _phantom: PhantomData, + }) + } else { + Err(Error::ExcessBits) + } + } + } + + /// Returns the `Some(i)` where `i` is the highest index with a set bit. Returns `None` if + /// there are no set bits. + pub fn highest_set_bit(&self) -> Option { + let byte_i = self.bytes.iter().position(|byte| *byte > 0)?; + let bit_i = 7 - self.bytes[byte_i].leading_zeros() as usize; + + Some((self.bytes.len().saturating_sub(1) - byte_i) * 8 + bit_i) + } + + /// Returns an iterator across bitfield `bool` values, starting at the lowest index. + pub fn iter(&self) -> BitIter<'_, T> { + BitIter { + bitfield: self, + i: 0, + } + } + + /// Returns true if no bits are set. + pub fn is_zero(&self) -> bool { + self.bytes.iter().all(|byte| *byte == 0) + } + + /// Compute the intersection (binary-and) of this bitfield with another. + /// + /// Returns `None` if `self.is_comparable(other) == false`. + pub fn intersection(&self, other: &Self) -> Option { + if self.is_comparable(other) { + let mut res = self.clone(); + res.intersection_inplace(other); + Some(res) + } else { + None + } + } + + /// Like `intersection` but in-place (updates `self`). + pub fn intersection_inplace(&mut self, other: &Self) -> Option<()> { + if self.is_comparable(other) { + for i in 0..self.bytes.len() { + self.bytes[i] &= other.bytes[i]; + } + Some(()) + } else { + None + } + } + + /// Compute the union (binary-or) of this bitfield with another. + /// + /// Returns `None` if `self.is_comparable(other) == false`. + pub fn union(&self, other: &Self) -> Option { + if self.is_comparable(other) { + let mut res = self.clone(); + res.union_inplace(other); + Some(res) + } else { + None + } + } + + /// Like `union` but in-place (updates `self`). + pub fn union_inplace(&mut self, other: &Self) -> Option<()> { + if self.is_comparable(other) { + for i in 0..self.bytes.len() { + self.bytes[i] |= other.bytes[i]; + } + Some(()) + } else { + None + } + } + + /// Compute the difference (binary-minus) of this bitfield with another. Lengths must match. + /// + /// Returns `None` if `self.is_comparable(other) == false`. + pub fn difference(&self, other: &Self) -> Option { + if self.is_comparable(other) { + let mut res = self.clone(); + res.difference_inplace(other); + Some(res) + } else { + None + } + } + + /// Like `difference` but in-place (updates `self`). + pub fn difference_inplace(&mut self, other: &Self) -> Option<()> { + if self.is_comparable(other) { + for i in 0..self.bytes.len() { + self.bytes[i] &= !other.bytes[i]; + } + Some(()) + } else { + None + } + } + + /// Returns true if `self` and `other` have the same lengths and can be used in binary + /// comparison operations. + pub fn is_comparable(&self, other: &Self) -> bool { + (self.len() == other.len()) && (self.bytes.len() == other.bytes.len()) + } +} + +/// Returns the minimum required bytes to represent a given number of bits. +/// +/// `bit_len == 0` requires a single byte. +fn bytes_for_bit_len(bit_len: usize) -> usize { + std::cmp::max(1, (bit_len + 7) / 8) +} + +/// An iterator over the bits in a `Bitfield`. +pub struct BitIter<'a, T> { + bitfield: &'a Bitfield, + i: usize, +} + +impl<'a, T: BitfieldBehaviour> Iterator for BitIter<'a, T> { + type Item = bool; + + fn next(&mut self) -> Option { + let res = self.bitfield.get(self.i).ok()?; + self.i += 1; + Some(res) + } +} + +impl Encode for Bitfield> { + fn is_ssz_fixed_len() -> bool { + false + } + + fn ssz_append(&self, buf: &mut Vec) { + buf.append(&mut self.clone().into_bytes()) + } +} + +impl Decode for Bitfield> { + fn is_ssz_fixed_len() -> bool { + false + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + Self::from_bytes(bytes.to_vec()).map_err(|e| { + ssz::DecodeError::BytesInvalid(format!("BitList failed to decode: {:?}", e)) + }) + } +} + +impl Encode for Bitfield> { + fn is_ssz_fixed_len() -> bool { + true + } + + fn ssz_fixed_len() -> usize { + bytes_for_bit_len(N::to_usize()) + } + + fn ssz_append(&self, buf: &mut Vec) { + buf.append(&mut self.clone().into_bytes()) + } +} + +impl Decode for Bitfield> { + fn is_ssz_fixed_len() -> bool { + false + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + Self::from_bytes(bytes.to_vec()).map_err(|e| { + ssz::DecodeError::BytesInvalid(format!("BitVector failed to decode: {:?}", e)) + }) + } +} + +impl Serialize for Bitfield> { + /// Serde serialization is compliant with the Ethereum YAML test format. + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&hex_encode(self.as_ssz_bytes())) + } +} + +impl<'de, N: Unsigned + Clone> Deserialize<'de> for Bitfield> { + /// Serde serialization is compliant with the Ethereum YAML test format. + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let bytes = deserializer.deserialize_str(PrefixedHexVisitor)?; + Self::from_ssz_bytes(&bytes) + .map_err(|e| serde::de::Error::custom(format!("Bitfield {:?}", e))) + } +} + +impl Serialize for Bitfield> { + /// Serde serialization is compliant with the Ethereum YAML test format. + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&hex_encode(self.as_ssz_bytes())) + } +} + +impl<'de, N: Unsigned + Clone> Deserialize<'de> for Bitfield> { + /// Serde serialization is compliant with the Ethereum YAML test format. + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let bytes = deserializer.deserialize_str(PrefixedHexVisitor)?; + Self::from_ssz_bytes(&bytes) + .map_err(|e| serde::de::Error::custom(format!("Bitfield {:?}", e))) + } +} + +impl tree_hash::TreeHash for Bitfield> { + fn tree_hash_type() -> tree_hash::TreeHashType { + tree_hash::TreeHashType::List + } + + fn tree_hash_packed_encoding(&self) -> Vec { + unreachable!("List should never be packed.") + } + + fn tree_hash_packing_factor() -> usize { + unreachable!("List should never be packed.") + } + + fn tree_hash_root(&self) -> Vec { + // TODO: pad this out to max length. + self.as_ssz_bytes().tree_hash_root() + } +} + +impl tree_hash::TreeHash for Bitfield> { + fn tree_hash_type() -> tree_hash::TreeHashType { + // TODO: move this to be a vector. + tree_hash::TreeHashType::List + } + + fn tree_hash_packed_encoding(&self) -> Vec { + // TODO: move this to be a vector. + unreachable!("Vector should never be packed.") + } + + fn tree_hash_packing_factor() -> usize { + // TODO: move this to be a vector. + unreachable!("Vector should never be packed.") + } + + fn tree_hash_root(&self) -> Vec { + self.as_ssz_bytes().tree_hash_root() + } +} + +impl cached_tree_hash::CachedTreeHash for Bitfield> { + fn new_tree_hash_cache( + &self, + depth: usize, + ) -> Result { + let bytes = self.clone().into_bytes(); + + let (mut cache, schema) = cached_tree_hash::vec::new_tree_hash_cache(&bytes, depth)?; + + cache.add_length_nodes(schema.into_overlay(0).chunk_range(), bytes.len())?; + + Ok(cache) + } + + fn num_tree_hash_cache_chunks(&self) -> usize { + // Add two extra nodes to cater for the node before and after to allow mixing-in length. + cached_tree_hash::BTreeOverlay::new(self, 0, 0).num_chunks() + 2 + } + + fn tree_hash_cache_schema(&self, depth: usize) -> cached_tree_hash::BTreeSchema { + let bytes = self.clone().into_bytes(); + cached_tree_hash::vec::produce_schema(&bytes, depth) + } + + fn update_tree_hash_cache( + &self, + cache: &mut cached_tree_hash::TreeHashCache, + ) -> Result<(), cached_tree_hash::Error> { + let bytes = self.clone().into_bytes(); + + // Skip the length-mixed-in root node. + cache.chunk_index += 1; + + // Update the cache, returning the new overlay. + let new_overlay = cached_tree_hash::vec::update_tree_hash_cache(&bytes, cache)?; + + // Mix in length + cache.mix_in_length(new_overlay.chunk_range(), bytes.len())?; + + // Skip an extra node to clear the length node. + cache.chunk_index += 1; + + Ok(()) + } +} + +impl cached_tree_hash::CachedTreeHash for Bitfield> { + fn new_tree_hash_cache( + &self, + depth: usize, + ) -> Result { + let (cache, _schema) = + cached_tree_hash::vec::new_tree_hash_cache(&ssz::ssz_encode(self), depth)?; + + Ok(cache) + } + + fn tree_hash_cache_schema(&self, depth: usize) -> cached_tree_hash::BTreeSchema { + let lengths = vec![ + 1; + cached_tree_hash::merkleize::num_unsanitized_leaves(bytes_for_bit_len( + N::to_usize() + )) + ]; + cached_tree_hash::BTreeSchema::from_lengths(depth, lengths) + } + + fn update_tree_hash_cache( + &self, + cache: &mut cached_tree_hash::TreeHashCache, + ) -> Result<(), cached_tree_hash::Error> { + cached_tree_hash::vec::update_tree_hash_cache(&ssz::ssz_encode(self), cache)?; + + Ok(()) + } +} + +#[cfg(test)] +mod bitvector { + use super::*; + use crate::BitVector; + + pub type BitVector0 = BitVector; + pub type BitVector1 = BitVector; + pub type BitVector4 = BitVector; + pub type BitVector8 = BitVector; + pub type BitVector16 = BitVector; + + #[test] + fn ssz_encode() { + assert_eq!(BitVector0::new().as_ssz_bytes(), vec![0b0000_0000]); + assert_eq!(BitVector1::new().as_ssz_bytes(), vec![0b0000_0000]); + assert_eq!(BitVector4::new().as_ssz_bytes(), vec![0b0000_0000]); + assert_eq!(BitVector8::new().as_ssz_bytes(), vec![0b0000_0000]); + assert_eq!( + BitVector16::new().as_ssz_bytes(), + vec![0b0000_0000, 0b0000_0000] + ); + + let mut b = BitVector8::new(); + for i in 0..8 { + b.set(i, true).unwrap(); + } + assert_eq!(b.as_ssz_bytes(), vec![255]); + + let mut b = BitVector4::new(); + for i in 0..4 { + b.set(i, true).unwrap(); + } + assert_eq!(b.as_ssz_bytes(), vec![0b0000_1111]); + } + + #[test] + fn ssz_decode() { + assert!(BitVector0::from_ssz_bytes(&[0b0000_0000]).is_ok()); + assert!(BitVector0::from_ssz_bytes(&[0b0000_0001]).is_err()); + assert!(BitVector0::from_ssz_bytes(&[0b0000_0010]).is_err()); + + assert!(BitVector1::from_ssz_bytes(&[0b0000_0001]).is_ok()); + assert!(BitVector1::from_ssz_bytes(&[0b0000_0010]).is_err()); + assert!(BitVector1::from_ssz_bytes(&[0b0000_0100]).is_err()); + assert!(BitVector1::from_ssz_bytes(&[0b0000_0000, 0b0000_0000]).is_err()); + + assert!(BitVector8::from_ssz_bytes(&[0b0000_0000]).is_ok()); + assert!(BitVector8::from_ssz_bytes(&[1, 0b0000_0000]).is_err()); + assert!(BitVector8::from_ssz_bytes(&[0b0000_0001]).is_ok()); + assert!(BitVector8::from_ssz_bytes(&[0b0000_0010]).is_ok()); + assert!(BitVector8::from_ssz_bytes(&[0b0000_0001, 0b0000_0100]).is_err()); + assert!(BitVector8::from_ssz_bytes(&[0b0000_0010, 0b0000_0100]).is_err()); + + assert!(BitVector16::from_ssz_bytes(&[0b0000_0000]).is_err()); + assert!(BitVector16::from_ssz_bytes(&[0b0000_0000, 0b0000_0000]).is_ok()); + assert!(BitVector16::from_ssz_bytes(&[1, 0b0000_0000, 0b0000_0000]).is_err()); + } + + #[test] + fn ssz_round_trip() { + assert_round_trip(BitVector0::new()); + + let mut b = BitVector1::new(); + b.set(0, true).unwrap(); + assert_round_trip(b); + + let mut b = BitVector8::new(); + for j in 0..8 { + if j % 2 == 0 { + b.set(j, true).unwrap(); + } + } + assert_round_trip(b); + + let mut b = BitVector8::new(); + for j in 0..8 { + b.set(j, true).unwrap(); + } + assert_round_trip(b); + + let mut b = BitVector16::new(); + for j in 0..16 { + if j % 2 == 0 { + b.set(j, true).unwrap(); + } + } + assert_round_trip(b); + + let mut b = BitVector16::new(); + for j in 0..16 { + b.set(j, true).unwrap(); + } + assert_round_trip(b); + } + + fn assert_round_trip(t: T) { + assert_eq!(T::from_ssz_bytes(&t.as_ssz_bytes()).unwrap(), t); + } +} + +#[cfg(test)] +mod bitlist { + use super::*; + use crate::BitList; + + pub type BitList0 = BitList; + pub type BitList1 = BitList; + pub type BitList8 = BitList; + pub type BitList16 = BitList; + pub type BitList1024 = BitList; + + #[test] + fn ssz_encode() { + assert_eq!( + BitList0::with_capacity(0).unwrap().as_ssz_bytes(), + vec![0b0000_00001], + ); + + assert_eq!( + BitList1::with_capacity(0).unwrap().as_ssz_bytes(), + vec![0b0000_00001], + ); + + assert_eq!( + BitList1::with_capacity(1).unwrap().as_ssz_bytes(), + vec![0b0000_00010], + ); + + assert_eq!( + BitList8::with_capacity(8).unwrap().as_ssz_bytes(), + vec![0b0000_0001, 0b0000_0000], + ); + + assert_eq!( + BitList8::with_capacity(7).unwrap().as_ssz_bytes(), + vec![0b1000_0000] + ); + + let mut b = BitList8::with_capacity(8).unwrap(); + for i in 0..8 { + b.set(i, true).unwrap(); + } + assert_eq!(b.as_ssz_bytes(), vec![0b0000_0001, 255]); + + let mut b = BitList8::with_capacity(8).unwrap(); + for i in 0..4 { + b.set(i, true).unwrap(); + } + assert_eq!(b.as_ssz_bytes(), vec![0b0000_0001, 0b0000_1111]); + + assert_eq!( + BitList16::with_capacity(16).unwrap().as_ssz_bytes(), + vec![0b0000_0001, 0b0000_0000, 0b0000_0000] + ); + } + + #[test] + fn ssz_decode() { + assert!(BitList0::from_ssz_bytes(&[0b0000_0000]).is_err()); + assert!(BitList1::from_ssz_bytes(&[0b0000_0000, 0b0000_0000]).is_err()); + assert!(BitList8::from_ssz_bytes(&[0b0000_0000]).is_err()); + assert!(BitList16::from_ssz_bytes(&[0b0000_0000]).is_err()); + + assert!(BitList0::from_ssz_bytes(&[0b0000_0001]).is_ok()); + assert!(BitList0::from_ssz_bytes(&[0b0000_0010]).is_err()); + + assert!(BitList1::from_ssz_bytes(&[0b0000_0001]).is_ok()); + assert!(BitList1::from_ssz_bytes(&[0b0000_0010]).is_ok()); + assert!(BitList1::from_ssz_bytes(&[0b0000_0100]).is_err()); + + assert!(BitList8::from_ssz_bytes(&[0b0000_0001]).is_ok()); + assert!(BitList8::from_ssz_bytes(&[0b0000_0010]).is_ok()); + assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0100]).is_ok()); + assert!(BitList8::from_ssz_bytes(&[0b0000_0010, 0b0000_0100]).is_err()); + } + + #[test] + fn ssz_round_trip() { + assert_round_trip(BitList0::with_capacity(0).unwrap()); + + for i in 0..2 { + assert_round_trip(BitList1::with_capacity(i).unwrap()); + } + for i in 0..9 { + assert_round_trip(BitList8::with_capacity(i).unwrap()); + } + for i in 0..17 { + assert_round_trip(BitList16::with_capacity(i).unwrap()); + } + + let mut b = BitList1::with_capacity(1).unwrap(); + b.set(0, true).unwrap(); + assert_round_trip(b); + + for i in 0..8 { + let mut b = BitList8::with_capacity(i).unwrap(); + for j in 0..i { + if j % 2 == 0 { + b.set(j, true).unwrap(); + } + } + assert_round_trip(b); + + let mut b = BitList8::with_capacity(i).unwrap(); + for j in 0..i { + b.set(j, true).unwrap(); + } + assert_round_trip(b); + } + + for i in 0..16 { + let mut b = BitList16::with_capacity(i).unwrap(); + for j in 0..i { + if j % 2 == 0 { + b.set(j, true).unwrap(); + } + } + assert_round_trip(b); + + let mut b = BitList16::with_capacity(i).unwrap(); + for j in 0..i { + b.set(j, true).unwrap(); + } + assert_round_trip(b); + } + } + + fn assert_round_trip(t: T) { + assert_eq!(T::from_ssz_bytes(&t.as_ssz_bytes()).unwrap(), t); + } + + #[test] + fn from_raw_bytes() { + assert!(BitList1024::from_raw_bytes(vec![0b0000_0000], 0).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b0000_0001], 1).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b0000_0011], 2).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b0000_0111], 3).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b0000_1111], 4).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b0001_1111], 5).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b0011_1111], 6).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b0111_1111], 7).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b1111_1111], 8).is_ok()); + + assert!(BitList1024::from_raw_bytes(vec![0b0000_0001, 0b1111_1111], 9).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b0000_0011, 0b1111_1111], 10).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b0000_0111, 0b1111_1111], 11).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b0000_1111, 0b1111_1111], 12).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b0001_1111, 0b1111_1111], 13).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b0011_1111, 0b1111_1111], 14).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b0111_1111, 0b1111_1111], 15).is_ok()); + assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b1111_1111], 16).is_ok()); + + for i in 0..8 { + assert!(BitList1024::from_raw_bytes(vec![], i).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b1111_1111], i).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b1111_1110, 0b0000_0000], i).is_err()); + } + + assert!(BitList1024::from_raw_bytes(vec![0b0000_0001], 0).is_err()); + + assert!(BitList1024::from_raw_bytes(vec![0b0000_0001], 0).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b0000_0011], 1).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b0000_0111], 2).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b0000_1111], 3).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b0001_1111], 4).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b0011_1111], 5).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b0111_1111], 6).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b1111_1111], 7).is_err()); + + assert!(BitList1024::from_raw_bytes(vec![0b0000_0001, 0b1111_1111], 8).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b0000_0011, 0b1111_1111], 9).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b0000_0111, 0b1111_1111], 10).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b0000_1111, 0b1111_1111], 11).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b0001_1111, 0b1111_1111], 12).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b0011_1111, 0b1111_1111], 13).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b0111_1111, 0b1111_1111], 14).is_err()); + assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b1111_1111], 15).is_err()); + } + + fn test_set_unset(num_bits: usize) { + let mut bitfield = BitList1024::with_capacity(num_bits).unwrap(); + + for i in 0..num_bits + 1 { + if i < num_bits { + // Starts as false + assert_eq!(bitfield.get(i), Ok(false)); + // Can be set true. + assert!(bitfield.set(i, true).is_ok()); + assert_eq!(bitfield.get(i), Ok(true)); + // Can be set false + assert!(bitfield.set(i, false).is_ok()); + assert_eq!(bitfield.get(i), Ok(false)); + } else { + assert!(bitfield.get(i).is_err()); + assert!(bitfield.set(i, true).is_err()); + assert!(bitfield.get(i).is_err()); + } + } + } + + fn test_bytes_round_trip(num_bits: usize) { + for i in 0..num_bits { + let mut bitfield = BitList1024::with_capacity(num_bits).unwrap(); + bitfield.set(i, true).unwrap(); + + let bytes = bitfield.clone().into_raw_bytes(); + assert_eq!(bitfield, Bitfield::from_raw_bytes(bytes, num_bits).unwrap()); + } + } + + #[test] + fn set_unset() { + for i in 0..8 * 5 { + test_set_unset(i) + } + } + + #[test] + fn bytes_round_trip() { + for i in 0..8 * 5 { + test_bytes_round_trip(i) + } + } + + #[test] + fn into_raw_bytes() { + let mut bitfield = BitList1024::with_capacity(9).unwrap(); + bitfield.set(0, true).unwrap(); + assert_eq!( + bitfield.clone().into_raw_bytes(), + vec![0b0000_0000, 0b0000_0001] + ); + bitfield.set(1, true).unwrap(); + assert_eq!( + bitfield.clone().into_raw_bytes(), + vec![0b0000_0000, 0b0000_0011] + ); + bitfield.set(2, true).unwrap(); + assert_eq!( + bitfield.clone().into_raw_bytes(), + vec![0b0000_0000, 0b0000_0111] + ); + bitfield.set(3, true).unwrap(); + assert_eq!( + bitfield.clone().into_raw_bytes(), + vec![0b0000_0000, 0b0000_1111] + ); + bitfield.set(4, true).unwrap(); + assert_eq!( + bitfield.clone().into_raw_bytes(), + vec![0b0000_0000, 0b0001_1111] + ); + bitfield.set(5, true).unwrap(); + assert_eq!( + bitfield.clone().into_raw_bytes(), + vec![0b0000_0000, 0b0011_1111] + ); + bitfield.set(6, true).unwrap(); + assert_eq!( + bitfield.clone().into_raw_bytes(), + vec![0b0000_0000, 0b0111_1111] + ); + bitfield.set(7, true).unwrap(); + assert_eq!( + bitfield.clone().into_raw_bytes(), + vec![0b0000_0000, 0b1111_1111] + ); + bitfield.set(8, true).unwrap(); + assert_eq!( + bitfield.clone().into_raw_bytes(), + vec![0b0000_0001, 0b1111_1111] + ); + } + + #[test] + fn highest_set_bit() { + assert_eq!( + BitList1024::with_capacity(16).unwrap().highest_set_bit(), + None + ); + + assert_eq!( + BitList1024::from_raw_bytes(vec![0b0000_000, 0b0000_0001], 16) + .unwrap() + .highest_set_bit(), + Some(0) + ); + + assert_eq!( + BitList1024::from_raw_bytes(vec![0b0000_000, 0b0000_0010], 16) + .unwrap() + .highest_set_bit(), + Some(1) + ); + + assert_eq!( + BitList1024::from_raw_bytes(vec![0b0000_1000], 8) + .unwrap() + .highest_set_bit(), + Some(3) + ); + + assert_eq!( + BitList1024::from_raw_bytes(vec![0b1000_0000, 0b0000_0000], 16) + .unwrap() + .highest_set_bit(), + Some(15) + ); + } + + #[test] + fn intersection() { + let a = BitList1024::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap(); + let b = BitList1024::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap(); + let c = BitList1024::from_raw_bytes(vec![0b1000, 0b0001], 16).unwrap(); + + assert_eq!(a.intersection(&b).unwrap(), c); + assert_eq!(b.intersection(&a).unwrap(), c); + assert_eq!(a.intersection(&c).unwrap(), c); + assert_eq!(b.intersection(&c).unwrap(), c); + assert_eq!(a.intersection(&a).unwrap(), a); + assert_eq!(b.intersection(&b).unwrap(), b); + assert_eq!(c.intersection(&c).unwrap(), c); + } + + #[test] + fn union() { + let a = BitList1024::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap(); + let b = BitList1024::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap(); + let c = BitList1024::from_raw_bytes(vec![0b1111, 0b1001], 16).unwrap(); + + assert_eq!(a.union(&b).unwrap(), c); + assert_eq!(b.union(&a).unwrap(), c); + assert_eq!(a.union(&a).unwrap(), a); + assert_eq!(b.union(&b).unwrap(), b); + assert_eq!(c.union(&c).unwrap(), c); + } + + #[test] + fn difference() { + let a = BitList1024::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap(); + let b = BitList1024::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap(); + let a_b = BitList1024::from_raw_bytes(vec![0b0100, 0b0000], 16).unwrap(); + let b_a = BitList1024::from_raw_bytes(vec![0b0011, 0b1000], 16).unwrap(); + + assert_eq!(a.difference(&b).unwrap(), a_b); + assert_eq!(b.difference(&a).unwrap(), b_a); + assert!(a.difference(&a).unwrap().is_zero()); + } + + #[test] + fn iter() { + let mut bitfield = BitList1024::with_capacity(9).unwrap(); + bitfield.set(2, true).unwrap(); + bitfield.set(8, true).unwrap(); + + assert_eq!( + bitfield.iter().collect::>(), + vec![false, false, true, false, false, false, false, false, true] + ); + } +} diff --git a/eth2/utils/ssz_types/src/fixed_vector.rs b/eth2/utils/ssz_types/src/fixed_vector.rs new file mode 100644 index 000000000..687d7d738 --- /dev/null +++ b/eth2/utils/ssz_types/src/fixed_vector.rs @@ -0,0 +1,335 @@ +use crate::Error; +use serde_derive::{Deserialize, Serialize}; +use std::marker::PhantomData; +use std::ops::{Deref, Index, IndexMut}; +use std::slice::SliceIndex; +use typenum::Unsigned; + +pub use typenum; + +/// Emulates a SSZ `Vector` (distinct from a Rust `Vec`). +/// +/// An ordered, heap-allocated, fixed-length, homogeneous collection of `T`, with `N` values. +/// +/// This struct is backed by a Rust `Vec` but constrained such that it must be instantiated with a +/// fixed number of elements and you may not add or remove elements, only modify. +/// +/// The length of this struct is fixed at the type-level using +/// [typenum](https://crates.io/crates/typenum). +/// +/// ## Note +/// +/// Whilst it is possible with this library, SSZ declares that a `FixedVector` with a length of `0` +/// is illegal. +/// +/// ## Example +/// +/// ``` +/// use ssz_types::{FixedVector, typenum}; +/// +/// let base: Vec = vec![1, 2, 3, 4]; +/// +/// // Create a `FixedVector` from a `Vec` that has the expected length. +/// let exact: FixedVector<_, typenum::U4> = FixedVector::from(base.clone()); +/// assert_eq!(&exact[..], &[1, 2, 3, 4]); +/// +/// // Create a `FixedVector` from a `Vec` that is too long and the `Vec` is truncated. +/// let short: FixedVector<_, typenum::U3> = FixedVector::from(base.clone()); +/// assert_eq!(&short[..], &[1, 2, 3]); +/// +/// // Create a `FixedVector` from a `Vec` that is too short and the missing values are created +/// // using `std::default::Default`. +/// let long: FixedVector<_, typenum::U5> = FixedVector::from(base); +/// assert_eq!(&long[..], &[1, 2, 3, 4, 0]); +/// ``` +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[serde(transparent)] +pub struct FixedVector { + vec: Vec, + _phantom: PhantomData, +} + +impl FixedVector { + /// Returns `Ok` if the given `vec` equals the fixed length of `Self`. Otherwise returns + /// `Err`. + pub fn new(vec: Vec) -> Result { + if vec.len() == Self::capacity() { + Ok(Self { + vec, + _phantom: PhantomData, + }) + } else { + Err(Error::OutOfBounds { + i: vec.len(), + len: Self::capacity(), + }) + } + } + + /// Identical to `self.capacity`, returns the type-level constant length. + /// + /// Exists for compatibility with `Vec`. + pub fn len(&self) -> usize { + self.vec.len() + } + + /// True if the type-level constant length of `self` is zero. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the type-level constant length. + pub fn capacity() -> usize { + N::to_usize() + } +} + +impl From> for FixedVector { + fn from(mut vec: Vec) -> Self { + vec.resize_with(Self::capacity(), Default::default); + + Self { + vec, + _phantom: PhantomData, + } + } +} + +impl Into> for FixedVector { + fn into(self) -> Vec { + self.vec + } +} + +impl Default for FixedVector { + fn default() -> Self { + Self { + vec: Vec::default(), + _phantom: PhantomData, + } + } +} + +impl> Index for FixedVector { + type Output = I::Output; + + #[inline] + fn index(&self, index: I) -> &Self::Output { + Index::index(&self.vec, index) + } +} + +impl> IndexMut for FixedVector { + #[inline] + fn index_mut(&mut self, index: I) -> &mut Self::Output { + IndexMut::index_mut(&mut self.vec, index) + } +} + +impl Deref for FixedVector { + type Target = [T]; + + fn deref(&self) -> &[T] { + &self.vec[..] + } +} + +#[cfg(test)] +mod test { + use super::*; + use typenum::*; + + #[test] + fn new() { + let vec = vec![42; 5]; + let fixed: Result, _> = FixedVector::new(vec.clone()); + assert!(fixed.is_err()); + + let vec = vec![42; 3]; + let fixed: Result, _> = FixedVector::new(vec.clone()); + assert!(fixed.is_err()); + + let vec = vec![42; 4]; + let fixed: Result, _> = FixedVector::new(vec.clone()); + assert!(fixed.is_ok()); + } + + #[test] + fn indexing() { + let vec = vec![1, 2]; + + let mut fixed: FixedVector = vec.clone().into(); + + assert_eq!(fixed[0], 1); + assert_eq!(&fixed[0..1], &vec[0..1]); + assert_eq!((&fixed[..]).len(), 8192); + + fixed[1] = 3; + assert_eq!(fixed[1], 3); + } + + #[test] + fn length() { + let vec = vec![42; 5]; + let fixed: FixedVector = FixedVector::from(vec.clone()); + assert_eq!(&fixed[..], &vec[0..4]); + + let vec = vec![42; 3]; + let fixed: FixedVector = FixedVector::from(vec.clone()); + assert_eq!(&fixed[0..3], &vec[..]); + assert_eq!(&fixed[..], &vec![42, 42, 42, 0][..]); + + let vec = vec![]; + let fixed: FixedVector = FixedVector::from(vec.clone()); + assert_eq!(&fixed[..], &vec![0, 0, 0, 0][..]); + } + + #[test] + fn deref() { + let vec = vec![0, 2, 4, 6]; + let fixed: FixedVector = FixedVector::from(vec); + + assert_eq!(fixed.get(0), Some(&0)); + assert_eq!(fixed.get(3), Some(&6)); + assert_eq!(fixed.get(4), None); + } +} + +impl tree_hash::TreeHash for FixedVector +where + T: tree_hash::TreeHash, +{ + fn tree_hash_type() -> tree_hash::TreeHashType { + tree_hash::TreeHashType::Vector + } + + fn tree_hash_packed_encoding(&self) -> Vec { + unreachable!("Vector should never be packed.") + } + + fn tree_hash_packing_factor() -> usize { + unreachable!("Vector should never be packed.") + } + + fn tree_hash_root(&self) -> Vec { + tree_hash::impls::vec_tree_hash_root(&self.vec) + } +} + +impl cached_tree_hash::CachedTreeHash for FixedVector +where + T: cached_tree_hash::CachedTreeHash + tree_hash::TreeHash, +{ + fn new_tree_hash_cache( + &self, + depth: usize, + ) -> Result { + let (cache, _overlay) = cached_tree_hash::vec::new_tree_hash_cache(&self.vec, depth)?; + + Ok(cache) + } + + fn tree_hash_cache_schema(&self, depth: usize) -> cached_tree_hash::BTreeSchema { + cached_tree_hash::vec::produce_schema(&self.vec, depth) + } + + fn update_tree_hash_cache( + &self, + cache: &mut cached_tree_hash::TreeHashCache, + ) -> Result<(), cached_tree_hash::Error> { + cached_tree_hash::vec::update_tree_hash_cache(&self.vec, cache)?; + + Ok(()) + } +} + +impl ssz::Encode for FixedVector +where + T: ssz::Encode, +{ + fn is_ssz_fixed_len() -> bool { + true + } + + fn ssz_fixed_len() -> usize { + if ::is_ssz_fixed_len() { + T::ssz_fixed_len() * N::to_usize() + } else { + ssz::BYTES_PER_LENGTH_OFFSET + } + } + + fn ssz_append(&self, buf: &mut Vec) { + if T::is_ssz_fixed_len() { + buf.reserve(T::ssz_fixed_len() * self.len()); + + for item in &self.vec { + item.ssz_append(buf); + } + } else { + let mut encoder = ssz::SszEncoder::list(buf, self.len() * ssz::BYTES_PER_LENGTH_OFFSET); + + for item in &self.vec { + encoder.append(item); + } + + encoder.finalize(); + } + } +} + +impl ssz::Decode for FixedVector +where + T: ssz::Decode + Default, +{ + fn is_ssz_fixed_len() -> bool { + T::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + if ::is_ssz_fixed_len() { + T::ssz_fixed_len() * N::to_usize() + } else { + ssz::BYTES_PER_LENGTH_OFFSET + } + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + if bytes.is_empty() { + Ok(FixedVector::from(vec![])) + } else if T::is_ssz_fixed_len() { + bytes + .chunks(T::ssz_fixed_len()) + .map(|chunk| T::from_ssz_bytes(chunk)) + .collect::, _>>() + .and_then(|vec| Ok(vec.into())) + } else { + ssz::decode_list_of_variable_length_items(bytes).and_then(|vec| Ok(vec.into())) + } + } +} + +#[cfg(test)] +mod ssz_tests { + use super::*; + use ssz::*; + use typenum::*; + + #[test] + fn encode() { + let vec: FixedVector = vec![0; 2].into(); + assert_eq!(vec.as_ssz_bytes(), vec![0, 0, 0, 0]); + assert_eq!( as Encode>::ssz_fixed_len(), 4); + } + + fn round_trip(item: T) { + let encoded = &item.as_ssz_bytes(); + assert_eq!(T::from_ssz_bytes(&encoded), Ok(item)); + } + + #[test] + fn u16_len_8() { + round_trip::>(vec![42; 8].into()); + round_trip::>(vec![0; 8].into()); + } +} diff --git a/eth2/utils/ssz_types/src/lib.rs b/eth2/utils/ssz_types/src/lib.rs new file mode 100644 index 000000000..59869b7c0 --- /dev/null +++ b/eth2/utils/ssz_types/src/lib.rs @@ -0,0 +1,66 @@ +//! Provides types with unique properties required for SSZ serialization and Merklization: +//! +//! - `FixedVector`: A heap-allocated list with a size that is fixed at compile time. +//! - `VariableList`: A heap-allocated list that cannot grow past a type-level maximum length. +//! - `BitList`: A heap-allocated bitfield that with a type-level _maximum_ length. +//! - `BitVector`: A heap-allocated bitfield that with a type-level _fixed__ length. +//! +//! These structs are required as SSZ serialization and Merklization rely upon type-level lengths +//! for padding and verification. +//! +//! ## Example +//! ``` +//! use ssz_types::*; +//! +//! pub struct Example { +//! bit_vector: BitVector, +//! bit_list: BitList, +//! variable_list: VariableList, +//! fixed_vector: FixedVector, +//! } +//! +//! let mut example = Example { +//! bit_vector: Bitfield::new(), +//! bit_list: Bitfield::with_capacity(4).unwrap(), +//! variable_list: <_>::from(vec![0, 1]), +//! fixed_vector: <_>::from(vec![2, 3]), +//! }; +//! +//! assert_eq!(example.bit_vector.len(), 8); +//! assert_eq!(example.bit_list.len(), 4); +//! assert_eq!(&example.variable_list[..], &[0, 1]); +//! assert_eq!(&example.fixed_vector[..], &[2, 3, 0, 0, 0, 0, 0, 0]); +//! +//! ``` + +#[macro_use] +mod bitfield; +mod fixed_vector; +mod variable_list; + +pub use bitfield::{BitList, BitVector, Bitfield}; +pub use fixed_vector::FixedVector; +pub use typenum; +pub use variable_list::VariableList; + +pub mod length { + pub use crate::bitfield::{Fixed, Variable}; +} + +/// Returned when an item encounters an error. +#[derive(PartialEq, Debug)] +pub enum Error { + OutOfBounds { + i: usize, + len: usize, + }, + /// A `BitList` does not have a set bit, therefore it's length is unknowable. + MissingLengthInformation, + /// A `BitList` has excess bits set to true. + ExcessBits, + /// A `BitList` has an invalid number of bytes for a given bit length. + InvalidByteCount { + given: usize, + expected: usize, + }, +} diff --git a/eth2/utils/ssz_types/src/variable_list.rs b/eth2/utils/ssz_types/src/variable_list.rs new file mode 100644 index 000000000..52872ada6 --- /dev/null +++ b/eth2/utils/ssz_types/src/variable_list.rs @@ -0,0 +1,320 @@ +use crate::Error; +use serde_derive::{Deserialize, Serialize}; +use std::marker::PhantomData; +use std::ops::{Deref, Index, IndexMut}; +use std::slice::SliceIndex; +use typenum::Unsigned; + +pub use typenum; + +/// Emulates a SSZ `List`. +/// +/// An ordered, heap-allocated, variable-length, homogeneous collection of `T`, with no more than +/// `N` values. +/// +/// This struct is backed by a Rust `Vec` but constrained such that it must be instantiated with a +/// fixed number of elements and you may not add or remove elements, only modify. +/// +/// The length of this struct is fixed at the type-level using +/// [typenum](https://crates.io/crates/typenum). +/// +/// ## Example +/// +/// ``` +/// use ssz_types::{VariableList, typenum}; +/// +/// let base: Vec = vec![1, 2, 3, 4]; +/// +/// // Create a `VariableList` from a `Vec` that has the expected length. +/// let exact: VariableList<_, typenum::U4> = VariableList::from(base.clone()); +/// assert_eq!(&exact[..], &[1, 2, 3, 4]); +/// +/// // Create a `VariableList` from a `Vec` that is too long and the `Vec` is truncated. +/// let short: VariableList<_, typenum::U3> = VariableList::from(base.clone()); +/// assert_eq!(&short[..], &[1, 2, 3]); +/// +/// // Create a `VariableList` from a `Vec` that is shorter than the maximum. +/// let mut long: VariableList<_, typenum::U5> = VariableList::from(base); +/// assert_eq!(&long[..], &[1, 2, 3, 4]); +/// +/// // Push a value to if it does not exceed the maximum +/// long.push(5).unwrap(); +/// assert_eq!(&long[..], &[1, 2, 3, 4, 5]); +/// +/// // Push a value to if it _does_ exceed the maximum. +/// assert!(long.push(6).is_err()); +/// ``` +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[serde(transparent)] +pub struct VariableList { + vec: Vec, + _phantom: PhantomData, +} + +impl VariableList { + /// Returns `Some` if the given `vec` equals the fixed length of `Self`. Otherwise returns + /// `None`. + pub fn new(vec: Vec) -> Result { + if vec.len() <= N::to_usize() { + Ok(Self { + vec, + _phantom: PhantomData, + }) + } else { + Err(Error::OutOfBounds { + i: vec.len(), + len: Self::max_len(), + }) + } + } + + /// Returns the number of values presently in `self`. + pub fn len(&self) -> usize { + self.vec.len() + } + + /// True if `self` does not contain any values. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the type-level maximum length. + pub fn max_len() -> usize { + N::to_usize() + } + + /// Appends `value` to the back of `self`. + /// + /// Returns `Err(())` when appending `value` would exceed the maximum length. + pub fn push(&mut self, value: T) -> Result<(), Error> { + if self.vec.len() < Self::max_len() { + self.vec.push(value); + Ok(()) + } else { + Err(Error::OutOfBounds { + i: self.vec.len() + 1, + len: Self::max_len(), + }) + } + } +} + +impl From> for VariableList { + fn from(mut vec: Vec) -> Self { + vec.truncate(N::to_usize()); + + Self { + vec, + _phantom: PhantomData, + } + } +} + +impl Into> for VariableList { + fn into(self) -> Vec { + self.vec + } +} + +impl Default for VariableList { + fn default() -> Self { + Self { + vec: Vec::default(), + _phantom: PhantomData, + } + } +} + +impl> Index for VariableList { + type Output = I::Output; + + #[inline] + fn index(&self, index: I) -> &Self::Output { + Index::index(&self.vec, index) + } +} + +impl> IndexMut for VariableList { + #[inline] + fn index_mut(&mut self, index: I) -> &mut Self::Output { + IndexMut::index_mut(&mut self.vec, index) + } +} + +impl Deref for VariableList { + type Target = [T]; + + fn deref(&self) -> &[T] { + &self.vec[..] + } +} + +#[cfg(test)] +mod test { + use super::*; + use typenum::*; + + #[test] + fn new() { + let vec = vec![42; 5]; + let fixed: Result, _> = VariableList::new(vec.clone()); + assert!(fixed.is_err()); + + let vec = vec![42; 3]; + let fixed: Result, _> = VariableList::new(vec.clone()); + assert!(fixed.is_ok()); + + let vec = vec![42; 4]; + let fixed: Result, _> = VariableList::new(vec.clone()); + assert!(fixed.is_ok()); + } + + #[test] + fn indexing() { + let vec = vec![1, 2]; + + let mut fixed: VariableList = vec.clone().into(); + + assert_eq!(fixed[0], 1); + assert_eq!(&fixed[0..1], &vec[0..1]); + assert_eq!((&fixed[..]).len(), 2); + + fixed[1] = 3; + assert_eq!(fixed[1], 3); + } + + #[test] + fn length() { + let vec = vec![42; 5]; + let fixed: VariableList = VariableList::from(vec.clone()); + assert_eq!(&fixed[..], &vec[0..4]); + + let vec = vec![42; 3]; + let fixed: VariableList = VariableList::from(vec.clone()); + assert_eq!(&fixed[0..3], &vec[..]); + assert_eq!(&fixed[..], &vec![42, 42, 42][..]); + + let vec = vec![]; + let fixed: VariableList = VariableList::from(vec.clone()); + assert_eq!(&fixed[..], &vec![][..]); + } + + #[test] + fn deref() { + let vec = vec![0, 2, 4, 6]; + let fixed: VariableList = VariableList::from(vec); + + assert_eq!(fixed.get(0), Some(&0)); + assert_eq!(fixed.get(3), Some(&6)); + assert_eq!(fixed.get(4), None); + } +} + +impl tree_hash::TreeHash for VariableList +where + T: tree_hash::TreeHash, +{ + fn tree_hash_type() -> tree_hash::TreeHashType { + tree_hash::TreeHashType::Vector + } + + fn tree_hash_packed_encoding(&self) -> Vec { + unreachable!("Vector should never be packed.") + } + + fn tree_hash_packing_factor() -> usize { + unreachable!("Vector should never be packed.") + } + + fn tree_hash_root(&self) -> Vec { + tree_hash::impls::vec_tree_hash_root(&self.vec) + } +} + +impl cached_tree_hash::CachedTreeHash for VariableList +where + T: cached_tree_hash::CachedTreeHash + tree_hash::TreeHash, +{ + fn new_tree_hash_cache( + &self, + depth: usize, + ) -> Result { + let (cache, _overlay) = cached_tree_hash::vec::new_tree_hash_cache(&self.vec, depth)?; + + Ok(cache) + } + + fn tree_hash_cache_schema(&self, depth: usize) -> cached_tree_hash::BTreeSchema { + cached_tree_hash::vec::produce_schema(&self.vec, depth) + } + + fn update_tree_hash_cache( + &self, + cache: &mut cached_tree_hash::TreeHashCache, + ) -> Result<(), cached_tree_hash::Error> { + cached_tree_hash::vec::update_tree_hash_cache(&self.vec, cache)?; + + Ok(()) + } +} + +impl ssz::Encode for VariableList +where + T: ssz::Encode, +{ + fn is_ssz_fixed_len() -> bool { + >::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + >::ssz_fixed_len() + } + + fn ssz_append(&self, buf: &mut Vec) { + self.vec.ssz_append(buf) + } +} + +impl ssz::Decode for VariableList +where + T: ssz::Decode + Default, +{ + fn is_ssz_fixed_len() -> bool { + >::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + >::ssz_fixed_len() + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + let vec = >::from_ssz_bytes(bytes)?; + + Self::new(vec).map_err(|e| ssz::DecodeError::BytesInvalid(format!("VariableList {:?}", e))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ssz::*; + use typenum::*; + + #[test] + fn encode() { + let vec: VariableList = vec![0; 2].into(); + assert_eq!(vec.as_ssz_bytes(), vec![0, 0, 0, 0]); + assert_eq!( as Encode>::ssz_fixed_len(), 4); + } + + fn round_trip(item: T) { + let encoded = &item.as_ssz_bytes(); + assert_eq!(T::from_ssz_bytes(&encoded), Ok(item)); + } + + #[test] + fn u16_len_8() { + round_trip::>(vec![42; 8].into()); + round_trip::>(vec![0; 8].into()); + } +}