From 5943e176cf497442beb3fa3bea97edf95a2cda65 Mon Sep 17 00:00:00 2001 From: Paul Hauner Date: Fri, 5 Jul 2019 17:33:20 +1000 Subject: [PATCH] Add ssz_types crate --- Cargo.toml | 1 + eth2/utils/ssz_types/Cargo.toml | 19 + eth2/utils/ssz_types/src/bit_vector.rs | 229 +++++++++ eth2/utils/ssz_types/src/bitfield.rs | 570 ++++++++++++++++++++++ eth2/utils/ssz_types/src/fixed_vector.rs | 335 +++++++++++++ eth2/utils/ssz_types/src/lib.rs | 23 + eth2/utils/ssz_types/src/variable_list.rs | 321 ++++++++++++ 7 files changed, 1498 insertions(+) create mode 100644 eth2/utils/ssz_types/Cargo.toml create mode 100644 eth2/utils/ssz_types/src/bit_vector.rs create mode 100644 eth2/utils/ssz_types/src/bitfield.rs create mode 100644 eth2/utils/ssz_types/src/fixed_vector.rs create mode 100644 eth2/utils/ssz_types/src/lib.rs create mode 100644 eth2/utils/ssz_types/src/variable_list.rs diff --git a/Cargo.toml b/Cargo.toml index ef17b431e..22ec6fd98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "eth2/utils/slot_clock", "eth2/utils/ssz", "eth2/utils/ssz_derive", + "eth2/utils/ssz_types", "eth2/utils/swap_or_not_shuffle", "eth2/utils/tree_hash", "eth2/utils/tree_hash_derive", diff --git a/eth2/utils/ssz_types/Cargo.toml b/eth2/utils/ssz_types/Cargo.toml new file mode 100644 index 000000000..31d567d49 --- /dev/null +++ b/eth2/utils/ssz_types/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "ssz_types" +version = "0.1.0" +authors = ["Paul Hauner "] +edition = "2018" + +[dependencies] +bit_reverse = "0.1" +bit-vec = "0.5.0" +cached_tree_hash = { path = "../cached_tree_hash" } +tree_hash = { path = "../tree_hash" } +serde = "1.0" +serde_derive = "1.0" +serde_hex = { path = "../serde_hex" } +ssz = { path = "../ssz" } +typenum = "1.10" + +[dev-dependencies] +serde_yaml = "0.8" diff --git a/eth2/utils/ssz_types/src/bit_vector.rs b/eth2/utils/ssz_types/src/bit_vector.rs new file mode 100644 index 000000000..48c91f94d --- /dev/null +++ b/eth2/utils/ssz_types/src/bit_vector.rs @@ -0,0 +1,229 @@ +use crate::bitfield::{Bitfield, Error}; +use crate::{FixedSizedError, VariableSizedError}; +use std::marker::PhantomData; +use typenum::Unsigned; + +/// Provides a common `impl` for structs that wrap a `Bitfield`. +macro_rules! common_impl { + ($name: ident, $error: ident) => { + impl $name { + /// Create a new bitfield with the given length `initial_len` and all values set to `bit`. + /// + /// Note: if `initial_len` is not a multiple of 8, the remaining bits will be set to `false` + /// regardless of `bit`. + pub fn from_elem(initial_len: usize, bit: bool) -> Result { + let bitfield = Bitfield::from_elem(initial_len, bit); + Self::from_bitfield(bitfield) + } + + /// Create a new BitList using the supplied `bytes` as input + pub fn from_bytes(bytes: &[u8]) -> Result { + let bitfield = Bitfield::from_bytes(bytes); + Self::from_bitfield(bitfield) + } + + /// Returns a vector of bytes representing the bitfield + pub fn to_bytes(&self) -> Vec { + self.bitfield.to_bytes() + } + + /// Read the value of a bit. + /// + /// If the index is in bounds, then result is Ok(value) where value is `true` if the bit is 1 and `false` if the bit is 0. + /// If the index is out of bounds, we return an error to that extent. + pub fn get(&self, i: usize) -> Result { + self.bitfield.get(i) + } + + fn capacity() -> usize { + N::to_usize() + } + + /// Set the value of a bit. + /// + /// Returns an `Err` if `i` is outside of the maximum permitted length. + pub fn set(&mut self, i: usize, value: bool) -> Result<(), VariableSizedError> { + if i < Self::capacity() { + self.bitfield.set(i, value); + Ok(()) + } else { + Err(VariableSizedError::ExceedsMaxLength { + len: Self::capacity() + 1, + max_len: Self::capacity(), + }) + } + } + + /// Returns the number of bits in this bitfield. + pub fn len(&self) -> usize { + self.bitfield.len() + } + + /// Returns true if `self.len() == 0` + pub fn is_empty(&self) -> bool { + self.bitfield.is_empty() + } + + /// Returns true if all bits are set to 0. + pub fn is_zero(&self) -> bool { + self.bitfield.is_zero() + } + + /// Returns the number of bytes required to represent this bitfield. + pub fn num_bytes(&self) -> usize { + self.bitfield.num_bytes() + } + + /// Returns the number of `1` bits in the bitfield + pub fn num_set_bits(&self) -> usize { + self.bitfield.num_set_bits() + } + } + }; +} + +/// Emulates a SSZ `Bitvector`. +/// +/// An ordered, heap-allocated, fixed-length, collection of `bool` values, with `N` values. +pub struct BitVector { + bitfield: Bitfield, + _phantom: PhantomData, +} + +common_impl!(BitVector, FixedSizedError); + +impl BitVector { + /// Create a new bitfield. + pub fn new() -> Self { + Self { + bitfield: Bitfield::with_capacity(N::to_usize()), + _phantom: PhantomData, + } + } + + fn from_bitfield(bitfield: Bitfield) -> Result { + if bitfield.len() != Self::capacity() { + Err(FixedSizedError::InvalidLength { + len: bitfield.len(), + fixed_len: Self::capacity(), + }) + } else { + Ok(Self { + bitfield, + _phantom: PhantomData, + }) + } + } +} + +/// Emulates a SSZ `Bitlist`. +/// +/// An ordered, heap-allocated, variable-length, collection of `bool` values, limited to `N` +/// values. +pub struct BitList { + bitfield: Bitfield, + _phantom: PhantomData, +} + +common_impl!(BitList, VariableSizedError); + +impl BitList { + /// Create a new, empty BitList. + pub fn new() -> Self { + Self { + bitfield: Bitfield::default(), + _phantom: PhantomData, + } + } + + /// Create a new BitList list with `initial_len` bits all set to `false`. + pub fn with_capacity(initial_len: usize) -> Result { + Self::from_elem(initial_len, false) + } + + /// The maximum possible number of bits. + pub fn max_len() -> usize { + N::to_usize() + } + + fn from_bitfield(bitfield: Bitfield) -> Result { + if bitfield.len() > Self::max_len() { + Err(VariableSizedError::ExceedsMaxLength { + len: bitfield.len(), + max_len: Self::max_len(), + }) + } else { + Ok(Self { + bitfield, + _phantom: PhantomData, + }) + } + } + + /// Compute the intersection (binary-and) of this bitfield with another + /// + /// ## Panics + /// + /// If `self` and `other` have different lengths. + pub fn intersection(&self, other: &Self) -> Self { + assert_eq!(self.len(), other.len()); + let bitfield = self.bitfield.intersection(&other.bitfield); + Self::from_bitfield(bitfield).expect( + "An intersection of two same-sized sets cannot be larger than one of the initial sets", + ) + } + + /// Like `intersection` but in-place (updates `self`). + /// + /// ## Panics + /// + /// If `self` and `other` have different lengths. + pub fn intersection_inplace(&mut self, other: &Self) { + self.bitfield.intersection_inplace(&other.bitfield); + } + + /// Compute the union (binary-or) of this bitfield with another. Lengths must match. + /// + /// ## Panics + /// + /// If `self` and `other` have different lengths. + pub fn union(&self, other: &Self) -> Self { + assert_eq!(self.len(), other.len()); + let bitfield = self.bitfield.union(&other.bitfield); + Self::from_bitfield(bitfield) + .expect("A union of two same-sized sets cannot be larger than one of the initial sets") + } + + /// Like `union` but in-place (updates `self`). + /// + /// ## Panics + /// + /// If `self` and `other` have different lengths. + pub fn union_inplace(&mut self, other: &Self) { + self.bitfield.union_inplace(&other.bitfield) + } + + /// Compute the difference (binary-minus) of this bitfield with another. Lengths must match. + /// + /// Computes `self - other`. + /// + /// ## Panics + /// + /// If `self` and `other` have different lengths. + pub fn difference(&self, other: &Self) -> Self { + assert_eq!(self.len(), other.len()); + let bitfield = self.bitfield.difference(&other.bitfield); + Self::from_bitfield(bitfield).expect( + "A difference of two same-sized sets cannot be larger than one of the initial sets", + ) + } + + /// Like `difference` but in-place (updates `self`). + /// + /// ## Panics + /// + /// If `self` and `other` have different lengths. + pub fn difference_inplace(&mut self, other: &Self) { + self.bitfield.difference_inplace(&other.bitfield) + } +} diff --git a/eth2/utils/ssz_types/src/bitfield.rs b/eth2/utils/ssz_types/src/bitfield.rs new file mode 100644 index 000000000..d77f63cf7 --- /dev/null +++ b/eth2/utils/ssz_types/src/bitfield.rs @@ -0,0 +1,570 @@ +use bit_reverse::LookupReverse; +use bit_vec::BitVec; +use cached_tree_hash::cached_tree_hash_bytes_as_list; +use serde::de::{Deserialize, Deserializer}; +use serde::ser::{Serialize, Serializer}; +use serde_hex::{encode, PrefixedHexVisitor}; +use ssz::{Decode, Encode}; +use std::cmp; +use std::default; + +/// A Bitfield represents a set of booleans compactly stored as a vector of bits. +/// The Bitfield is given a fixed size during construction. Reads outside of the current size return an out-of-bounds error. Writes outside of the current size expand the size of the set. +#[derive(Debug, Clone)] +pub struct Bitfield(BitVec); + +/// Error represents some reason a request against a bitfield was not satisfied +#[derive(Debug, PartialEq)] +pub enum Error { + /// OutOfBounds refers to indexing into a bitfield where no bits exist; returns the illegal index and the current size of the bitfield, respectively + OutOfBounds(usize, usize), +} + +impl Bitfield { + pub fn with_capacity(initial_len: usize) -> Self { + Self::from_elem(initial_len, false) + } + + /// Create a new bitfield with the given length `initial_len` and all values set to `bit`. + /// + /// Note: if `initial_len` is not a multiple of 8, the remaining bits will be set to `false` + /// regardless of `bit`. + pub fn from_elem(initial_len: usize, bit: bool) -> Self { + // BitVec can panic if we don't set the len to be a multiple of 8. + let full_len = ((initial_len + 7) / 8) * 8; + let mut bitfield = BitVec::from_elem(full_len, false); + + if bit { + for i in 0..initial_len { + bitfield.set(i, true); + } + } + + Self { 0: bitfield } + } + + /// Create a new bitfield using the supplied `bytes` as input + pub fn from_bytes(bytes: &[u8]) -> Self { + Self { + 0: BitVec::from_bytes(&reverse_bit_order(bytes.to_vec())), + } + } + + /// Returns a vector of bytes representing the bitfield + pub fn to_bytes(&self) -> Vec { + reverse_bit_order(self.0.to_bytes().to_vec()) + } + + /// Read the value of a bit. + /// + /// If the index is in bounds, then result is Ok(value) where value is `true` if the bit is 1 and `false` if the bit is 0. + /// If the index is out of bounds, we return an error to that extent. + pub fn get(&self, i: usize) -> Result { + match self.0.get(i) { + Some(value) => Ok(value), + None => Err(Error::OutOfBounds(i, self.0.len())), + } + } + + /// Set the value of a bit. + /// + /// If the index is out of bounds, we expand the size of the underlying set to include the new index. + /// Returns the previous value if there was one. + pub fn set(&mut self, i: usize, value: bool) -> Option { + let previous = match self.get(i) { + Ok(previous) => Some(previous), + Err(Error::OutOfBounds(_, len)) => { + let new_len = i - len + 1; + self.0.grow(new_len, false); + None + } + }; + self.0.set(i, value); + previous + } + + /// Returns the number of bits in this bitfield. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns true if `self.len() == 0` + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns true if all bits are set to 0. + pub fn is_zero(&self) -> bool { + self.0.none() + } + + /// Returns the number of bytes required to represent this bitfield. + pub fn num_bytes(&self) -> usize { + self.to_bytes().len() + } + + /// Returns the number of `1` bits in the bitfield + pub fn num_set_bits(&self) -> usize { + self.0.iter().filter(|&bit| bit).count() + } + + /// Compute the intersection (binary-and) of this bitfield with another. Lengths must match. + pub fn intersection(&self, other: &Self) -> Self { + let mut res = self.clone(); + res.intersection_inplace(other); + res + } + + /// Like `intersection` but in-place (updates `self`). + pub fn intersection_inplace(&mut self, other: &Self) { + self.0.intersect(&other.0); + } + + /// Compute the union (binary-or) of this bitfield with another. Lengths must match. + pub fn union(&self, other: &Self) -> Self { + let mut res = self.clone(); + res.union_inplace(other); + res + } + + /// Like `union` but in-place (updates `self`). + pub fn union_inplace(&mut self, other: &Self) { + self.0.union(&other.0); + } + + /// Compute the difference (binary-minus) of this bitfield with another. Lengths must match. + /// + /// Computes `self - other`. + pub fn difference(&self, other: &Self) -> Self { + let mut res = self.clone(); + res.difference_inplace(other); + res + } + + /// Like `difference` but in-place (updates `self`). + pub fn difference_inplace(&mut self, other: &Self) { + self.0.difference(&other.0); + } +} + +impl default::Default for Bitfield { + /// default provides the "empty" bitfield + /// Note: the empty bitfield is set to the `0` byte. + fn default() -> Self { + Self::from_elem(8, false) + } +} + +impl cmp::PartialEq for Bitfield { + /// Determines equality by comparing the `ssz` encoding of the two candidates. + /// This method ensures that the presence of high-order (empty) bits in the highest byte do not exclude equality when they are in fact representing the same information. + fn eq(&self, other: &Self) -> bool { + ssz::ssz_encode(self) == ssz::ssz_encode(other) + } +} + +/// Create a new bitfield that is a union of two other bitfields. +/// +/// For example `union(0101, 1000) == 1101` +// TODO: length-independent intersection for BitAnd +impl std::ops::BitOr for Bitfield { + type Output = Self; + + fn bitor(self, other: Self) -> Self { + let (biggest, smallest) = if self.len() > other.len() { + (&self, &other) + } else { + (&other, &self) + }; + let mut new = biggest.clone(); + for i in 0..smallest.len() { + if let Ok(true) = smallest.get(i) { + new.set(i, true); + } + } + new + } +} + +impl Encode for Bitfield { + fn is_ssz_fixed_len() -> bool { + false + } + + fn ssz_append(&self, buf: &mut Vec) { + buf.append(&mut self.to_bytes()) + } +} + +impl Decode for Bitfield { + fn is_ssz_fixed_len() -> bool { + false + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + Ok(Bitfield::from_bytes(bytes)) + } +} + +// Reverse the bit order of a whole byte vec, so that the ith bit +// of the input vec is placed in the (N - i)th bit of the output vec. +// This function is necessary for converting bitfields to and from YAML, +// as the BitVec library and the hex-parser use opposing bit orders. +fn reverse_bit_order(mut bytes: Vec) -> Vec { + bytes.reverse(); + bytes.into_iter().map(LookupReverse::swap_bits).collect() +} + +impl Serialize for Bitfield { + /// Serde serialization is compliant with the Ethereum YAML test format. + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&encode(self.to_bytes())) + } +} + +impl<'de> Deserialize<'de> for Bitfield { + /// Serde serialization is compliant with the Ethereum YAML test format. + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // We reverse the bit-order so that the BitVec library can read its 0th + // bit from the end of the hex string, e.g. + // "0xef01" => [0xef, 0x01] => [0b1000_0000, 0b1111_1110] + let bytes = deserializer.deserialize_str(PrefixedHexVisitor)?; + Ok(Bitfield::from_bytes(&bytes)) + } +} + +impl tree_hash::TreeHash for Bitfield { + fn tree_hash_type() -> tree_hash::TreeHashType { + tree_hash::TreeHashType::List + } + + fn tree_hash_packed_encoding(&self) -> Vec { + unreachable!("List should never be packed.") + } + + fn tree_hash_packing_factor() -> usize { + unreachable!("List should never be packed.") + } + + fn tree_hash_root(&self) -> Vec { + self.to_bytes().tree_hash_root() + } +} + +cached_tree_hash_bytes_as_list!(Bitfield); + +#[cfg(test)] +mod tests { + use super::*; + use serde_yaml; + use ssz::ssz_encode; + use tree_hash::TreeHash; + + impl Bitfield { + /// Create a new bitfield. + pub fn new() -> Self { + Default::default() + } + } + + #[test] + pub fn test_cached_tree_hash() { + let original = Bitfield::from_bytes(&vec![18; 12][..]); + + let mut cache = cached_tree_hash::TreeHashCache::new(&original).unwrap(); + + assert_eq!( + cache.tree_hash_root().unwrap().to_vec(), + original.tree_hash_root() + ); + + let modified = Bitfield::from_bytes(&vec![2; 1][..]); + + cache.update(&modified).unwrap(); + + assert_eq!( + cache.tree_hash_root().unwrap().to_vec(), + modified.tree_hash_root() + ); + } + + #[test] + fn test_new_bitfield() { + let mut field = Bitfield::new(); + let original_len = field.len(); + + for i in 0..100 { + if i < original_len { + assert!(!field.get(i).unwrap()); + } else { + assert!(field.get(i).is_err()); + } + let previous = field.set(i, true); + if i < original_len { + assert!(!previous.unwrap()); + } else { + assert!(previous.is_none()); + } + } + } + + #[test] + fn test_empty_bitfield() { + let mut field = Bitfield::from_elem(0, false); + let original_len = field.len(); + + assert_eq!(original_len, 0); + + for i in 0..100 { + if i < original_len { + assert!(!field.get(i).unwrap()); + } else { + assert!(field.get(i).is_err()); + } + let previous = field.set(i, true); + if i < original_len { + assert!(!previous.unwrap()); + } else { + assert!(previous.is_none()); + } + } + + assert_eq!(field.len(), 100); + assert_eq!(field.num_set_bits(), 100); + } + + const INPUT: &[u8] = &[0b0100_0000, 0b0100_0000]; + + #[test] + fn test_get_from_bitfield() { + let field = Bitfield::from_bytes(INPUT); + let unset = field.get(0).unwrap(); + assert!(!unset); + let set = field.get(6).unwrap(); + assert!(set); + let set = field.get(14).unwrap(); + assert!(set); + } + + #[test] + fn test_set_for_bitfield() { + let mut field = Bitfield::from_bytes(INPUT); + let previous = field.set(10, true).unwrap(); + assert!(!previous); + let previous = field.get(10).unwrap(); + assert!(previous); + let previous = field.set(6, false).unwrap(); + assert!(previous); + let previous = field.get(6).unwrap(); + assert!(!previous); + } + + #[test] + fn test_len() { + let field = Bitfield::from_bytes(INPUT); + assert_eq!(field.len(), 16); + + let field = Bitfield::new(); + assert_eq!(field.len(), 8); + } + + #[test] + fn test_num_set_bits() { + let field = Bitfield::from_bytes(INPUT); + assert_eq!(field.num_set_bits(), 2); + + let field = Bitfield::new(); + assert_eq!(field.num_set_bits(), 0); + } + + #[test] + fn test_to_bytes() { + let field = Bitfield::from_bytes(INPUT); + assert_eq!(field.to_bytes(), INPUT); + + let field = Bitfield::new(); + assert_eq!(field.to_bytes(), vec![0]); + } + + #[test] + fn test_out_of_bounds() { + let mut field = Bitfield::from_bytes(INPUT); + + let out_of_bounds_index = field.len(); + assert!(field.set(out_of_bounds_index, true).is_none()); + assert!(field.len() == out_of_bounds_index + 1); + assert!(field.get(out_of_bounds_index).unwrap()); + + for i in 0..100 { + if i <= out_of_bounds_index { + assert!(field.set(i, true).is_some()); + } else { + assert!(field.set(i, true).is_none()); + } + } + } + + #[test] + fn test_grows_with_false() { + let input_all_set: &[u8] = &[0b1111_1111, 0b1111_1111]; + let mut field = Bitfield::from_bytes(input_all_set); + + // Define `a` and `b`, where both are out of bounds and `b` is greater than `a`. + let a = field.len(); + let b = a + 1; + + // Ensure `a` is out-of-bounds for test integrity. + assert!(field.get(a).is_err()); + + // Set `b` to `true`. Also, for test integrity, ensure it was previously out-of-bounds. + assert!(field.set(b, true).is_none()); + + // Ensure that `a` wasn't also set to `true` during the grow. + assert_eq!(field.get(a), Ok(false)); + assert_eq!(field.get(b), Ok(true)); + } + + #[test] + fn test_num_bytes() { + let field = Bitfield::from_bytes(INPUT); + assert_eq!(field.num_bytes(), 2); + + let field = Bitfield::from_elem(2, true); + assert_eq!(field.num_bytes(), 1); + + let field = Bitfield::from_elem(13, true); + assert_eq!(field.num_bytes(), 2); + } + + #[test] + fn test_ssz_encode() { + let field = create_test_bitfield(); + assert_eq!(field.as_ssz_bytes(), vec![0b0000_0011, 0b1000_0111]); + + let field = Bitfield::from_elem(18, true); + assert_eq!( + field.as_ssz_bytes(), + vec![0b0000_0011, 0b1111_1111, 0b1111_1111] + ); + + let mut b = Bitfield::new(); + b.set(1, true); + assert_eq!(ssz_encode(&b), vec![0b0000_0010]); + } + + fn create_test_bitfield() -> Bitfield { + let count = 2 * 8; + let mut field = Bitfield::with_capacity(count); + + let indices = &[0, 1, 2, 7, 8, 9]; + for &i in indices { + field.set(i, true); + } + field + } + + #[test] + fn test_ssz_decode() { + let encoded = vec![0b0000_0011, 0b1000_0111]; + let field = Bitfield::from_ssz_bytes(&encoded).unwrap(); + let expected = create_test_bitfield(); + assert_eq!(field, expected); + + let encoded = vec![255, 255, 3]; + let field = Bitfield::from_ssz_bytes(&encoded).unwrap(); + let expected = Bitfield::from_bytes(&[255, 255, 3]); + assert_eq!(field, expected); + } + + #[test] + fn test_serialize_deserialize() { + use serde_yaml::Value; + + let data: &[(_, &[_])] = &[ + ("0x01", &[0b00000001]), + ("0xf301", &[0b11110011, 0b00000001]), + ]; + for (hex_data, bytes) in data { + let bitfield = Bitfield::from_bytes(bytes); + assert_eq!( + serde_yaml::from_str::(hex_data).unwrap(), + bitfield + ); + assert_eq!( + serde_yaml::to_value(&bitfield).unwrap(), + Value::String(hex_data.to_string()) + ); + } + } + + #[test] + fn test_ssz_round_trip() { + let original = Bitfield::from_bytes(&vec![18; 12][..]); + let ssz = ssz_encode(&original); + let decoded = Bitfield::from_ssz_bytes(&ssz).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn test_bitor() { + let a = Bitfield::from_bytes(&vec![2, 8, 1][..]); + let b = Bitfield::from_bytes(&vec![4, 8, 16][..]); + let c = Bitfield::from_bytes(&vec![6, 8, 17][..]); + assert_eq!(c, a | b); + } + + #[test] + fn test_is_zero() { + let yes_data: &[&[u8]] = &[&[], &[0], &[0, 0], &[0, 0, 0]]; + for bytes in yes_data { + assert!(Bitfield::from_bytes(bytes).is_zero()); + } + let no_data: &[&[u8]] = &[&[1], &[6], &[0, 1], &[0, 0, 1], &[0, 0, 255]]; + for bytes in no_data { + assert!(!Bitfield::from_bytes(bytes).is_zero()); + } + } + + #[test] + fn test_intersection() { + let a = Bitfield::from_bytes(&[0b1100, 0b0001]); + let b = Bitfield::from_bytes(&[0b1011, 0b1001]); + let c = Bitfield::from_bytes(&[0b1000, 0b0001]); + assert_eq!(a.intersection(&b), c); + assert_eq!(b.intersection(&a), c); + assert_eq!(a.intersection(&c), c); + assert_eq!(b.intersection(&c), c); + assert_eq!(a.intersection(&a), a); + assert_eq!(b.intersection(&b), b); + assert_eq!(c.intersection(&c), c); + } + + #[test] + fn test_union() { + let a = Bitfield::from_bytes(&[0b1100, 0b0001]); + let b = Bitfield::from_bytes(&[0b1011, 0b1001]); + let c = Bitfield::from_bytes(&[0b1111, 0b1001]); + assert_eq!(a.union(&b), c); + assert_eq!(b.union(&a), c); + assert_eq!(a.union(&a), a); + assert_eq!(b.union(&b), b); + assert_eq!(c.union(&c), c); + } + + #[test] + fn test_difference() { + let a = Bitfield::from_bytes(&[0b1100, 0b0001]); + let b = Bitfield::from_bytes(&[0b1011, 0b1001]); + let a_b = Bitfield::from_bytes(&[0b0100, 0b0000]); + let b_a = Bitfield::from_bytes(&[0b0011, 0b1000]); + assert_eq!(a.difference(&b), a_b); + assert_eq!(b.difference(&a), b_a); + assert!(a.difference(&a).is_zero()); + } +} diff --git a/eth2/utils/ssz_types/src/fixed_vector.rs b/eth2/utils/ssz_types/src/fixed_vector.rs new file mode 100644 index 000000000..071532e22 --- /dev/null +++ b/eth2/utils/ssz_types/src/fixed_vector.rs @@ -0,0 +1,335 @@ +use crate::FixedSizedError as Error; +use serde_derive::{Deserialize, Serialize}; +use std::marker::PhantomData; +use std::ops::{Deref, Index, IndexMut}; +use std::slice::SliceIndex; +use typenum::Unsigned; + +pub use typenum; + +/// Emulates a SSZ `Vector` (distinct from a Rust `Vec`). +/// +/// An ordered, heap-allocated, fixed-length, homogeneous collection of `T`, with `N` values. +/// +/// This struct is backed by a Rust `Vec` but constrained such that it must be instantiated with a +/// fixed number of elements and you may not add or remove elements, only modify. +/// +/// The length of this struct is fixed at the type-level using +/// [typenum](https://crates.io/crates/typenum). +/// +/// ## Note +/// +/// Whilst it is possible with this library, SSZ declares that a `FixedVector` with a length of `0` +/// is illegal. +/// +/// ## Example +/// +/// ``` +/// use ssz_types::{FixedVector, typenum}; +/// +/// let base: Vec = vec![1, 2, 3, 4]; +/// +/// // Create a `FixedVector` from a `Vec` that has the expected length. +/// let exact: FixedVector<_, typenum::U4> = FixedVector::from(base.clone()); +/// assert_eq!(&exact[..], &[1, 2, 3, 4]); +/// +/// // Create a `FixedVector` from a `Vec` that is too long and the `Vec` is truncated. +/// let short: FixedVector<_, typenum::U3> = FixedVector::from(base.clone()); +/// assert_eq!(&short[..], &[1, 2, 3]); +/// +/// // Create a `FixedVector` from a `Vec` that is too short and the missing values are created +/// // using `std::default::Default`. +/// let long: FixedVector<_, typenum::U5> = FixedVector::from(base); +/// assert_eq!(&long[..], &[1, 2, 3, 4, 0]); +/// ``` +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[serde(transparent)] +pub struct FixedVector { + vec: Vec, + _phantom: PhantomData, +} + +impl FixedVector { + /// Returns `Some` if the given `vec` equals the fixed length of `Self`. Otherwise returns + /// `None`. + pub fn new(vec: Vec) -> Result { + if vec.len() == Self::capacity() { + Ok(Self { + vec, + _phantom: PhantomData, + }) + } else { + Err(Error::InvalidLength { + len: vec.len(), + fixed_len: Self::capacity(), + }) + } + } + + /// Identical to `self.capacity`, returns the type-level constant length. + /// + /// Exists for compatibility with `Vec`. + pub fn len(&self) -> usize { + self.vec.len() + } + + /// True if the type-level constant length of `self` is zero. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the type-level constant length. + pub fn capacity() -> usize { + N::to_usize() + } +} + +impl From> for FixedVector { + fn from(mut vec: Vec) -> Self { + vec.resize_with(Self::capacity(), Default::default); + + Self { + vec, + _phantom: PhantomData, + } + } +} + +impl Into> for FixedVector { + fn into(self) -> Vec { + self.vec + } +} + +impl Default for FixedVector { + fn default() -> Self { + Self { + vec: Vec::default(), + _phantom: PhantomData, + } + } +} + +impl> Index for FixedVector { + type Output = I::Output; + + #[inline] + fn index(&self, index: I) -> &Self::Output { + Index::index(&self.vec, index) + } +} + +impl> IndexMut for FixedVector { + #[inline] + fn index_mut(&mut self, index: I) -> &mut Self::Output { + IndexMut::index_mut(&mut self.vec, index) + } +} + +impl Deref for FixedVector { + type Target = [T]; + + fn deref(&self) -> &[T] { + &self.vec[..] + } +} + +#[cfg(test)] +mod test { + use super::*; + use typenum::*; + + #[test] + fn new() { + let vec = vec![42; 5]; + let fixed: Result, _> = FixedVector::new(vec.clone()); + assert!(fixed.is_err()); + + let vec = vec![42; 3]; + let fixed: Result, _> = FixedVector::new(vec.clone()); + assert!(fixed.is_err()); + + let vec = vec![42; 4]; + let fixed: Result, _> = FixedVector::new(vec.clone()); + assert!(fixed.is_ok()); + } + + #[test] + fn indexing() { + let vec = vec![1, 2]; + + let mut fixed: FixedVector = vec.clone().into(); + + assert_eq!(fixed[0], 1); + assert_eq!(&fixed[0..1], &vec[0..1]); + assert_eq!((&fixed[..]).len(), 8192); + + fixed[1] = 3; + assert_eq!(fixed[1], 3); + } + + #[test] + fn length() { + let vec = vec![42; 5]; + let fixed: FixedVector = FixedVector::from(vec.clone()); + assert_eq!(&fixed[..], &vec[0..4]); + + let vec = vec![42; 3]; + let fixed: FixedVector = FixedVector::from(vec.clone()); + assert_eq!(&fixed[0..3], &vec[..]); + assert_eq!(&fixed[..], &vec![42, 42, 42, 0][..]); + + let vec = vec![]; + let fixed: FixedVector = FixedVector::from(vec.clone()); + assert_eq!(&fixed[..], &vec![0, 0, 0, 0][..]); + } + + #[test] + fn deref() { + let vec = vec![0, 2, 4, 6]; + let fixed: FixedVector = FixedVector::from(vec); + + assert_eq!(fixed.get(0), Some(&0)); + assert_eq!(fixed.get(3), Some(&6)); + assert_eq!(fixed.get(4), None); + } +} + +impl tree_hash::TreeHash for FixedVector +where + T: tree_hash::TreeHash, +{ + fn tree_hash_type() -> tree_hash::TreeHashType { + tree_hash::TreeHashType::Vector + } + + fn tree_hash_packed_encoding(&self) -> Vec { + unreachable!("Vector should never be packed.") + } + + fn tree_hash_packing_factor() -> usize { + unreachable!("Vector should never be packed.") + } + + fn tree_hash_root(&self) -> Vec { + tree_hash::impls::vec_tree_hash_root(&self.vec) + } +} + +impl cached_tree_hash::CachedTreeHash for FixedVector +where + T: cached_tree_hash::CachedTreeHash + tree_hash::TreeHash, +{ + fn new_tree_hash_cache( + &self, + depth: usize, + ) -> Result { + let (cache, _overlay) = cached_tree_hash::vec::new_tree_hash_cache(&self.vec, depth)?; + + Ok(cache) + } + + fn tree_hash_cache_schema(&self, depth: usize) -> cached_tree_hash::BTreeSchema { + cached_tree_hash::vec::produce_schema(&self.vec, depth) + } + + fn update_tree_hash_cache( + &self, + cache: &mut cached_tree_hash::TreeHashCache, + ) -> Result<(), cached_tree_hash::Error> { + cached_tree_hash::vec::update_tree_hash_cache(&self.vec, cache)?; + + Ok(()) + } +} + +impl ssz::Encode for FixedVector +where + T: ssz::Encode, +{ + fn is_ssz_fixed_len() -> bool { + true + } + + fn ssz_fixed_len() -> usize { + if ::is_ssz_fixed_len() { + T::ssz_fixed_len() * N::to_usize() + } else { + ssz::BYTES_PER_LENGTH_OFFSET + } + } + + fn ssz_append(&self, buf: &mut Vec) { + if T::is_ssz_fixed_len() { + buf.reserve(T::ssz_fixed_len() * self.len()); + + for item in &self.vec { + item.ssz_append(buf); + } + } else { + let mut encoder = ssz::SszEncoder::list(buf, self.len() * ssz::BYTES_PER_LENGTH_OFFSET); + + for item in &self.vec { + encoder.append(item); + } + + encoder.finalize(); + } + } +} + +impl ssz::Decode for FixedVector +where + T: ssz::Decode + Default, +{ + fn is_ssz_fixed_len() -> bool { + T::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + if ::is_ssz_fixed_len() { + T::ssz_fixed_len() * N::to_usize() + } else { + ssz::BYTES_PER_LENGTH_OFFSET + } + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + if bytes.is_empty() { + Ok(FixedVector::from(vec![])) + } else if T::is_ssz_fixed_len() { + bytes + .chunks(T::ssz_fixed_len()) + .map(|chunk| T::from_ssz_bytes(chunk)) + .collect::, _>>() + .and_then(|vec| Ok(vec.into())) + } else { + ssz::decode_list_of_variable_length_items(bytes).and_then(|vec| Ok(vec.into())) + } + } +} + +#[cfg(test)] +mod ssz_tests { + use super::*; + use ssz::*; + use typenum::*; + + #[test] + fn encode() { + let vec: FixedVector = vec![0; 2].into(); + assert_eq!(vec.as_ssz_bytes(), vec![0, 0, 0, 0]); + assert_eq!( as Encode>::ssz_fixed_len(), 4); + } + + fn round_trip(item: T) { + let encoded = &item.as_ssz_bytes(); + assert_eq!(T::from_ssz_bytes(&encoded), Ok(item)); + } + + #[test] + fn u16_len_8() { + round_trip::>(vec![42; 8].into()); + round_trip::>(vec![0; 8].into()); + } +} diff --git a/eth2/utils/ssz_types/src/lib.rs b/eth2/utils/ssz_types/src/lib.rs new file mode 100644 index 000000000..37aa24875 --- /dev/null +++ b/eth2/utils/ssz_types/src/lib.rs @@ -0,0 +1,23 @@ +mod bit_vector; +mod bitfield; +mod fixed_vector; +mod variable_list; + +pub use bit_vector::{BitList, BitVector}; +pub use fixed_vector::FixedVector; +pub use typenum; +pub use variable_list::VariableList; + +/// Returned when a variable-length item encounters an error. +#[derive(PartialEq, Debug)] +pub enum VariableSizedError { + /// The operation would cause the maximum length to be exceeded. + ExceedsMaxLength { len: usize, max_len: usize }, +} + +/// Returned when a fixed-length item encounters an error. +#[derive(PartialEq, Debug)] +pub enum FixedSizedError { + /// The operation would create an item of an invalid size. + InvalidLength { len: usize, fixed_len: usize }, +} diff --git a/eth2/utils/ssz_types/src/variable_list.rs b/eth2/utils/ssz_types/src/variable_list.rs new file mode 100644 index 000000000..3d0bf31c9 --- /dev/null +++ b/eth2/utils/ssz_types/src/variable_list.rs @@ -0,0 +1,321 @@ +use crate::VariableSizedError as Error; +use serde_derive::{Deserialize, Serialize}; +use std::marker::PhantomData; +use std::ops::{Deref, Index, IndexMut}; +use std::slice::SliceIndex; +use typenum::Unsigned; + +pub use typenum; + +/// Emulates a SSZ `List`. +/// +/// An ordered, heap-allocated, variable-length, homogeneous collection of `T`, with no more than +/// `N` values. +/// +/// This struct is backed by a Rust `Vec` but constrained such that it must be instantiated with a +/// fixed number of elements and you may not add or remove elements, only modify. +/// +/// The length of this struct is fixed at the type-level using +/// [typenum](https://crates.io/crates/typenum). +/// +/// ## Example +/// +/// ``` +/// use ssz_types::{VariableList, typenum}; +/// +/// let base: Vec = vec![1, 2, 3, 4]; +/// +/// // Create a `VariableList` from a `Vec` that has the expected length. +/// let exact: VariableList<_, typenum::U4> = VariableList::from(base.clone()); +/// assert_eq!(&exact[..], &[1, 2, 3, 4]); +/// +/// // Create a `VariableList` from a `Vec` that is too long and the `Vec` is truncated. +/// let short: VariableList<_, typenum::U3> = VariableList::from(base.clone()); +/// assert_eq!(&short[..], &[1, 2, 3]); +/// +/// // Create a `VariableList` from a `Vec` that is shorter than the maximum. +/// let mut long: VariableList<_, typenum::U5> = VariableList::from(base); +/// assert_eq!(&long[..], &[1, 2, 3, 4]); +/// +/// // Push a value to if it does not exceed the maximum +/// long.push(5).unwrap(); +/// assert_eq!(&long[..], &[1, 2, 3, 4, 5]); +/// +/// // Push a value to if it _does_ exceed the maximum. +/// assert!(long.push(6).is_err()); +/// ``` +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[serde(transparent)] +pub struct VariableList { + vec: Vec, + _phantom: PhantomData, +} + +impl VariableList { + /// Returns `Some` if the given `vec` equals the fixed length of `Self`. Otherwise returns + /// `None`. + pub fn new(vec: Vec) -> Result { + if vec.len() <= N::to_usize() { + Ok(Self { + vec, + _phantom: PhantomData, + }) + } else { + Err(Error::ExceedsMaxLength { + len: vec.len(), + max_len: Self::max_len(), + }) + } + } + + /// Returns the number of values presently in `self`. + pub fn len(&self) -> usize { + self.vec.len() + } + + /// True if `self` does not contain any values. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the type-level maximum length. + pub fn max_len() -> usize { + N::to_usize() + } + + /// Appends `value` to the back of `self`. + /// + /// Returns `Err(())` when appending `value` would exceed the maximum length. + pub fn push(&mut self, value: T) -> Result<(), Error> { + if self.vec.len() < Self::max_len() { + Ok(self.vec.push(value)) + } else { + Err(Error::ExceedsMaxLength { + len: self.vec.len() + 1, + max_len: Self::max_len(), + }) + } + } +} + +impl From> for VariableList { + fn from(mut vec: Vec) -> Self { + vec.truncate(N::to_usize()); + + Self { + vec, + _phantom: PhantomData, + } + } +} + +impl Into> for VariableList { + fn into(self) -> Vec { + self.vec + } +} + +impl Default for VariableList { + fn default() -> Self { + Self { + vec: Vec::default(), + _phantom: PhantomData, + } + } +} + +impl> Index for VariableList { + type Output = I::Output; + + #[inline] + fn index(&self, index: I) -> &Self::Output { + Index::index(&self.vec, index) + } +} + +impl> IndexMut for VariableList { + #[inline] + fn index_mut(&mut self, index: I) -> &mut Self::Output { + IndexMut::index_mut(&mut self.vec, index) + } +} + +impl Deref for VariableList { + type Target = [T]; + + fn deref(&self) -> &[T] { + &self.vec[..] + } +} + +#[cfg(test)] +mod test { + use super::*; + use typenum::*; + + #[test] + fn new() { + let vec = vec![42; 5]; + let fixed: Result, _> = VariableList::new(vec.clone()); + assert!(fixed.is_err()); + + let vec = vec![42; 3]; + let fixed: Result, _> = VariableList::new(vec.clone()); + assert!(fixed.is_ok()); + + let vec = vec![42; 4]; + let fixed: Result, _> = VariableList::new(vec.clone()); + assert!(fixed.is_ok()); + } + + #[test] + fn indexing() { + let vec = vec![1, 2]; + + let mut fixed: VariableList = vec.clone().into(); + + assert_eq!(fixed[0], 1); + assert_eq!(&fixed[0..1], &vec[0..1]); + assert_eq!((&fixed[..]).len(), 2); + + fixed[1] = 3; + assert_eq!(fixed[1], 3); + } + + #[test] + fn length() { + let vec = vec![42; 5]; + let fixed: VariableList = VariableList::from(vec.clone()); + assert_eq!(&fixed[..], &vec[0..4]); + + let vec = vec![42; 3]; + let fixed: VariableList = VariableList::from(vec.clone()); + assert_eq!(&fixed[0..3], &vec[..]); + assert_eq!(&fixed[..], &vec![42, 42, 42][..]); + + let vec = vec![]; + let fixed: VariableList = VariableList::from(vec.clone()); + assert_eq!(&fixed[..], &vec![][..]); + } + + #[test] + fn deref() { + let vec = vec![0, 2, 4, 6]; + let fixed: VariableList = VariableList::from(vec); + + assert_eq!(fixed.get(0), Some(&0)); + assert_eq!(fixed.get(3), Some(&6)); + assert_eq!(fixed.get(4), None); + } +} + +/* +impl tree_hash::TreeHash for VariableList +where + T: tree_hash::TreeHash, +{ + fn tree_hash_type() -> tree_hash::TreeHashType { + tree_hash::TreeHashType::Vector + } + + fn tree_hash_packed_encoding(&self) -> Vec { + unreachable!("Vector should never be packed.") + } + + fn tree_hash_packing_factor() -> usize { + unreachable!("Vector should never be packed.") + } + + fn tree_hash_root(&self) -> Vec { + tree_hash::impls::vec_tree_hash_root(&self.vec) + } +} + +impl cached_tree_hash::CachedTreeHash for VariableList +where + T: cached_tree_hash::CachedTreeHash + tree_hash::TreeHash, +{ + fn new_tree_hash_cache( + &self, + depth: usize, + ) -> Result { + let (cache, _overlay) = cached_tree_hash::vec::new_tree_hash_cache(&self.vec, depth)?; + + Ok(cache) + } + + fn tree_hash_cache_schema(&self, depth: usize) -> cached_tree_hash::BTreeSchema { + cached_tree_hash::vec::produce_schema(&self.vec, depth) + } + + fn update_tree_hash_cache( + &self, + cache: &mut cached_tree_hash::TreeHashCache, + ) -> Result<(), cached_tree_hash::Error> { + cached_tree_hash::vec::update_tree_hash_cache(&self.vec, cache)?; + + Ok(()) + } +} +*/ + +impl ssz::Encode for VariableList +where + T: ssz::Encode, +{ + fn is_ssz_fixed_len() -> bool { + >::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + >::ssz_fixed_len() + } + + fn ssz_append(&self, buf: &mut Vec) { + self.vec.ssz_append(buf) + } +} + +impl ssz::Decode for VariableList +where + T: ssz::Decode + Default, +{ + fn is_ssz_fixed_len() -> bool { + >::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + >::ssz_fixed_len() + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + let vec = >::from_ssz_bytes(bytes)?; + + Self::new(vec).map_err(|e| ssz::DecodeError::BytesInvalid(format!("VariableList {:?}", e))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ssz::*; + use typenum::*; + + #[test] + fn encode() { + let vec: VariableList = vec![0; 2].into(); + assert_eq!(vec.as_ssz_bytes(), vec![0, 0, 0, 0]); + assert_eq!( as Encode>::ssz_fixed_len(), 4); + } + + fn round_trip(item: T) { + let encoded = &item.as_ssz_bytes(); + assert_eq!(T::from_ssz_bytes(&encoded), Ok(item)); + } + + #[test] + fn u16_len_8() { + round_trip::>(vec![42; 8].into()); + round_trip::>(vec![0; 8].into()); + } +}