diff --git a/Cargo.lock b/Cargo.lock index e7156ecd9..5d34cfad5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4668,6 +4668,15 @@ dependencies = [ "url 2.1.1", ] +[[package]] +name = "serde_utils" +version = "0.1.0" +dependencies = [ + "serde", + "serde_derive", + "serde_json", +] + [[package]] name = "serde_yaml" version = "0.8.13" @@ -5886,6 +5895,7 @@ dependencies = [ "serde", "serde_derive", "serde_json", + "serde_utils", "serde_yaml", "slog", "swap_or_not_shuffle", diff --git a/Cargo.toml b/Cargo.toml index 59bf507fa..fec48a134 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ members = [ "consensus/ssz_derive", "consensus/ssz_types", "consensus/serde_hex", + "consensus/serde_utils", "consensus/state_processing", "consensus/swap_or_not_shuffle", "consensus/tree_hash", diff --git a/consensus/serde_utils/Cargo.toml b/consensus/serde_utils/Cargo.toml new file mode 100644 index 000000000..1fb35736b --- /dev/null +++ b/consensus/serde_utils/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "serde_utils" +version = "0.1.0" +authors = ["Paul Hauner "] +edition = "2018" + +[dependencies] +serde = { version = "1.0.110", features = ["derive"] } +serde_derive = "1.0.110" + +[dev-dependencies] +serde_json = "1.0.52" diff --git a/consensus/serde_utils/src/lib.rs b/consensus/serde_utils/src/lib.rs new file mode 100644 index 000000000..df2b44b62 --- /dev/null +++ b/consensus/serde_utils/src/lib.rs @@ -0,0 +1,2 @@ +pub mod quoted_u64; +pub mod quoted_u64_vec; diff --git a/consensus/serde_utils/src/quoted_u64.rs b/consensus/serde_utils/src/quoted_u64.rs new file mode 100644 index 000000000..2e73a104f --- /dev/null +++ b/consensus/serde_utils/src/quoted_u64.rs @@ -0,0 +1,115 @@ +use serde::{Deserializer, Serializer}; +use serde_derive::{Deserialize, Serialize}; +use std::marker::PhantomData; + +/// Serde support for deserializing quoted integers. +/// +/// Configurable so that quotes are either required or optional. +pub struct QuotedIntVisitor { + require_quotes: bool, + _phantom: PhantomData, +} + +impl<'a, T> serde::de::Visitor<'a> for QuotedIntVisitor +where + T: From + Into + Copy, +{ + type Value = T; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + if self.require_quotes { + write!(formatter, "a quoted integer") + } else { + write!(formatter, "a quoted or unquoted integer") + } + } + + fn visit_str(self, s: &str) -> Result + where + E: serde::de::Error, + { + s.parse::() + .map(T::from) + .map_err(serde::de::Error::custom) + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + if self.require_quotes { + Err(serde::de::Error::custom( + "received unquoted integer when quotes are required", + )) + } else { + Ok(T::from(v)) + } + } +} + +/// Wrapper type for requiring quotes on a `u64`-like type. +/// +/// Unlike using `serde(with = "quoted_u64::require_quotes")` this is composable, and can be nested +/// inside types like `Option`, `Result` and `Vec`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)] +#[serde(transparent)] +pub struct Quoted +where + T: From + Into + Copy, +{ + #[serde(with = "require_quotes")] + pub value: T, +} + +/// Serialize with quotes. +pub fn serialize(value: &T, serializer: S) -> Result +where + S: Serializer, + T: From + Into + Copy, +{ + let v: u64 = (*value).into(); + serializer.serialize_str(&format!("{}", v)) +} + +/// Deserialize with or without quotes. +pub fn deserialize<'de, D, T>(deserializer: D) -> Result +where + D: Deserializer<'de>, + T: From + Into + Copy, +{ + deserializer.deserialize_any(QuotedIntVisitor { + require_quotes: false, + _phantom: PhantomData, + }) +} + +/// Requires quotes when deserializing. +/// +/// Usage: `#[serde(with = "quoted_u64::require_quotes")]`. +pub mod require_quotes { + pub use super::serialize; + use super::*; + + pub fn deserialize<'de, D, T>(deserializer: D) -> Result + where + D: Deserializer<'de>, + T: From + Into + Copy, + { + deserializer.deserialize_any(QuotedIntVisitor { + require_quotes: true, + _phantom: PhantomData, + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn require_quotes() { + let x = serde_json::from_str::>("\"8\"").unwrap(); + assert_eq!(x.value, 8); + serde_json::from_str::>("8").unwrap_err(); + } +} diff --git a/consensus/serde_utils/src/quoted_u64_vec.rs b/consensus/serde_utils/src/quoted_u64_vec.rs new file mode 100644 index 000000000..c5badee50 --- /dev/null +++ b/consensus/serde_utils/src/quoted_u64_vec.rs @@ -0,0 +1,91 @@ +use serde::ser::SerializeSeq; +use serde::{Deserializer, Serializer}; +use serde_derive::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +#[serde(transparent)] +pub struct QuotedIntWrapper { + #[serde(with = "crate::quoted_u64")] + int: u64, +} + +pub struct QuotedIntVecVisitor; +impl<'a> serde::de::Visitor<'a> for QuotedIntVecVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a list of quoted or unquoted integers") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'a>, + { + let mut vec = vec![]; + + while let Some(val) = seq.next_element()? { + let val: QuotedIntWrapper = val; + vec.push(val.int); + } + + Ok(vec) + } +} + +pub fn serialize(value: &[u64], serializer: S) -> Result +where + S: Serializer, +{ + let mut seq = serializer.serialize_seq(Some(value.len()))?; + for &int in value { + seq.serialize_element(&QuotedIntWrapper { int })?; + } + seq.end() +} + +pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + deserializer.deserialize_any(QuotedIntVecVisitor) +} + +#[cfg(test)] +mod test { + use super::*; + + #[derive(Debug, Serialize, Deserialize)] + struct Obj { + #[serde(with = "crate::quoted_u64_vec")] + values: Vec, + } + + #[test] + fn quoted_list_success() { + let obj: Obj = serde_json::from_str(r#"{ "values": ["1", "2", "3", "4"] }"#).unwrap(); + assert_eq!(obj.values, vec![1, 2, 3, 4]); + } + + #[test] + fn unquoted_list_success() { + let obj: Obj = serde_json::from_str(r#"{ "values": [1, 2, 3, 4] }"#).unwrap(); + assert_eq!(obj.values, vec![1, 2, 3, 4]); + } + + #[test] + fn mixed_list_success() { + let obj: Obj = serde_json::from_str(r#"{ "values": ["1", 2, "3", "4"] }"#).unwrap(); + assert_eq!(obj.values, vec![1, 2, 3, 4]); + } + + #[test] + fn empty_list_success() { + let obj: Obj = serde_json::from_str(r#"{ "values": [] }"#).unwrap(); + assert!(obj.values.is_empty()); + } + + #[test] + fn whole_list_quoted_err() { + serde_json::from_str::(r#"{ "values": "[1, 2, 3, 4]" }"#).unwrap_err(); + } +} diff --git a/consensus/types/Cargo.toml b/consensus/types/Cargo.toml index d893ff3ad..6334eaf4a 100644 --- a/consensus/types/Cargo.toml +++ b/consensus/types/Cargo.toml @@ -25,6 +25,7 @@ rand = "0.7.3" safe_arith = { path = "../safe_arith" } serde = "1.0.110" serde_derive = "1.0.110" +serde_utils = { path = "../serde_utils" } slog = "2.5.2" eth2_ssz = "0.1.2" eth2_ssz_derive = "0.1.0" diff --git a/consensus/types/src/slot_epoch.rs b/consensus/types/src/slot_epoch.rs index 42b922cfb..cec17b09a 100644 --- a/consensus/types/src/slot_epoch.rs +++ b/consensus/types/src/slot_epoch.rs @@ -25,11 +25,12 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, Sub, SubAssi #[cfg_attr(feature = "arbitrary-fuzz", derive(arbitrary::Arbitrary))] #[derive(Eq, Clone, Copy, Default, Serialize, Deserialize)] #[serde(transparent)] -pub struct Slot(u64); +pub struct Slot(#[serde(with = "serde_utils::quoted_u64")] u64); #[cfg_attr(feature = "arbitrary-fuzz", derive(arbitrary::Arbitrary))] #[derive(Eq, Clone, Copy, Default, Serialize, Deserialize)] -pub struct Epoch(u64); +#[serde(transparent)] +pub struct Epoch(#[serde(with = "serde_utils::quoted_u64")] u64); impl_common!(Slot); impl_common!(Epoch); diff --git a/consensus/types/src/utils.rs b/consensus/types/src/utils.rs index 51af86692..a527fc18f 100644 --- a/consensus/types/src/utils.rs +++ b/consensus/types/src/utils.rs @@ -1,3 +1,3 @@ mod serde_utils; -pub use serde_utils::*; +pub use self::serde_utils::*;