diff --git a/ssz/src/decode.rs b/ssz/src/decode.rs new file mode 100644 index 000000000..591d8c863 --- /dev/null +++ b/ssz/src/decode.rs @@ -0,0 +1,135 @@ +use super::{ + LENGTH_BYTES, +}; + +#[derive(Debug)] +pub enum DecodeError { + OutOfBounds, + TooShort, + TooLong, +} + +pub trait Decodable: Sized { + fn ssz_decode(bytes: &[u8]) -> Result; +} + + +pub fn decode_ssz_list_element(ssz_bytes: &[u8], n: usize) + -> Result + where T: Decodable +{ + T::ssz_decode(nth_value(ssz_bytes, n)?) +} + +fn nth_value(ssz_bytes: &[u8], n: usize) + -> Result<&[u8], DecodeError> +{ + let mut c: usize = 0; + for i in 0..(n + 1) { + let length = decode_length(&ssz_bytes[c..], LENGTH_BYTES)?; + let next = c + LENGTH_BYTES + length; + + if i == n { + return Ok(&ssz_bytes[c + LENGTH_BYTES..next]); + } else { + if next >= ssz_bytes.len() { + return Err(DecodeError::OutOfBounds); + } else { + c = next; + } + } + } + Err(DecodeError::OutOfBounds) +} + +fn decode_length(bytes: &[u8], length_bytes: usize) + -> Result +{ + if bytes.len() < length_bytes { + return Err(DecodeError::TooShort); + }; + let mut len: usize = 0; + for i in 0..length_bytes { + let offset = (length_bytes - i - 1) * 8; + len = ((bytes[i] as usize) << offset) | len; + }; + Ok(len) +} + +#[cfg(test)] +mod tests { + use super::*; + use super::super::encode_length; + + #[test] + fn test_ssz_decode_length() { + let decoded = decode_length( + &vec![0, 0, 0, 1], + LENGTH_BYTES); + assert_eq!(decoded.unwrap(), 1); + + let decoded = decode_length( + &vec![0, 0, 1, 0], + LENGTH_BYTES); + assert_eq!(decoded.unwrap(), 256); + + let decoded = decode_length( + &vec![0, 0, 1, 255], + LENGTH_BYTES); + assert_eq!(decoded.unwrap(), 511); + + let decoded = decode_length( + &vec![255, 255, 255, 255], + LENGTH_BYTES); + assert_eq!(decoded.unwrap(), 4294967295); + } + + #[test] + fn test_encode_decode_length() { + let params: Vec = vec![ + 0, 1, 2, 3, 7, 8, 16, + 2^8, 2^8 + 1, + 2^16, 2^16 + 1, + 2^24, 2^24 + 1, + 2^32, + ]; + for i in params { + let decoded = decode_length( + &encode_length(i, LENGTH_BYTES), + LENGTH_BYTES).unwrap(); + assert_eq!(i, decoded); + } + } + + #[test] + fn test_ssz_nth_value() { + let ssz = vec![0, 0, 0, 1, 0]; + let result = nth_value(&ssz, 0).unwrap(); + assert_eq!(result, vec![0].as_slice()); + + let ssz = vec![0, 0, 0, 4, 1, 2, 3, 4]; + let result = nth_value(&ssz, 0).unwrap(); + assert_eq!(result, vec![1, 2, 3, 4].as_slice()); + + let ssz = vec![0, 0, 0, 1, 0, 0, 0, 0, 1, 1]; + let result = nth_value(&ssz, 1).unwrap(); + assert_eq!(result, vec![1].as_slice()); + + let mut ssz = vec![0, 0, 1, 255]; + ssz.append(&mut vec![42; 511]); + let result = nth_value(&ssz, 0).unwrap(); + assert_eq!(result, vec![42; 511].as_slice()); + } + + /* + #[test] + fn test_ssz_decode_u16() { + let x: u16 = 100; + let mut s = SszStream::new(); + s.append(&x); + let y: u16 = u16::ssz_decode(s.nth_value(0).unwrap()) + .unwrap(); + assert_eq!(x, y); + } + */ +} diff --git a/ssz/src/impls.rs b/ssz/src/impls.rs index 82f339351..eae6355d6 100644 --- a/ssz/src/impls.rs +++ b/ssz/src/impls.rs @@ -13,18 +13,14 @@ use super::ethereum_types::{ H256, U256 }; macro_rules! impl_decodable_for_uint { ($type: ident, $bit_size: expr) => { impl Decodable for $type { - type Decoded = $type; - - fn ssz_decode<$type>(bytes: &[u8]) - -> Result + fn ssz_decode(bytes: &[u8]) + -> Result { - // TOOD: figure out if than can be done at compile time - // instead of runtime (where I assume it happens). assert!(0 < $bit_size && $bit_size <= 64 && $bit_size % 8 == 0); let bytes_required = $bit_size / 8; - if bytes_required == bytes.len() { + if bytes_required <= bytes.len() { let mut result = 0; for i in 0..bytes.len() { let offset = (bytes.len() - i - 1) * 8; @@ -32,10 +28,7 @@ macro_rules! impl_decodable_for_uint { }; Ok(result.into()) } else { - match bytes_required > bytes.len() { - true => Err(DecodeError::TooLong), - false => Err(DecodeError::TooShort), - } + Err(DecodeError::TooLong) } } } diff --git a/ssz/src/lib.rs b/ssz/src/lib.rs index a3fa9d944..1d8f36bb9 100644 --- a/ssz/src/lib.rs +++ b/ssz/src/lib.rs @@ -11,6 +11,13 @@ extern crate bytes; extern crate ethereum_types; mod impls; +mod decode; + +pub use decode::{ + decode_ssz_list_element, + Decodable, + DecodeError +}; pub const LENGTH_BYTES: usize = 4; @@ -18,23 +25,10 @@ pub trait Encodable { fn ssz_append(&self, s: &mut SszStream); } -pub trait Decodable { - type Decoded; - - fn ssz_decode(bytes: &[u8]) -> Result; -} - pub struct SszStream { buffer: Vec } -#[derive(Debug)] -pub enum DecodeError { - OutOfBounds, - TooShort, - TooLong, -} - impl SszStream { /// Create a new, empty steam for writing ssz values. pub fn new() -> Self { @@ -92,20 +86,6 @@ fn encode_length(len: usize, length_bytes: usize) -> Vec { header } -fn decode_length(bytes: &Vec, length_bytes: usize) - -> Result -{ - if bytes.len() < length_bytes { - return Err(DecodeError::TooShort); - }; - let mut len: usize = 0; - for i in 0..length_bytes { - let offset = (length_bytes - i - 1) * 8; - len = ((bytes[i] as usize) << offset) | len; - }; - Ok(len) -} - #[cfg(test)] mod tests { @@ -148,45 +128,6 @@ mod tests { encode_length(4294967296, 4); // 2^(4*8) } - #[test] - fn test_decode_length() { - let decoded = decode_length( - &vec![0, 0, 0, 1], - LENGTH_BYTES); - assert_eq!(decoded.unwrap(), 1); - - let decoded = decode_length( - &vec![0, 0, 1, 0], - LENGTH_BYTES); - assert_eq!(decoded.unwrap(), 256); - } - - #[test] - fn test_encode_decode_length() { - let params: Vec = vec![ - 0, - 1, - 2, - 3, - 7, - 8, - 16, - 2^8, - 2^8 + 1, - 2^16, - 2^16 + 1, - 2^24, - 2^24 + 1, - 2^32, - ]; - for i in params { - let decoded = decode_length( - &encode_length(i, LENGTH_BYTES), - LENGTH_BYTES).unwrap(); - assert_eq!(i, decoded); - } - } - #[test] fn test_serialization() { pub struct TestStruct {