a3704b971e
## Issue Addressed - Resolves #1766 ## Proposed Changes - Use the `warp::filters::cors` filter instead of our work-around. ## Additional Info It's not trivial to enable/disable `cors` using `warp`, since using `routes.with(cors)` changes the type of `routes`. This makes it difficult to apply/not apply cors at runtime. My solution has been to *always* use the `warp::filters::cors` wrapper but when cors should be disabled, just pass the HTTP server listen address as the only permissible origin.
77 lines
2.4 KiB
Rust
77 lines
2.4 KiB
Rust
use std::net::Ipv4Addr;
|
|
use warp::filters::cors::Builder;
|
|
|
|
/// Configure a `cors::Builder`.
|
|
///
|
|
/// If `allow_origin.is_none()` the `default_origin` is used.
|
|
pub fn set_builder_origins(
|
|
builder: Builder,
|
|
allow_origin: Option<&str>,
|
|
default_origin: (Ipv4Addr, u16),
|
|
) -> Result<Builder, String> {
|
|
if let Some(allow_origin) = allow_origin {
|
|
let origins = allow_origin
|
|
.split(',')
|
|
.map(|s| verify_cors_origin_str(s).map(|_| s))
|
|
.collect::<Result<Vec<_>, _>>()?;
|
|
Ok(builder.allow_origins(origins))
|
|
} else {
|
|
let origin = format!("http://{}:{}", default_origin.0, default_origin.1);
|
|
verify_cors_origin_str(&origin)?;
|
|
|
|
Ok(builder.allow_origin(origin.as_str()))
|
|
}
|
|
}
|
|
|
|
/// Verify that `s` can be used as a CORS origin.
|
|
///
|
|
/// ## Notes
|
|
///
|
|
/// We need this function since `warp` will panic if provided an invalid origin. The verification
|
|
/// code is taken from here:
|
|
///
|
|
/// https://github.com/seanmonstar/warp/blob/3d1760c6ca35ce2d03dee0562259d0320e9face3/src/filters/cors.rs#L616
|
|
///
|
|
/// Ideally we should make a PR to `warp` to expose this behaviour, however we defer this for a
|
|
/// later time. The impact of a false-positive on this function is fairly limited, since only
|
|
/// trusted users should be setting CORS origins.
|
|
fn verify_cors_origin_str(s: &str) -> Result<(), String> {
|
|
// Always the wildcard origin.
|
|
if s == "*" {
|
|
return Ok(());
|
|
}
|
|
|
|
let mut parts = s.splitn(2, "://");
|
|
let scheme = parts
|
|
.next()
|
|
.ok_or_else(|| format!("{} is missing a scheme", s))?;
|
|
let rest = parts
|
|
.next()
|
|
.ok_or_else(|| format!("{} is missing the part following the scheme", s))?;
|
|
|
|
headers::Origin::try_from_parts(scheme, rest, None)
|
|
.map_err(|e| format!("Unable to parse {}: {}", s, e))
|
|
.map(|_| ())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn valid_origins() {
|
|
verify_cors_origin_str("*").unwrap();
|
|
verify_cors_origin_str("http://127.0.0.1").unwrap();
|
|
verify_cors_origin_str("http://localhost").unwrap();
|
|
verify_cors_origin_str("http://127.0.0.1:8000").unwrap();
|
|
verify_cors_origin_str("http://localhost:8000").unwrap();
|
|
}
|
|
|
|
#[test]
|
|
fn invalid_origins() {
|
|
verify_cors_origin_str(".*").unwrap_err();
|
|
verify_cors_origin_str("127.0.0.1").unwrap_err();
|
|
verify_cors_origin_str("localhost").unwrap_err();
|
|
}
|
|
}
|