diff --git a/consensus/serde_utils/src/u256_hex_be_opt.rs b/consensus/serde_utils/src/u256_hex_be_opt.rs index 8eadbf024..e53cb5c69 100644 --- a/consensus/serde_utils/src/u256_hex_be_opt.rs +++ b/consensus/serde_utils/src/u256_hex_be_opt.rs @@ -1,6 +1,6 @@ use ethereum_types::U256; -use serde::de::Visitor; +use serde::de::{Error, Visitor}; use serde::{de, Deserializer, Serialize, Serializer}; use std::fmt; use std::str::FromStr; @@ -15,12 +15,26 @@ where pub struct U256Visitor; impl<'de> Visitor<'de> for U256Visitor { - type Value = String; + type Value = Option; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("a well formatted hex string") } + fn visit_some(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_string(U256Visitor) + } + + fn visit_none(self) -> Result + where + E: Error, + { + Ok(None) + } + fn visit_str(self, value: &str) -> Result where E: de::Error, @@ -35,11 +49,11 @@ impl<'de> Visitor<'de> for U256Visitor { stripped ))) } else if stripped == "0" { - Ok(value.to_string()) + Ok(Some(value.to_string())) } else if stripped.starts_with('0') { Err(de::Error::custom("cannot have leading zero")) } else { - Ok(value.to_string()) + Ok(Some(value.to_string())) } } } @@ -48,13 +62,14 @@ pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { - let decoded = deserializer.deserialize_string(U256Visitor)?; + let decoded = deserializer.deserialize_option(U256Visitor)?; - Some( - U256::from_str(&decoded) - .map_err(|e| de::Error::custom(format!("Invalid U256 string: {}", e))), - ) - .transpose() + decoded + .map(|decoded| { + U256::from_str(&decoded) + .map_err(|e| de::Error::custom(format!("Invalid U256 string: {}", e))) + }) + .transpose() } #[cfg(test)] @@ -161,6 +176,10 @@ mod test { val: Some(U256::max_value()) }, ); + assert_eq!( + serde_json::from_str::("null").unwrap(), + Wrapper { val: None }, + ); serde_json::from_str::("\"0x\"").unwrap_err(); serde_json::from_str::("\"0x0400\"").unwrap_err(); serde_json::from_str::("\"400\"").unwrap_err();