use super::LENGTH_BYTES; #[derive(Debug, PartialEq)] pub enum DecodeError { TooShort, TooLong, Invalid, } pub trait Decodable: Sized { fn ssz_decode(bytes: &[u8], index: usize) -> Result<(Self, usize), DecodeError>; } /// Decode the given bytes for the given type /// /// The single ssz encoded value will be decoded as the given type at the /// given index. pub fn decode_ssz(ssz_bytes: &[u8], index: usize) -> Result<(T, usize), DecodeError> where T: Decodable, { if index >= ssz_bytes.len() { return Err(DecodeError::TooShort); } T::ssz_decode(ssz_bytes, index) } /// Decode a vector (list) of encoded bytes. /// /// Each element in the list will be decoded and placed into the vector. pub fn decode_ssz_list(ssz_bytes: &[u8], index: usize) -> Result<(Vec, usize), DecodeError> where T: Decodable, { if index + LENGTH_BYTES > ssz_bytes.len() { return Err(DecodeError::TooShort); }; // get the length let serialized_length = match decode_length(ssz_bytes, index, LENGTH_BYTES) { Err(v) => return Err(v), Ok(v) => v, }; let final_len: usize = index + LENGTH_BYTES + serialized_length; if final_len > ssz_bytes.len() { return Err(DecodeError::TooShort); }; let mut tmp_index = index + LENGTH_BYTES; let mut res_vec: Vec = Vec::new(); while tmp_index < final_len { match T::ssz_decode(ssz_bytes, tmp_index) { Err(v) => return Err(v), Ok(v) => { tmp_index = v.1; res_vec.push(v.0); } }; } Ok((res_vec, final_len)) } /// Given some number of bytes, interpret the first four /// bytes as a 32-bit little-endian integer and return the /// result. pub fn decode_length( bytes: &[u8], index: usize, length_bytes: usize, ) -> Result { if bytes.len() < index + length_bytes { return Err(DecodeError::TooShort); }; let mut len: usize = 0; for (i, byte) in bytes .iter() .enumerate() .take(index + length_bytes) .skip(index) { let offset = (length_bytes - (length_bytes - (i - index))) * 8; len |= (*byte as usize) << offset; } Ok(len) } #[cfg(test)] mod tests { use super::super::encode::*; use super::*; #[test] fn test_ssz_decode_length() { let decoded = decode_length(&vec![1, 0, 0, 0], 0, LENGTH_BYTES); assert_eq!(decoded.unwrap(), 1); let decoded = decode_length(&vec![0, 1, 0, 0], 0, LENGTH_BYTES); assert_eq!(decoded.unwrap(), 256); let decoded = decode_length(&vec![255, 1, 0, 0], 0, LENGTH_BYTES); assert_eq!(decoded.unwrap(), 511); let decoded = decode_length(&vec![255, 255, 255, 255], 0, LENGTH_BYTES); assert_eq!(decoded.unwrap(), 4294967295); } #[test] fn test_encode_decode_length() { let params: Vec = vec![ 0, 1, 2, 3, 7, 8, 16, 2 ^ 8, 2 ^ 8 + 1, 2 ^ 16, 2 ^ 16 + 1, 2 ^ 24, 2 ^ 24 + 1, 2 ^ 32, ]; for i in params { let decoded = decode_length(&encode_length(i, LENGTH_BYTES), 0, LENGTH_BYTES).unwrap(); assert_eq!(i, decoded); } } #[test] fn test_encode_decode_ssz_list() { let test_vec: Vec = vec![256; 12]; let mut stream = SszStream::new(); stream.append_vec(&test_vec); let ssz = stream.drain(); // u16 let decoded: (Vec, usize) = decode_ssz_list(&ssz, 0).unwrap(); assert_eq!(decoded.0, test_vec); assert_eq!(decoded.1, LENGTH_BYTES + (12 * 2)); } #[test] fn test_decode_ssz_list() { // u16 let v: Vec = vec![10, 10, 10, 10]; let decoded: (Vec, usize) = decode_ssz_list(&vec![8, 0, 0, 0, 10, 0, 10, 0, 10, 0, 10, 0], 0).unwrap(); assert_eq!(decoded.0, v); assert_eq!(decoded.1, LENGTH_BYTES + (4 * 2)); // u32 let v: Vec = vec![10, 10, 10, 10]; let decoded: (Vec, usize) = decode_ssz_list( &vec![ 16, 0, 0, 0, 10, 0, 0, 0, 10, 0, 0, 0, 10, 0, 0, 0, 10, 0, 0, 00, ], 0, ) .unwrap(); assert_eq!(decoded.0, v); assert_eq!(decoded.1, 20); // u64 let v: Vec = vec![10, 10, 10, 10]; let decoded: (Vec, usize) = decode_ssz_list( &vec![ 32, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, ], 0, ) .unwrap(); assert_eq!(decoded.0, v); assert_eq!(decoded.1, LENGTH_BYTES + (8 * 4)); // Check that it can accept index let v: Vec = vec![15, 15, 15, 15]; let offset = 10; let decoded: (Vec, usize) = decode_ssz_list( &vec![ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 32, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, ], offset, ) .unwrap(); assert_eq!(decoded.0, v); assert_eq!(decoded.1, offset + LENGTH_BYTES + (8 * 4)); // Check that length > bytes throws error let decoded: Result<(Vec, usize), DecodeError> = decode_ssz_list(&vec![32, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0], 0); assert_eq!(decoded, Err(DecodeError::TooShort)); // Check that incorrect index throws error let decoded: Result<(Vec, usize), DecodeError> = decode_ssz_list(&vec![15, 0, 0, 0, 0, 0, 0, 0], 16); assert_eq!(decoded, Err(DecodeError::TooShort)); } }