diff --git a/Cargo.lock b/Cargo.lock index bdb822187..4d487ae70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5266,6 +5266,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde_array_query" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89c6e82b1005b33d5b2bbc47096800e5ad6b67ef5636f9c13ad29a6935734a7" +dependencies = [ + "serde", + "serde_urlencoded", +] + [[package]] name = "serde_cbor" version = "0.11.2" @@ -6823,6 +6833,7 @@ dependencies = [ "lighthouse_metrics", "safe_arith", "serde", + "serde_array_query", "state_processing", "tokio", "types", diff --git a/beacon_node/http_api/src/lib.rs b/beacon_node/http_api/src/lib.rs index 85c464466..b0907a30c 100644 --- a/beacon_node/http_api/src/lib.rs +++ b/beacon_node/http_api/src/lib.rs @@ -55,7 +55,10 @@ use warp::http::StatusCode; use warp::sse::Event; use warp::Reply; use warp::{http::Response, Filter}; -use warp_utils::task::{blocking_json_task, blocking_task}; +use warp_utils::{ + query::multi_key_query, + task::{blocking_json_task, blocking_task}, +}; const API_PREFIX: &str = "eth"; @@ -505,12 +508,13 @@ pub fn serve( .clone() .and(warp::path("validator_balances")) .and(warp::path::end()) - .and(warp::query::()) + .and(multi_key_query::()) .and_then( |state_id: StateId, chain: Arc>, - query: api_types::ValidatorBalancesQuery| { + query_res: Result| { blocking_json_task(move || { + let query = query_res?; state_id .map_state(&chain, |state| { Ok(state @@ -521,7 +525,7 @@ pub fn serve( // filter by validator id(s) if provided .filter(|(index, (validator, _))| { query.id.as_ref().map_or(true, |ids| { - ids.0.iter().any(|id| match id { + ids.iter().any(|id| match id { ValidatorId::PublicKey(pubkey) => { &validator.pubkey == pubkey } @@ -548,11 +552,14 @@ pub fn serve( let get_beacon_state_validators = beacon_states_path .clone() .and(warp::path("validators")) - .and(warp::query::()) .and(warp::path::end()) + .and(multi_key_query::()) .and_then( - |state_id: StateId, chain: Arc>, query: api_types::ValidatorsQuery| { + |state_id: StateId, + chain: Arc>, + query_res: Result| { blocking_json_task(move || { + let query = query_res?; state_id .map_state(&chain, |state| { let epoch = state.current_epoch(); @@ -566,7 +573,7 @@ pub fn serve( // filter by validator id(s) if provided .filter(|(index, (validator, _))| { query.id.as_ref().map_or(true, |ids| { - ids.0.iter().any(|id| match id { + ids.iter().any(|id| match id { ValidatorId::PublicKey(pubkey) => { &validator.pubkey == pubkey } @@ -586,8 +593,8 @@ pub fn serve( let status_matches = query.status.as_ref().map_or(true, |statuses| { - statuses.0.contains(&status) - || statuses.0.contains(&status.superstatus()) + statuses.contains(&status) + || statuses.contains(&status.superstatus()) }); if status_matches { @@ -1721,11 +1728,13 @@ pub fn serve( .and(warp::path("node")) .and(warp::path("peers")) .and(warp::path::end()) - .and(warp::query::()) + .and(multi_key_query::()) .and(network_globals.clone()) .and_then( - |query: api_types::PeersQuery, network_globals: Arc>| { + |query_res: Result, + network_globals: Arc>| { blocking_json_task(move || { + let query = query_res?; let mut peers: Vec = Vec::new(); network_globals .peers @@ -1755,11 +1764,11 @@ pub fn serve( ); let state_matches = query.state.as_ref().map_or(true, |states| { - states.0.iter().any(|state_param| *state_param == state) + states.iter().any(|state_param| *state_param == state) }); let direction_matches = query.direction.as_ref().map_or(true, |directions| { - directions.0.iter().any(|dir_param| *dir_param == direction) + directions.iter().any(|dir_param| *dir_param == direction) }); if state_matches && direction_matches { @@ -2534,16 +2543,18 @@ pub fn serve( let get_events = eth1_v1 .and(warp::path("events")) .and(warp::path::end()) - .and(warp::query::()) + .and(multi_key_query::()) .and(chain_filter) .and_then( - |topics: api_types::EventQuery, chain: Arc>| { + |topics_res: Result, + chain: Arc>| { blocking_task(move || { + let topics = topics_res?; // for each topic subscribed spawn a new subscription - let mut receivers = Vec::with_capacity(topics.topics.0.len()); + let mut receivers = Vec::with_capacity(topics.topics.len()); if let Some(event_handler) = chain.event_handler.as_ref() { - for topic in topics.topics.0.clone() { + for topic in topics.topics { let receiver = match topic { api_types::EventTopic::Head => event_handler.subscribe_head(), api_types::EventTopic::Block => event_handler.subscribe_block(), @@ -2606,8 +2617,8 @@ pub fn serve( .or(get_beacon_state_fork.boxed()) .or(get_beacon_state_finality_checkpoints.boxed()) .or(get_beacon_state_validator_balances.boxed()) - .or(get_beacon_state_validators.boxed()) .or(get_beacon_state_validators_id.boxed()) + .or(get_beacon_state_validators.boxed()) .or(get_beacon_state_committees.boxed()) .or(get_beacon_state_sync_committees.boxed()) .or(get_beacon_headers.boxed()) diff --git a/common/eth2/src/types.rs b/common/eth2/src/types.rs index be65dd877..169a8de59 100644 --- a/common/eth2/src/types.rs +++ b/common/eth2/src/types.rs @@ -428,10 +428,13 @@ pub struct AttestationPoolQuery { pub committee_index: Option, } -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] pub struct ValidatorsQuery { - pub id: Option>, - pub status: Option>, + #[serde(default, deserialize_with = "option_query_vec")] + pub id: Option>, + #[serde(default, deserialize_with = "option_query_vec")] + pub status: Option>, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -520,27 +523,68 @@ pub struct SyncingData { #[derive(Clone, PartialEq, Debug, Deserialize)] #[serde(try_from = "String", bound = "T: FromStr")] -pub struct QueryVec(pub Vec); +pub struct QueryVec { + values: Vec, +} + +fn query_vec<'de, D, T>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, + T: FromStr, +{ + let vec: Vec> = Deserialize::deserialize(deserializer)?; + Ok(Vec::from(QueryVec::from(vec))) +} + +fn option_query_vec<'de, D, T>(deserializer: D) -> Result>, D::Error> +where + D: serde::Deserializer<'de>, + T: FromStr, +{ + let vec: Vec> = Deserialize::deserialize(deserializer)?; + if vec.is_empty() { + return Ok(None); + } + + Ok(Some(Vec::from(QueryVec::from(vec)))) +} + +impl From>> for QueryVec { + fn from(vecs: Vec>) -> Self { + Self { + values: vecs.into_iter().flat_map(|qv| qv.values).collect(), + } + } +} impl TryFrom for QueryVec { type Error = String; fn try_from(string: String) -> Result { if string.is_empty() { - return Ok(Self(vec![])); + return Ok(Self { values: vec![] }); } - string - .split(',') - .map(|s| s.parse().map_err(|_| "unable to parse".to_string())) - .collect::, String>>() - .map(Self) + Ok(Self { + values: string + .split(',') + .map(|s| s.parse().map_err(|_| "unable to parse query".to_string())) + .collect::, String>>()?, + }) + } +} + +impl From> for Vec { + fn from(vec: QueryVec) -> Vec { + vec.values } } #[derive(Clone, Deserialize)] +#[serde(deny_unknown_fields)] pub struct ValidatorBalancesQuery { - pub id: Option>, + #[serde(default, deserialize_with = "option_query_vec")] + pub id: Option>, } #[derive(Clone, Serialize, Deserialize)] @@ -602,9 +646,12 @@ pub struct BeaconCommitteeSubscription { } #[derive(Deserialize)] +#[serde(deny_unknown_fields)] pub struct PeersQuery { - pub state: Option>, - pub direction: Option>, + #[serde(default, deserialize_with = "option_query_vec")] + pub state: Option>, + #[serde(default, deserialize_with = "option_query_vec")] + pub direction: Option>, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -858,8 +905,10 @@ impl EventKind { } #[derive(Clone, Deserialize)] +#[serde(deny_unknown_fields)] pub struct EventQuery { - pub topics: QueryVec, + #[serde(deserialize_with = "query_vec")] + pub topics: Vec, } #[derive(Debug, Clone, Copy, PartialEq, Deserialize)] @@ -961,7 +1010,9 @@ mod tests { fn query_vec() { assert_eq!( QueryVec::try_from("0,1,2".to_string()).unwrap(), - QueryVec(vec![0_u64, 1, 2]) + QueryVec { + values: vec![0_u64, 1, 2] + } ); } } diff --git a/common/warp_utils/Cargo.toml b/common/warp_utils/Cargo.toml index f99d7773b..09b6f125f 100644 --- a/common/warp_utils/Cargo.toml +++ b/common/warp_utils/Cargo.toml @@ -18,3 +18,4 @@ tokio = { version = "1.14.0", features = ["sync"] } headers = "0.3.2" lighthouse_metrics = { path = "../lighthouse_metrics" } lazy_static = "1.4.0" +serde_array_query = "0.1.0" diff --git a/common/warp_utils/src/lib.rs b/common/warp_utils/src/lib.rs index 5f37dde87..346361b18 100644 --- a/common/warp_utils/src/lib.rs +++ b/common/warp_utils/src/lib.rs @@ -3,5 +3,6 @@ pub mod cors; pub mod metrics; +pub mod query; pub mod reject; pub mod task; diff --git a/common/warp_utils/src/query.rs b/common/warp_utils/src/query.rs new file mode 100644 index 000000000..c5ed5c5f1 --- /dev/null +++ b/common/warp_utils/src/query.rs @@ -0,0 +1,22 @@ +use crate::reject::custom_bad_request; +use serde::Deserialize; +use warp::Filter; + +// Custom query filter using `serde_array_query`. +// This allows duplicate keys inside query strings. +pub fn multi_key_query<'de, T: Deserialize<'de>>( +) -> impl warp::Filter,), Error = std::convert::Infallible> + Copy +{ + raw_query().then(|query_str: String| async move { + serde_array_query::from_str(&query_str).map_err(|e| custom_bad_request(e.to_string())) + }) +} + +// This ensures that empty query strings are still accepted. +// This is because warp::filters::query::raw() does not allow empty query strings +// but warp::query::() does. +fn raw_query() -> impl Filter + Copy { + warp::filters::query::raw() + .or(warp::any().map(String::default)) + .unify() +}