From 117081cddd4cc56c6e70c38b59d718c168ec8116 Mon Sep 17 00:00:00 2001 From: Isvane <277444536+Isvane@users.noreply.github.com> Date: Fri, 29 May 2026 21:26:25 +0700 Subject: [PATCH 1/3] docs: fix typos in ServeDir, ServeFile, and conditional headers (#700) --- tower-http/src/services/fs/serve_dir/headers.rs | 4 ++-- tower-http/src/services/fs/serve_dir/mod.rs | 4 ++-- tower-http/src/services/fs/serve_dir/tests.rs | 2 +- tower-http/src/services/fs/serve_file.rs | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tower-http/src/services/fs/serve_dir/headers.rs b/tower-http/src/services/fs/serve_dir/headers.rs index e9e80907..e3a87a2c 100644 --- a/tower-http/src/services/fs/serve_dir/headers.rs +++ b/tower-http/src/services/fs/serve_dir/headers.rs @@ -18,7 +18,7 @@ impl IfModifiedSince { self.0 < last_modified.0 } - /// convert a header value into a IfModifiedSince, invalid values are silentely ignored + /// Convert a header value into a IfModifiedSince. Invalid values are silently ignored pub(super) fn from_header_value(value: &HeaderValue) -> Option { std::str::from_utf8(value.as_bytes()) .ok() @@ -35,7 +35,7 @@ impl IfUnmodifiedSince { self.0 >= last_modified.0 } - /// Convert a header value into a IfModifiedSince, invalid values are silentely ignored + /// Convert a header value into a IfUnmodifiedSince. Invalid values are silently ignored pub(super) fn from_header_value(value: &HeaderValue) -> Option { std::str::from_utf8(value.as_bytes()) .ok() diff --git a/tower-http/src/services/fs/serve_dir/mod.rs b/tower-http/src/services/fs/serve_dir/mod.rs index b243202c..bab51014 100644 --- a/tower-http/src/services/fs/serve_dir/mod.rs +++ b/tower-http/src/services/fs/serve_dir/mod.rs @@ -54,7 +54,7 @@ pub struct ServeDir { base: PathBuf, buf_chunk_size: usize, precompressed_variants: Option, - // This is used to specialise implementation for + // This is used to specialize implementation for // single files variant: ServeVariant, fallback: Option, @@ -296,7 +296,7 @@ impl ServeDir { /// let mut service = ServeDir::new("assets"); /// /// // You only need to worry about backpressure, and thus call `ServiceExt::ready`, if - /// // your adding a fallback to `ServeDir` that cares about backpressure. + /// // you are adding a fallback to `ServeDir` that cares about backpressure. /// // /// // Its shown here for demonstration but you can do `service.try_call(request)` /// // otherwise diff --git a/tower-http/src/services/fs/serve_dir/tests.rs b/tower-http/src/services/fs/serve_dir/tests.rs index 3dd9a086..0b9a6c78 100644 --- a/tower-http/src/services/fs/serve_dir/tests.rs +++ b/tower-http/src/services/fs/serve_dir/tests.rs @@ -75,7 +75,7 @@ async fn head_request() { } #[tokio::test] -async fn precompresed_head_request() { +async fn precompressed_head_request() { let svc = ServeDir::new(TEST_FILES_DIR).precompressed_gzip(); let req = Request::builder() diff --git a/tower-http/src/services/fs/serve_file.rs b/tower-http/src/services/fs/serve_file.rs index d3b5e2f0..efd216c5 100644 --- a/tower-http/src/services/fs/serve_file.rs +++ b/tower-http/src/services/fs/serve_file.rs @@ -208,7 +208,7 @@ mod tests { } #[tokio::test] - async fn precompresed_head_request() { + async fn precompressed_head_request() { let svc = ServeFile::new(format!("{TEST_FILES_DIR}/precompressed.txt")).precompressed_gzip(); From dd5b9289762ee368776cdbf7561fa0e57afe9379 Mon Sep 17 00:00:00 2001 From: Jess Izen <44884346+jlizen@users.noreply.github.com> Date: Fri, 29 May 2026 15:34:10 -0700 Subject: [PATCH 2/3] =?UTF-8?q?Fix:=20Handle=20wildcard=20*=20in=20Accept-?= =?UTF-8?q?Encoding=20per=20RFC=209110=20=C2=A712.5.3=20(#693)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix!: handle wildcard * in Accept-Encoding per RFC 9110 §12.5.3 * fix: replace let...else with match for MSRV compatibility --- tower-http/CHANGELOG.md | 7 + tower-http/src/compression/future.rs | 33 ++- tower-http/src/compression/mod.rs | 89 +++++++ tower-http/src/content_encoding.rs | 343 ++++++++++++++++++++++++--- 4 files changed, 433 insertions(+), 39 deletions(-) diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index d59ab922..2f31ac24 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -19,6 +19,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `DefaultOnEos` now explicitly parent their tracing events to the request span rather than relying on the ambient span context. This fixes intermittent cases where events could appear without their request span attached ([#655]) +- **breaking:** `compression`: the middleware now handles the `*` wildcard and + `identity;q=0` in Accept-Encoding per RFC 9110 §12.5.3. Requests that + previously fell back to identity (e.g. `*;q=0` or `identity;q=0` with no + other acceptable encoding) now receive a 406 Not Acceptable response. Clients + that explicitly reject all encodings without listing an alternative will see + different behavior. ([#215]) - The implicit `tokio` and `async-compression` features are removed (BREAKING). These were kept as no-op features in 0.6.x for backwards compatibility after the switch to `dep:` syntax in [#642]. Downstream crates that activate @@ -27,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 by the features that need them (e.g. `compression-gzip`, `fs`, `timeout`). ([#628]) +[#215]: https://github.com/tower-rs/tower-http/issues/215 [#628]: https://github.com/tower-rs/tower-http/pull/628 [#642]: https://github.com/tower-rs/tower-http/pull/642 diff --git a/tower-http/src/compression/future.rs b/tower-http/src/compression/future.rs index 3e899a73..e19ecd6f 100644 --- a/tower-http/src/compression/future.rs +++ b/tower-http/src/compression/future.rs @@ -22,7 +22,7 @@ pin_project! { pub struct ResponseFuture { #[pin] pub(crate) inner: F, - pub(crate) encoding: Encoding, + pub(crate) encoding: Option, pub(crate) predicate: P, pub(crate) quality: CompressionLevel, } @@ -39,6 +39,33 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let res = ready!(self.as_mut().project().inner.poll(cx)?); + let encoding = match self.encoding { + Some(enc) => enc, + None => { + // RFC 9110 §12.5.3: the server SHOULD respond with 406 Not Acceptable + // when no encoding is satisfiable. This middleware chooses to enforce it. + // + // Note: the inner service has already been called, so its response body and + // headers are passed through. Only the status code is overwritten. + let mut res = res; + *res.status_mut() = http::StatusCode::NOT_ACCEPTABLE; + if !res.headers().get_all(header::VARY).iter().any(|value| { + contains_ignore_ascii_case( + value.as_bytes(), + header::ACCEPT_ENCODING.as_str().as_bytes(), + ) + }) { + res.headers_mut() + .append(header::VARY, header::ACCEPT_ENCODING.into()); + } + let (parts, body) = res.into_parts(); + return Poll::Ready(Ok(Response::from_parts( + parts, + CompressionBody::new(BodyInner::identity(body)), + ))); + } + }; + // never recompress responses that are already compressed let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING) // never compress responses that are ranges @@ -60,7 +87,7 @@ where .append(header::VARY, header::ACCEPT_ENCODING.into()); } - let body = match (should_compress, self.encoding) { + let body = match (should_compress, encoding) { // if compression is _not_ supported or the client doesn't accept it (false, _) | (_, Encoding::Identity) => { return Poll::Ready(Ok(Response::from_parts( @@ -114,7 +141,7 @@ where parts .headers - .insert(header::CONTENT_ENCODING, self.encoding.into_header_value()); + .insert(header::CONTENT_ENCODING, encoding.into_header_value()); let res = Response::from_parts(parts, body); Poll::Ready(Ok(res)) diff --git a/tower-http/src/compression/mod.rs b/tower-http/src/compression/mod.rs index c32fc579..044fca51 100644 --- a/tower-http/src/compression/mod.rs +++ b/tower-http/src/compression/mod.rs @@ -616,4 +616,93 @@ mod tests { let body = res.into_body(); assert_eq!(body.size_hint().exact().unwrap(), msg.len() as u64); } + + #[tokio::test] + async fn wildcard_q_zero_returns_406() { + let svc = service_fn(handle); + let mut svc = Compression::new(svc).compress_when(Always); + + let req = Request::builder() + .header("accept-encoding", "*;q=0") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + assert_eq!(res.status(), http::StatusCode::NOT_ACCEPTABLE); + assert!(res + .headers() + .get_all(http::header::VARY) + .iter() + .any(|v| v.to_str().unwrap().contains("accept-encoding"))); + } + + #[tokio::test] + async fn wildcard_q_zero_with_gzip_picks_gzip() { + let svc = service_fn(handle); + let mut svc = Compression::new(svc).compress_when(Always); + + let req = Request::builder() + .header("accept-encoding", "*;q=0,gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + assert_eq!(res.status(), http::StatusCode::OK); + assert_eq!( + res.headers() + .get("content-encoding") + .and_then(|v| v.to_str().ok()), + Some("gzip") + ); + } + + #[tokio::test] + async fn wildcard_alone_compresses() { + let svc = service_fn(handle); + let mut svc = Compression::new(svc).compress_when(Always); + + let req = Request::builder() + .header("accept-encoding", "*") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + assert_eq!(res.status(), http::StatusCode::OK); + // Should pick the best supported encoding (not identity) + assert!(res.headers().contains_key(CONTENT_ENCODING)); + } + + #[tokio::test] + async fn identity_q_zero_alone_returns_406() { + let svc = service_fn(handle); + let mut svc = Compression::new(svc).compress_when(Always); + + let req = Request::builder() + .header("accept-encoding", "identity;q=0") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + assert_eq!(res.status(), http::StatusCode::NOT_ACCEPTABLE); + } + + #[tokio::test] + async fn identity_q_zero_with_gzip_picks_gzip() { + let svc = service_fn(handle); + let mut svc = Compression::new(svc).compress_when(Always); + + let req = Request::builder() + .header("accept-encoding", "identity;q=0,gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + assert_eq!(res.status(), http::StatusCode::OK); + assert_eq!( + res.headers() + .get("content-encoding") + .and_then(|v| v.to_str().ok()), + Some("gzip") + ); + } } diff --git a/tower-http/src/content_encoding.rs b/tower-http/src/content_encoding.rs index 91c21d45..18ccacb1 100644 --- a/tower-http/src/content_encoding.rs +++ b/tower-http/src/content_encoding.rs @@ -96,12 +96,14 @@ impl Encoding { feature = "compression-deflate", ))] // based on https://github.com/http-rs/accept-encoding + // + // Returns `Some(encoding)` for the best acceptable encoding, or `None` if the client's + // preferences cannot be satisfied (406 Not Acceptable per RFC 9110 §12.5.3). pub(crate) fn from_headers( headers: &http::HeaderMap, supported_encoding: impl SupportedEncodings, - ) -> Self { - Encoding::preferred_encoding(encodings(headers, supported_encoding)) - .unwrap_or(Encoding::Identity) + ) -> Option { + preferred_encoding_with_wildcard(headers, supported_encoding) } #[cfg(any( @@ -240,6 +242,135 @@ pub(crate) fn encodings<'a>( }) } +/// Extracts the q-value for the `*` wildcard from Accept-Encoding headers. +/// Returns `None` if no wildcard is present. +#[cfg(any( + feature = "compression-gzip", + feature = "compression-br", + feature = "compression-zstd", + feature = "compression-deflate", +))] +fn wildcard_qvalue(headers: &http::HeaderMap) -> Option { + headers + .get_all(http::header::ACCEPT_ENCODING) + .iter() + .filter_map(|hval| hval.to_str().ok()) + .flat_map(|s| s.split(',')) + .find_map(|v| { + let mut v = v.splitn(2, ';'); + let coding = v.next().unwrap().trim(); + if coding != "*" { + return None; + } + let qval = if let Some(qval) = v.next() { + QValue::parse(qval.trim())? + } else { + QValue::one() + }; + Some(qval) + }) +} + +/// Selects the preferred encoding considering the `*` wildcard per RFC 9110 §12.5.3. +/// +/// The wildcard applies its q-value to any encoding not explicitly listed. If all acceptable +/// encodings (including identity) are excluded, returns `None` to signal 406 Not Acceptable. +#[cfg(any( + feature = "compression-gzip", + feature = "compression-br", + feature = "compression-zstd", + feature = "compression-deflate", +))] +fn preferred_encoding_with_wildcard( + headers: &http::HeaderMap, + supported_encoding: impl SupportedEncodings, +) -> Option { + let explicit: Vec<(Encoding, QValue)> = encodings(headers, supported_encoding).collect(); + let wildcard_q = wildcard_qvalue(headers); + + // If there is no wildcard, use only the explicitly listed encodings. + // Per RFC 9110 §12.5.3, if identity is excluded (q=0) and no other encoding is + // acceptable, the server SHOULD respond with 406. + let wildcard_q = match wildcard_q { + Some(q) => q, + None => { + let identity_rejected = explicit + .iter() + .any(|(enc, q)| *enc == Encoding::Identity && q.0 == 0); + return match Encoding::preferred_encoding(explicit.into_iter()) { + Some(enc) => Some(enc), + None => { + if identity_rejected { + None + } else { + Some(Encoding::Identity) + } + } + }; + } + }; + + // Build the effective set of (encoding, qvalue) for all supported encodings. + // For each supported encoding, use its explicit q-value if listed, otherwise the wildcard + // q-value. + let all_supported = all_supported_encodings(supported_encoding); + + let effective = all_supported.iter().filter_map(|e| *e).map(|enc| { + let q = explicit + .iter() + .find(|(e, _)| *e == enc) + .map(|(_, q)| *q) + .unwrap_or(wildcard_q); + (enc, q) + }); + + Encoding::preferred_encoding(effective) +} + +/// Returns all encodings the server supports (including Identity) in a fixed-capacity array. +#[cfg(any( + feature = "compression-gzip", + feature = "compression-br", + feature = "compression-zstd", + feature = "compression-deflate", +))] +fn all_supported_encodings(supported_encoding: impl SupportedEncodings) -> [Option; 5] { + let mut out: [Option; 5] = [None; 5]; + let mut n = 0; + + macro_rules! push { + ($enc:expr) => { + out[n] = Some($enc); + n += 1; + }; + } + + push!(Encoding::Identity); + + #[cfg(any(feature = "fs", feature = "compression-gzip"))] + if supported_encoding.gzip() { + push!(Encoding::Gzip); + } + + #[cfg(any(feature = "fs", feature = "compression-deflate"))] + if supported_encoding.deflate() { + push!(Encoding::Deflate); + } + + #[cfg(any(feature = "fs", feature = "compression-br"))] + if supported_encoding.br() { + push!(Encoding::Brotli); + } + + #[cfg(any(feature = "fs", feature = "compression-zstd"))] + if supported_encoding.zstd() { + push!(Encoding::Zstd); + } + + let _ = n; + out +} + #[cfg(all( test, feature = "compression-gzip", @@ -274,7 +405,7 @@ mod tests { #[test] fn no_accept_encoding_header() { let encoding = Encoding::from_headers(&http::HeaderMap::new(), SupportedEncodingsAll); - assert_eq!(Encoding::Identity, encoding); + assert_eq!(Some(Encoding::Identity), encoding); } #[test] @@ -285,7 +416,7 @@ mod tests { http::HeaderValue::from_static("gzip"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Gzip, encoding); + assert_eq!(Some(Encoding::Gzip), encoding); } #[test] @@ -296,7 +427,7 @@ mod tests { http::HeaderValue::from_static("gzip,br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); } #[test] @@ -307,7 +438,7 @@ mod tests { http::HeaderValue::from_static("gzip,x-gzip"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Gzip, encoding); + assert_eq!(Some(Encoding::Gzip), encoding); } #[test] @@ -318,7 +449,7 @@ mod tests { http::HeaderValue::from_static("deflate,x-gzip"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Gzip, encoding); + assert_eq!(Some(Encoding::Gzip), encoding); } #[test] @@ -329,7 +460,7 @@ mod tests { http::HeaderValue::from_static("gzip,deflate,br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); } #[test] @@ -340,7 +471,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0.5,br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); } #[test] @@ -351,7 +482,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0.5,deflate,br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); } #[test] @@ -366,7 +497,7 @@ mod tests { http::HeaderValue::from_static("br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); } #[test] @@ -381,7 +512,7 @@ mod tests { http::HeaderValue::from_static("br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); } #[test] @@ -400,7 +531,7 @@ mod tests { http::HeaderValue::from_static("br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); } #[test] @@ -411,7 +542,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0.5,br;q=0.8"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -419,7 +550,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0.8,br;q=0.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Gzip, encoding); + assert_eq!(Some(Encoding::Gzip), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -427,7 +558,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0.995,br;q=0.999"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); } #[test] @@ -438,7 +569,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0.5,deflate;q=0.6,br;q=0.8"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -446,7 +577,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0.8,deflate;q=0.6,br;q=0.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Gzip, encoding); + assert_eq!(Some(Encoding::Gzip), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -454,7 +585,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0.6,deflate;q=0.8,br;q=0.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Deflate, encoding); + assert_eq!(Some(Encoding::Deflate), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -462,7 +593,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0.995,deflate;q=0.997,br;q=0.999"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); } #[test] @@ -473,7 +604,7 @@ mod tests { http::HeaderValue::from_static("invalid,gzip"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Gzip, encoding); + assert_eq!(Some(Encoding::Gzip), encoding); } #[test] @@ -484,7 +615,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Identity, encoding); + assert_eq!(Some(Encoding::Identity), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -492,7 +623,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0."), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Identity, encoding); + assert_eq!(Some(Encoding::Identity), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -500,7 +631,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0,br;q=0.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); } #[test] @@ -511,7 +642,7 @@ mod tests { http::HeaderValue::from_static("gZiP"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Gzip, encoding); + assert_eq!(Some(Encoding::Gzip), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -519,7 +650,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0.5,br;Q=0.8"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); } #[test] @@ -530,7 +661,7 @@ mod tests { http::HeaderValue::from_static(" gzip\t; q=0.5 ,\tbr ;\tq=0.8\t"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Brotli, encoding); + assert_eq!(Some(Encoding::Brotli), encoding); } #[test] @@ -541,7 +672,7 @@ mod tests { http::HeaderValue::from_static("gzip;q =0.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Identity, encoding); + assert_eq!(Some(Encoding::Identity), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -549,7 +680,7 @@ mod tests { http::HeaderValue::from_static("gzip;q= 0.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Identity, encoding); + assert_eq!(Some(Encoding::Identity), encoding); } #[test] @@ -560,7 +691,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=-0.1"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Identity, encoding); + assert_eq!(Some(Encoding::Identity), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -568,7 +699,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=00.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Identity, encoding); + assert_eq!(Some(Encoding::Identity), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -576,7 +707,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=0.5000"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Identity, encoding); + assert_eq!(Some(Encoding::Identity), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -584,7 +715,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Identity, encoding); + assert_eq!(Some(Encoding::Identity), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -592,7 +723,7 @@ mod tests { http::HeaderValue::from_static("gzip;q=1.01"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Identity, encoding); + assert_eq!(Some(Encoding::Identity), encoding); let mut headers = http::HeaderMap::new(); headers.append( @@ -600,6 +731,146 @@ mod tests { http::HeaderValue::from_static("gzip;q=1.001"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); - assert_eq!(Encoding::Identity, encoding); + assert_eq!(Some(Encoding::Identity), encoding); + } + + #[test] + fn wildcard_alone_picks_best_supported() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("*"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + // * with q=1 means all encodings are acceptable; picks the highest-priority supported + assert_eq!(Some(Encoding::Zstd), encoding); + } + + #[test] + fn wildcard_q_zero_with_nothing_else_returns_not_satisfiable() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("*;q=0"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + // *;q=0 rejects everything, including identity + assert_eq!(None, encoding); + } + + #[test] + fn wildcard_q_zero_with_gzip_picks_gzip() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("*;q=0,gzip"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Some(Encoding::Gzip), encoding); + } + + #[test] + fn identity_q_zero_alone_returns_not_satisfiable() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("identity;q=0"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + // identity;q=0 with no other encoding explicitly listed: the server cannot + // determine what the client accepts, so 406 per RFC 9110 §12.5.3 + assert_eq!(None, encoding); + } + + #[test] + fn identity_q_zero_with_gzip_picks_gzip() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("identity;q=0,gzip"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + assert_eq!(Some(Encoding::Gzip), encoding); + } + + #[test] + fn wildcard_q_zero_identity_q_zero_no_compression_returns_not_satisfiable() { + // *;q=0,identity;q=0 with no explicit compression listed + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("*;q=0,identity;q=0"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + // Both wildcard and identity are q=0, and no explicit encoding is listed with q>0 + assert_eq!(None, encoding); + } + + #[test] + fn wildcard_with_low_qvalue() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("*;q=0.5,gzip;q=1"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + // gzip is explicitly q=1, everything else gets q=0.5 from wildcard + assert_eq!(Some(Encoding::Gzip), encoding); + } + + #[test] + fn wildcard_q_zero_with_identity_picks_identity() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("*;q=0,identity"), + ); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + // *;q=0 rejects all, but identity is explicitly listed with q=1 + assert_eq!(Some(Encoding::Identity), encoding); + } + + #[derive(Copy, Clone)] + struct SupportedGzipOnly; + + impl SupportedEncodings for SupportedGzipOnly { + fn gzip(&self) -> bool { + true + } + fn deflate(&self) -> bool { + false + } + fn br(&self) -> bool { + false + } + fn zstd(&self) -> bool { + false + } + } + + #[test] + fn wildcard_with_partial_server_support_picks_best_available() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("*"), + ); + let encoding = Encoding::from_headers(&headers, SupportedGzipOnly); + // Server only supports gzip, so * should pick gzip (not zstd/br) + assert_eq!(Some(Encoding::Gzip), encoding); + } + + #[test] + fn wildcard_q_zero_with_unsupported_encoding_returns_not_satisfiable() { + let mut headers = http::HeaderMap::new(); + headers.append( + http::header::ACCEPT_ENCODING, + http::HeaderValue::from_static("*;q=0,br"), + ); + let encoding = Encoding::from_headers(&headers, SupportedGzipOnly); + // Client wants br, but server only supports gzip. br is not in the + // supported set so it's ignored by encodings(). Wildcard rejects + // everything else. Result: 406. + assert_eq!(None, encoding); } } From 56d5b6b9e785f36dc9e4b8ad7c80570e53f7bfb5 Mon Sep 17 00:00:00 2001 From: Joern Barthel Date: Sat, 30 May 2026 19:37:23 +0200 Subject: [PATCH 3/3] feat(csrf): add cross-origin protection middleware (#699) * feat(csrf): add cross-origin protection middleware Ports the CSRF protection scheme introduced in Go 1.25 (described in Filippo Valsorda's blog post) as a new optional `csrf` feature. The middleware combines `Sec-Fetch-Site`, an `Origin` allow-list, and an `Origin`/`Host` fallback to reject cross-origin state-changing requests without per-request token state. * fix(csrf): strip query string from trace log URIs * feat(csrf): add #[must_use] to Csrf service * chore(csrf): made example slightly better formatted * test(csrf): exercise Service::call() flow via oneshot * chore(csrf): fix style * refactor(csrf): inline is_exempt as a closure inside verify * feat(csrf): allow customizing the rejection response * docs(csrf): document UriExt and TrustedOrigin * test(csrf): cover custom rejection response on allowed request * test(csrf): explain why the safe-method check avoids Method::is_safe * refactor(csrf): match Go with strict byte-level origin comparison * refactor(csrf): bound with_rejection_response on Clone for clearer errors * refactor(csrf): make ProtectionError an opaque struct over a kind enum * refactor(csrf): attach ProtectionError in the service so custom rejections keep it * fix(csrf): prefer request-target authority over Host header in same-origin check * style(csrf): rustfmt service.rs --- tower-http/Cargo.toml | 2 + tower-http/src/csrf/future.rs | 71 +++ tower-http/src/csrf/layer.rs | 147 +++++ tower-http/src/csrf/mod.rs | 1008 +++++++++++++++++++++++++++++++ tower-http/src/csrf/response.rs | 45 ++ tower-http/src/csrf/service.rs | 216 +++++++ tower-http/src/csrf/url.rs | 160 +++++ tower-http/src/lib.rs | 3 + 8 files changed, 1652 insertions(+) create mode 100644 tower-http/src/csrf/future.rs create mode 100644 tower-http/src/csrf/layer.rs create mode 100644 tower-http/src/csrf/mod.rs create mode 100644 tower-http/src/csrf/response.rs create mode 100644 tower-http/src/csrf/service.rs create mode 100644 tower-http/src/csrf/url.rs diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index e908bd56..632c77f3 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -64,6 +64,7 @@ full = [ "catch-panic", "compression-full", "cors", + "csrf", "decompression-full", "follow-redirect", "fs", @@ -89,6 +90,7 @@ add-extension = [] auth = ["base64", "validate-request"] catch-panic = ["tracing", "futures-util/std", "dep:http-body", "dep:http-body-util"] cors = [] +csrf = [] follow-redirect = ["futures-util", "dep:http-body", "dep:url", "tower/util"] fs = ["dep:tokio", "tokio?/fs", "tokio?/io-util", "futures-core", "futures-util", "dep:http-body", "dep:http-body-util", "tokio-util/io", "dep:http-range-header", "mime_guess", "mime", "httpdate", "set-status", "futures-util/alloc"] limit = ["dep:http-body", "dep:http-body-util"] diff --git a/tower-http/src/csrf/future.rs b/tower-http/src/csrf/future.rs new file mode 100644 index 00000000..bb2bfef4 --- /dev/null +++ b/tower-http/src/csrf/future.rs @@ -0,0 +1,71 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use pin_project_lite::pin_project; + +pin_project! { + /// Response future for [`Csrf`]. + /// + /// [`Csrf`]: super::Csrf + pub struct ResponseFuture + where + F: Future, + { + #[pin] + kind: Kind, + } +} + +pin_project! { + #[project = KindProj] + enum Kind + where + F: Future, + { + Future { + #[pin] + future: F, + }, + Rejected { + response: Option, + }, + } +} + +impl ResponseFuture +where + F: Future, +{ + pub(super) fn future(future: F) -> Self { + Self { + kind: Kind::Future { future }, + } + } + + pub(super) fn rejected(response: F::Output) -> Self { + Self { + kind: Kind::Rejected { + response: Some(response), + }, + } + } +} + +impl Future for ResponseFuture +where + F: Future, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().kind.project() { + KindProj::Future { future } => future.poll(cx), + KindProj::Rejected { response } => Poll::Ready( + response + .take() + .expect("ResponseFuture polled after completion"), + ), + } + } +} diff --git a/tower-http/src/csrf/layer.rs b/tower-http/src/csrf/layer.rs new file mode 100644 index 00000000..86b5249b --- /dev/null +++ b/tower-http/src/csrf/layer.rs @@ -0,0 +1,147 @@ +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; + +use http::{Method, Uri}; +use tower_layer::Layer; + +use super::service::Csrf; +use super::url::UriExt; +use super::{BypassFn, ConfigError, DebugFn, DefaultResponseForProtectionError, Origins}; + +/// Layer that applies the [`Csrf`] middleware. +/// +/// See the [module docs](crate::csrf) for an example. +#[derive(Clone)] +#[must_use] +pub struct CsrfLayer { + insecure_bypass: Option>, + rejection_response: T, + trusted_origins: Origins, +} + +impl Default for CsrfLayer { + fn default() -> Self { + Self { + insecure_bypass: None, + rejection_response: DefaultResponseForProtectionError, + trusted_origins: Origins::default(), + } + } +} + +impl Debug for CsrfLayer { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("CsrfLayer") + .field( + "insecure_bypass", + &self.insecure_bypass.as_ref().map(|_| DebugFn), + ) + .field("trusted_origins", &self.trusted_origins) + .field("rejection_response", &DebugFn) + .finish() + } +} + +impl CsrfLayer { + /// Creates a new `CsrfLayer` with no trusted origins, no bypass, and the + /// default rejection response. + pub fn new() -> Self { + Self::default() + } +} + +impl CsrfLayer { + /// Adds a trusted origin that allows all requests whose `Origin` header + /// matches the given value. + /// + /// The value is matched **byte-for-byte** against the request's `Origin` + /// header — there is no normalization (this mirrors the Go reference). It + /// must therefore be written exactly as a browser sends it: + /// + /// - form `scheme://host[:port]`, where `scheme` is `http` or `https`; + /// - the host lowercased (browsers lowercase it; IDN hosts must be given in + /// punycode, e.g. `xn--exmple-cua.com`); + /// - **default ports omitted** — browsers drop `:80`/`:443`, so an explicit + /// default port (e.g. `https://example.com:443`) will never match; + /// - **no trailing slash**, path, query, or fragment. + /// + /// Inputs that can't represent a browser `Origin` are rejected with a + /// [`ConfigError`]; inputs that parse but aren't in the canonical browser + /// form above are accepted but will silently never match. + /// + /// ``` + /// # use tower_http::csrf::CsrfLayer; + /// // Matches `Origin: https://example.com`: + /// let layer = CsrfLayer::new().add_trusted_origin("https://example.com")?; + /// + /// // Accepted, but never matches a browser Origin (explicit default port): + /// let layer = CsrfLayer::new().add_trusted_origin("https://example.com:443")?; + /// # Ok::<_, tower_http::csrf::ConfigError>(()) + /// ``` + pub fn add_trusted_origin>(mut self, origin: S) -> Result { + let origin = origin.as_ref(); + + // validate the form; the origin is stored and matched verbatim. + Uri::parse_origin(origin)?; + + #[cfg(feature = "tracing")] + tracing::debug!(origin = %origin, "added trusted origin"); + + self.trusted_origins.insert(origin.to_owned()); + + Ok(self) + } + + /// Adds a bypass predicate that returns `true` for requests which should + /// skip CSRF protection. + /// + /// This is an escape hatch for endpoints that legitimately need to accept + /// cross-origin POSTs (e.g. webhook receivers). Bypassed endpoints must + /// have their own protection (signed payloads, authentication tokens, + /// etc.) — otherwise they are CSRF-vulnerable. + pub fn with_insecure_bypass(mut self, predicate: F) -> Self + where + F: Fn(&Method, &Uri) -> bool + Send + Sync + 'static, + { + #[cfg(feature = "tracing")] + tracing::debug!("added insecure bypass"); + + self.insecure_bypass = Some(Arc::new(predicate)); + self + } + + /// Replaces the response builder used when a request is rejected. + /// + /// Accepts any type that implements [`ResponseForProtectionError`](super::ResponseForProtectionError), + /// including a `FnMut(ProtectionError) -> Response + Clone` closure. + /// The default builder returns a `403 Forbidden` with an empty body. + /// Regardless of the builder, [`Csrf`](super::Csrf) attaches the + /// [`ProtectionError`](super::ProtectionError) to the response's extensions, + /// so a custom builder need not re-attach it. + pub fn with_rejection_response(self, rejection_response: R) -> CsrfLayer + where + R: Clone, + { + CsrfLayer { + insecure_bypass: self.insecure_bypass, + trusted_origins: self.trusted_origins, + rejection_response, + } + } +} + +impl Layer for CsrfLayer +where + T: Clone, +{ + type Service = Csrf; + + fn layer(&self, inner: S) -> Self::Service { + Csrf::new( + inner, + self.insecure_bypass.clone(), + self.rejection_response.clone(), + self.trusted_origins.clone(), + ) + } +} diff --git a/tower-http/src/csrf/mod.rs b/tower-http/src/csrf/mod.rs new file mode 100644 index 00000000..720d52a3 --- /dev/null +++ b/tower-http/src/csrf/mod.rs @@ -0,0 +1,1008 @@ +//! Modern protection against [cross-site request forgery] (CSRF) attacks. +//! +//! This middleware implements the CSRF protection scheme [introduced in Go 1.25][go] +//! and described in [Filippo Valsorda's blog post][filippo]. It relies on the +//! [`Sec-Fetch-Site`] and [`Origin`] request headers and requires no +//! per-request token state. +//! +//! Requests are allowed if any of the following hold: +//! +//! 1. The method is `GET`, `HEAD`, or `OPTIONS`. +//! 2. The `Origin` header byte-for-byte matches an allow-listed trusted origin. +//! 3. `Sec-Fetch-Site` is `same-origin` or `none`. +//! 4. Neither `Sec-Fetch-Site` nor `Origin` is present. +//! 5. The `Origin`'s authority (host and any port) matches the request's effective +//! host byte-for-byte (the request-target authority if present, else `Host`). +//! +//! Rejected requests receive a `403 Forbidden` response. The originating +//! [`ProtectionError`] is attached to the response's extensions — on every +//! rejection, including those from a custom builder — so surrounding layers can +//! distinguish explicit cross-origin rejections from conservative fallback +//! rejections (e.g. requests from old browsers without `Sec-Fetch-Site`). Use +//! [`CsrfLayer::with_rejection_response`](CsrfLayer::with_rejection_response) +//! to replace the rejection response with a custom builder. +//! +//! # Deployment caveat +//! +//! The middleware trusts whatever `Origin` and `Host` reach it. Reverse proxies +//! and load balancers that rewrite `Host` (e.g. to an internal hostname) or +//! strip `Origin` silently degrade the protection: the `Origin`/`Host` +//! fallback can no longer match, and `Sec-Fetch-Site` becomes the only +//! remaining line of defense. Configure intermediaries to forward both headers +//! unchanged. +//! +//! # Example +//! +//! ``` +//! use bytes::Bytes; +//! use http::{Request, Response, StatusCode}; +//! use http_body_util::Full; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; +//! use tower_http::csrf::CsrfLayer; +//! +//! async fn handle(_: Request>) -> Result>, BoxError> { +//! Ok(Response::new(Full::default())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! let layer = CsrfLayer::new() +//! .add_trusted_origin("https://example.com")?; +//! +//! let mut service = ServiceBuilder::new() +//! .layer(layer) +//! .service_fn(handle); +//! +//! // Safe methods always pass. +//! let request = Request::builder() +//! .method("GET") +//! .uri("/") +//! .body(Full::default()) +//! .unwrap(); +//! +//! let response = service.ready().await?.call(request).await?; +//! +//! assert_eq!(response.status(), StatusCode::OK); +//! +//! // Cross-site POSTs are blocked. +//! let request = Request::builder() +//! .method("POST") +//! .uri("/") +//! .header("host", "example.com") +//! .header("sec-fetch-site", "cross-site") +//! .body(Full::default()) +//! .unwrap(); +//! +//! let response = service.ready().await?.call(request).await?; +//! +//! assert_eq!(response.status(), StatusCode::FORBIDDEN); +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! [cross-site request forgery]: https://developer.mozilla.org/en-US/docs/Glossary/CSRF +//! [filippo]: https://words.filippo.io/csrf/ +//! [go]: https://pkg.go.dev/net/http#CrossOriginProtection +//! [`Sec-Fetch-Site`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site +//! [`Origin`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin + +use std::collections::HashSet; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; + +use http::{Method, Uri}; + +mod future; +mod layer; +mod response; +mod service; +mod url; + +pub use self::future::ResponseFuture; +pub use self::layer::CsrfLayer; +pub use self::response::{DefaultResponseForProtectionError, ResponseForProtectionError}; +pub use self::service::Csrf; + +/// Errors that can occur while configuring [`CsrfLayer`]. +#[derive(Clone, Debug, PartialEq)] +#[non_exhaustive] +pub enum ConfigError { + /// The origin string could not be parsed as a URI. + InvalidOriginUrl { + /// The offending origin string. + origin: String, + /// The parser error message. + message: String, + }, + + /// An origin URL containing a path, query, or fragment was added as a + /// trusted origin. + InvalidOriginUrlComponents { + /// The offending origin string. + origin: String, + }, + + /// An origin with a scheme other than `http` or `https` (e.g. `file://`, + /// `mailto:`, or a bare host with no scheme) was added as a trusted + /// origin. Such origins can never match a browser-supplied request + /// `Origin`. + OpaqueOrigin { + /// The offending origin string. + origin: String, + }, + + /// A trusted origin contained non-ASCII characters. Browsers send IDN + /// hostnames in punycode form, so the configured value must use the + /// punycode form (e.g. `xn--exmple-cua.com`) to ever match. + NonAsciiHostname { + /// The offending origin string. + origin: String, + }, +} + +impl fmt::Display for ConfigError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + ConfigError::InvalidOriginUrl { origin, message } => { + write!(f, "invalid origin {origin:?}: {message}") + } + ConfigError::InvalidOriginUrlComponents { origin } => write!( + f, + "invalid origin {origin:?}: path, query, and fragment are not allowed" + ), + ConfigError::OpaqueOrigin { origin } => write!( + f, + "invalid origin {origin:?}: scheme must be http or https" + ), + ConfigError::NonAsciiHostname { origin } => write!( + f, + "invalid origin {origin:?}: non-ASCII hostnames must be supplied in punycode (xn--…)" + ), + } + } +} + +impl std::error::Error for ConfigError {} + +/// Reason a request was rejected by [`Csrf`]. +/// +/// Retrieve the category with [`ProtectionError::kind`]. [`Csrf`] attaches it to +/// every `403 Forbidden` rejection response's extensions so surrounding layers +/// can distinguish explicit cross-origin rejections from conservative fallback +/// rejections. +/// +/// This is an opaque struct rather than an enum so future variants can carry +/// additional context without a breaking change; match on [`kind`] instead. +/// +/// [`kind`]: ProtectionError::kind +#[derive(Clone, Debug)] +pub struct ProtectionError { + kind: ProtectionErrorKind, +} + +impl ProtectionError { + pub(crate) fn new(kind: ProtectionErrorKind) -> Self { + Self { kind } + } + + /// The category of rejection. + pub fn kind(&self) -> ProtectionErrorKind { + self.kind + } +} + +impl fmt::Display for ProtectionError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self.kind { + ProtectionErrorKind::CrossOriginRequest => f.write_str("Cross-Origin request detected"), + ProtectionErrorKind::CrossOriginRequestFromOldBrowser => { + f.write_str("Cross-Origin request from old browser detected") + } + } + } +} + +impl std::error::Error for ProtectionError {} + +/// The category of a [`ProtectionError`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum ProtectionErrorKind { + /// A cross-origin request was detected via `Sec-Fetch-Site`. + CrossOriginRequest, + + /// A request without `Sec-Fetch-Site` failed the `Origin`/`Host` fallback + /// check. Modern browsers always send `Sec-Fetch-Site`, so this typically + /// means the request came from an old browser or non-browser client. + CrossOriginRequestFromOldBrowser, +} + +type BypassFn = dyn Fn(&Method, &Uri) -> bool + Send + Sync + 'static; + +struct DebugFn; + +impl Debug for DebugFn { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("") + } +} + +#[derive(Clone, Default)] +struct Origins(Arc>>); + +impl Origins { + fn contains(&self, origin: &[u8]) -> bool { + self.0.contains(origin) + } + + fn insert(&mut self, origin: impl Into>) { + Arc::make_mut(&mut self.0).insert(origin.into()); + } +} + +impl Debug for Origins { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // render trusted origins as utf-h strings + write!(f, "Origins(")?; + f.debug_set() + .entries(self.0.iter().map(|o| String::from_utf8_lossy(o))) + .finish()?; + write!(f, ")") + } +} + +#[cfg(test)] +mod tests { + use std::convert::Infallible; + + use http::{Request, Response, StatusCode}; + use tower::{service_fn, ServiceExt}; + use tower_layer::Layer; + + use super::*; + use crate::test_helpers::{to_bytes, Body}; + + impl PartialEq for super::ProtectionError { + fn eq(&self, other: &Self) -> bool { + self.kind == other.kind + } + } + + fn echo_service() -> impl tower::Service< + Request, + Response = Response, + Error = Infallible, + Future = impl std::future::Future, Infallible>>, + > + Clone { + service_fn(|req: Request| async move { + let body: Body = match req.uri().path() { + "/foo" => "foo".into(), + "/bar" => "bar".into(), + _ => Body::empty(), + }; + Ok::<_, Infallible>(Response::new(body)) + }) + } + + #[tokio::test] + async fn test_service_allows_safe_method() { + let svc = CsrfLayer::new() + .add_trusted_origin("https://example.com") + .unwrap() + .layer(echo_service()); + + let req = Request::builder() + .method("GET") + .uri("/foo") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let body = to_bytes(res.into_body()).await.unwrap(); + assert_eq!(&body[..], b"foo"); + } + + #[tokio::test] + async fn test_service_allows_post_from_trusted_origin() { + let svc = CsrfLayer::new() + .add_trusted_origin("https://example.com") + .unwrap() + .layer(echo_service()); + + let req = Request::builder() + .method("POST") + .uri("/bar") + .header("origin", "https://example.com") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let body = to_bytes(res.into_body()).await.unwrap(); + assert_eq!(&body[..], b"bar"); + } + + #[tokio::test] + async fn test_service_rejects_post_from_untrusted_origin() { + let svc = CsrfLayer::new() + .add_trusted_origin("https://example.com") + .unwrap() + .layer(echo_service()); + + let req = Request::builder() + .method("POST") + .uri("/bar") + .header("origin", "https://malicious.example") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::FORBIDDEN); + assert_eq!( + res.extensions().get::(), + Some(&ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser + )), + ); + } + + #[tokio::test] + async fn test_service_uses_custom_rejection_response() { + let svc = CsrfLayer::new() + .with_rejection_response(|_err: ProtectionError| { + let mut res = Response::new(Body::from("denied")); + *res.status_mut() = StatusCode::IM_A_TEAPOT; + res + }) + .layer(echo_service()); + + let req = Request::builder() + .method("POST") + .uri("/bar") + .header("origin", "https://malicious.example") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::IM_A_TEAPOT); + assert_ne!(res.status(), StatusCode::OK); + // The middleware attaches the error even though a custom builder + // produced the response. + assert_eq!( + res.extensions().get::(), + Some(&ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser + )), + ); + + let body = to_bytes(res.into_body()).await.unwrap(); + assert_eq!(&body[..], b"denied"); + } + + #[tokio::test] + async fn test_service_custom_rejection_response_not_invoked_when_allowed() { + let svc = CsrfLayer::new() + .add_trusted_origin("https://example.com") + .unwrap() + .with_rejection_response(|_err: ProtectionError| { + let mut res = Response::new(Body::from("denied")); + *res.status_mut() = StatusCode::IM_A_TEAPOT; + res + }) + .layer(echo_service()); + + let req = Request::builder() + .method("POST") + .uri("/bar") + .header("origin", "https://example.com") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_ne!(res.status(), StatusCode::IM_A_TEAPOT); + assert!(res.extensions().get::().is_none()); + + let body = to_bytes(res.into_body()).await.unwrap(); + assert_eq!(&body[..], b"bar"); + } + + #[test] + fn test_layer_add_trusted_origin() { + // Smoke check that the layer threads parse_origin's Ok and Err + // through; the full validation matrix lives in url.rs. + assert!(CsrfLayer::new() + .add_trusted_origin("https://example.com") + .is_ok()); + assert!(matches!( + CsrfLayer::new().add_trusted_origin("not a valid url"), + Err(ConfigError::InvalidOriginUrl { .. }) + )); + } + + #[test] + fn test_middleware_bypass() { + let layer = CsrfLayer::new() + .with_insecure_bypass(|_method, uri| -> bool { uri.path() == "/bypass" }); + + let middleware = layer.layer(()); + + struct Test { + name: &'static str, + path: &'static str, + sec_fetch_site: Option<&'static str>, + result: Result<(), ProtectionError>, + } + + let tests = [ + Test { + name: "bypass path without sec-fetch-site", + path: "/bypass", + sec_fetch_site: None, + result: Ok(()), + }, + Test { + name: "bypass path with cross-site", + path: "/bypass", + sec_fetch_site: Some("cross-site"), + result: Ok(()), + }, + Test { + name: "non-bypass path without sec-fetch-site", + path: "/api", + sec_fetch_site: None, + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + name: "non-bypass path with cross-site", + path: "/api", + sec_fetch_site: Some("cross-site"), + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequest, + )), + }, + ]; + + for test in tests { + let mut req = Request::builder() + .method("POST") + .header("host", "example.com") + .header("origin", "https://attacker.example") + .uri(format!("https://example.com{}", test.path)); + + if let Some(sec_fetch_site) = test.sec_fetch_site { + req = req.header("sec-fetch-site", sec_fetch_site); + } + + let req = req.body(()).unwrap(); + + assert_eq!(middleware.verify(&req), test.result, "{}", test.name); + } + } + + #[test] + fn test_middleware_bypass_applies_when_origin_unparseable() { + let middleware = CsrfLayer::new() + .with_insecure_bypass(|_method, uri| uri.path() == "/bypass") + .layer(()); + + let req = Request::builder() + .method("POST") + .uri("https://example.com/bypass") + .header("host", "example.com") + .header( + "origin", + http::HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap(), + ) + .body(()) + .unwrap(); + + assert_eq!(middleware.verify(&req), Ok(())); + } + + #[test] + fn test_middleware_debug_trait() { + let layer = CsrfLayer::new(); + + let middleware = layer + .clone() + .with_insecure_bypass(|method, uri| method == Method::POST && uri.path() == "/bypass") + .layer(()); + + assert_eq!( + format!("{:?}", middleware), + "Csrf { inner: (), insecure_bypass: Some(), trusted_origins: Origins({}), rejection_response: }" + ); + + let middleware = layer.layer(()); + + assert_eq!( + format!("{:?}", middleware), + "Csrf { inner: (), insecure_bypass: None, trusted_origins: Origins({}), rejection_response: }" + ); + } + + #[test] + fn test_middleware_origin_host_port_match() { + let middleware: Csrf<()> = Default::default(); + + struct Test { + name: &'static str, + uri: &'static str, + host: Option<&'static str>, + origin: &'static str, + result: Result<(), ProtectionError>, + } + + let tests = [ + Test { + name: "default port both sides", + uri: "/", + host: Some("example.com"), + origin: "https://example.com", + result: Ok(()), + }, + Test { + name: "same non-default port both sides", + uri: "/", + host: Some("example.com:8443"), + origin: "https://example.com:8443", + result: Ok(()), + }, + Test { + name: "explicit default port both sides", + uri: "/", + host: Some("example.com:443"), + origin: "https://example.com:443", + result: Ok(()), + }, + Test { + name: "mismatched non-default ports", + uri: "/", + host: Some("example.com:8443"), + origin: "https://example.com:8444", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + // Strict byte match: an explicit default port does not equal an + // implicit one (the reference does not normalize ports). + name: "origin has explicit default, host implicit", + uri: "/", + host: Some("example.com"), + origin: "https://example.com:443", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + name: "host has explicit default, origin implicit", + uri: "/", + host: Some("example.com:443"), + origin: "https://example.com", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + name: "host implicit, origin explicit non-default", + uri: "/", + host: Some("example.com"), + origin: "https://example.com:8443", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + name: "missing host, uri authority implicit, origin explicit non-default", + uri: "https://example.com/path", + host: None, + origin: "https://example.com:8443", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + // No request-target authority, so the Host header is the effective + // host, compared verbatim — a malformed Host never matches an Origin. + name: "malformed host header compared verbatim", + uri: "/path", + host: Some("not a valid authority"), + origin: "https://example.com", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + // RFC 7230 §5.3 / Go parity: the request-target authority is the + // effective host (Host header ignored); here it matches Origin. + name: "request-target authority wins over host header (match)", + uri: "https://example.com/path", + host: Some("other.example"), + origin: "https://example.com", + result: Ok(()), + }, + Test { + // Security-relevant: Origin matches the Host header but not the + // winning request-target authority, so it stays cross-origin. + name: "origin matching host header but not authority is rejected", + uri: "https://example.com/path", + host: Some("other.example"), + origin: "https://other.example", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + name: "missing host, uri carries authority (match)", + uri: "https://example.com/path", + host: None, + origin: "https://example.com", + result: Ok(()), + }, + Test { + name: "missing host, uri authority mismatch", + uri: "https://other.example/path", + host: None, + origin: "https://example.com", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + name: "missing host and no uri authority", + uri: "/path", + host: None, + origin: "https://example.com", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + name: "scheme-less origin does not match host even if bytes agree", + uri: "/", + host: Some("example.com:8443"), + origin: "example.com:8443", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + name: "non-http origin scheme does not enter host fallback", + uri: "/", + host: Some("example.com:8443"), + origin: "ftp://example.com:8443", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + ]; + + for test in tests { + let mut req = Request::builder().method(Method::POST).uri(test.uri); + + if let Some(host) = test.host { + req = req.header("host", host); + } + + let req = req.header("origin", test.origin).body(()).unwrap(); + + assert_eq!(middleware.verify(&req), test.result, "{}", test.name); + } + } + + #[test] + fn test_middleware_sec_fetch_site() { + let middleware: Csrf<()> = Default::default(); + + const NON_DECODABLE: &[u8] = &[0xFF, 0xFE]; + assert!( + http::HeaderValue::from_bytes(NON_DECODABLE) + .expect("NON_DECODABLE must be a valid HeaderValue") + .to_str() + .is_err(), + "NON_DECODABLE must fail HeaderValue::to_str()" + ); + + struct Test { + name: &'static str, + method: http::Method, + sec_fetch_site: Option<&'static [u8]>, + origin: Option<&'static [u8]>, + result: Result<(), ProtectionError>, + } + + let tests = [ + Test { + name: "same-origin allowed", + method: Method::GET, + sec_fetch_site: Some(b"same-origin"), + origin: None, + result: Ok(()), + }, + Test { + name: "none allowed", + method: Method::POST, + sec_fetch_site: Some(b"none"), + origin: None, + result: Ok(()), + }, + Test { + name: "cross-site blocked", + method: Method::POST, + sec_fetch_site: Some(b"cross-site"), + origin: None, + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequest, + )), + }, + Test { + name: "same-site blocked", + method: Method::POST, + sec_fetch_site: Some(b"same-site"), + origin: None, + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequest, + )), + }, + Test { + name: "no header with no origin", + method: Method::POST, + sec_fetch_site: None, + origin: None, + result: Ok(()), + }, + Test { + name: "no header with matching origin", + method: Method::POST, + sec_fetch_site: None, + origin: Some(b"https://example.com"), + result: Ok(()), + }, + Test { + name: "no header with mismatched origin", + method: Method::POST, + sec_fetch_site: None, + origin: Some(b"https://attacker.example"), + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + name: "no header with null origin", + method: Method::POST, + sec_fetch_site: None, + origin: Some(b"null"), + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + name: "GET allowed", + method: Method::GET, + sec_fetch_site: Some(b"cross-site"), + origin: None, + result: Ok(()), + }, + Test { + name: "HEAD allowed", + method: Method::HEAD, + sec_fetch_site: Some(b"cross-site"), + origin: None, + result: Ok(()), + }, + Test { + name: "OPTIONS allowed", + method: Method::OPTIONS, + sec_fetch_site: Some(b"cross-site"), + origin: None, + result: Ok(()), + }, + Test { + name: "PUT blocked", + method: Method::PUT, + sec_fetch_site: Some(b"cross-site"), + origin: None, + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequest, + )), + }, + Test { + name: "non-decodable origin without sec-fetch-site rejected", + method: Method::POST, + sec_fetch_site: None, + origin: Some(NON_DECODABLE), + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + name: "non-decodable sec-fetch-site without origin rejected", + method: Method::POST, + sec_fetch_site: Some(NON_DECODABLE), + origin: None, + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequest, + )), + }, + Test { + name: "empty sec-fetch-site without origin allowed", + method: Method::POST, + sec_fetch_site: Some(b""), + origin: None, + result: Ok(()), + }, + Test { + name: "empty origin without sec-fetch-site allowed", + method: Method::POST, + sec_fetch_site: None, + origin: Some(b""), + result: Ok(()), + }, + ]; + + for test in tests { + let mut req = Request::builder() + .method(test.method) + .header("host", "example.com"); + + if let Some(sec_fetch_site) = test.sec_fetch_site { + req = req.header("sec-fetch-site", sec_fetch_site); + } + + if let Some(origin) = test.origin { + req = req.header("origin", origin); + } + + let req = req.body(()).unwrap(); + + assert_eq!(middleware.verify(&req), test.result, "{}", test.name); + } + } + + #[test] + fn test_middleware_trusted_origin_bypass() { + let layer = CsrfLayer::new() + .add_trusted_origin("https://trusted.example") + .unwrap(); + + let middleware = layer.layer(()); + + struct Test { + name: &'static str, + sec_fetch_site: Option<&'static str>, + origin: Option<&'static str>, + result: Result<(), ProtectionError>, + } + + let tests = [ + Test { + name: "trusted origin without sec-fetch-site", + origin: Some("https://trusted.example"), + sec_fetch_site: None, + result: Ok(()), + }, + Test { + name: "trusted origin with cross-site", + origin: Some("https://trusted.example"), + sec_fetch_site: Some("cross-site"), + result: Ok(()), + }, + Test { + name: "untrusted origin without sec-fetch-site", + origin: Some("https://attacker.example"), + sec_fetch_site: None, + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )), + }, + Test { + name: "untrusted origin with cross-site", + origin: Some("https://attacker.example"), + sec_fetch_site: Some("cross-site"), + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequest, + )), + }, + ]; + + for test in tests { + let mut req = Request::builder() + .method("POST") + .header("host", "example.com"); + + if let Some(sec_fetch_site) = test.sec_fetch_site { + req = req.header("sec-fetch-site", sec_fetch_site); + } + + if let Some(origin) = test.origin { + req = req.header("origin", origin); + } + + let req = req.body(()).unwrap(); + + assert_eq!(middleware.verify(&req), test.result, "{}", test.name); + } + } + + #[test] + fn test_middleware_trusted_origin_strict_byte_match() { + // Trusted origins are matched byte-for-byte against the request's Origin + // header (no canonicalization), mirroring the Go reference. Only an exact + // match is trusted; case- and port-form variants are not. + struct Test { + name: &'static str, + trusted: &'static str, + origin: &'static str, + result: Result<(), ProtectionError>, + } + + let tests = [ + Test { + name: "exact match trusted", + trusted: "https://example.com", + origin: "https://example.com", + result: Ok(()), + }, + Test { + name: "exact match with non-default port", + trusted: "https://example.com:8443", + origin: "https://example.com:8443", + result: Ok(()), + }, + Test { + name: "host case mismatch not trusted", + trusted: "https://Example.COM", + origin: "https://example.com", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequest, + )), + }, + Test { + name: "explicit default port not trusted against bare origin", + trusted: "https://example.com:443", + origin: "https://example.com", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequest, + )), + }, + Test { + name: "bare trusted not matched by explicit-default-port origin", + trusted: "https://example.com", + origin: "https://example.com:443", + result: Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequest, + )), + }, + ]; + + for test in tests { + let middleware = CsrfLayer::new() + .add_trusted_origin(test.trusted) + .unwrap_or_else(|e| panic!("{}: add_trusted_origin failed: {e}", test.name)) + .layer(()); + + let req = Request::builder() + .method("POST") + .header("host", "other.example") + .header("origin", test.origin) + .header("sec-fetch-site", "cross-site") + .body(()) + .unwrap(); + + assert_eq!(middleware.verify(&req), test.result, "{}", test.name); + } + } +} diff --git a/tower-http/src/csrf/response.rs b/tower-http/src/csrf/response.rs new file mode 100644 index 00000000..cdbad7ab --- /dev/null +++ b/tower-http/src/csrf/response.rs @@ -0,0 +1,45 @@ +use http::{Response, StatusCode}; + +use super::ProtectionError; + +/// Builds the response returned by [`Csrf`] when a request fails CSRF protection. +/// +/// Implemented for any `FnMut(ProtectionError) -> Response + Clone`, so a +/// closure can be passed directly to +/// [`CsrfLayer::with_rejection_response`](super::CsrfLayer::with_rejection_response). +/// +/// [`Csrf`]: super::Csrf +pub trait ResponseForProtectionError: Clone { + /// Builds the response from the rejection error. + fn response_for_protection_error(&mut self, error: ProtectionError) -> Response; +} + +impl ResponseForProtectionError for F +where + F: FnMut(ProtectionError) -> Response + Clone, +{ + fn response_for_protection_error(&mut self, error: ProtectionError) -> Response { + self(error) + } +} + +/// Default [`ResponseForProtectionError`] used by +/// [`CsrfLayer::new`](super::CsrfLayer::new). +/// +/// Produces a `403 Forbidden` response with an empty body. The originating +/// [`ProtectionError`] is attached to the response's extensions by [`Csrf`] +/// itself, so it is present regardless of which builder produced the response. +/// +/// [`Csrf`]: super::Csrf +#[derive(Clone, Copy, Debug, Default)] +#[non_exhaustive] +pub struct DefaultResponseForProtectionError; + +impl ResponseForProtectionError for DefaultResponseForProtectionError { + fn response_for_protection_error(&mut self, _error: ProtectionError) -> Response { + let mut response = Response::new(B::default()); + *response.status_mut() = StatusCode::FORBIDDEN; + + response + } +} diff --git a/tower-http/src/csrf/service.rs b/tower-http/src/csrf/service.rs new file mode 100644 index 00000000..7695603f --- /dev/null +++ b/tower-http/src/csrf/service.rs @@ -0,0 +1,216 @@ +use std::convert::TryFrom; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use http::{Method, Request, Response, Uri}; +use tower_service::Service; + +use super::future::ResponseFuture; +use super::{ + BypassFn, DebugFn, DefaultResponseForProtectionError, Origins, ProtectionError, + ProtectionErrorKind, ResponseForProtectionError, +}; + +/// Middleware that enforces cross-origin request forgery (CSRF) protection. +/// +/// See the [module docs](crate::csrf) for an example. +#[derive(Clone)] +#[must_use] +pub struct Csrf { + inner: S, + insecure_bypass: Option>, + rejection_response: T, + trusted_origins: Origins, +} + +impl Csrf { + pub(super) fn new( + inner: S, + insecure_bypass: Option>, + rejection_response: T, + trusted_origins: Origins, + ) -> Self { + Self { + inner, + insecure_bypass, + rejection_response, + trusted_origins, + } + } + + pub(super) fn verify(&self, req: &Request) -> Result<(), ProtectionError> { + // Deliberately not Method::is_safe: it also treats TRACE as safe, but the + // reference implementation only exempts GET/HEAD/OPTIONS, so we match it here. + if matches!( + req.method(), + &Method::GET | &Method::HEAD | &Method::OPTIONS + ) { + #[cfg(feature = "tracing")] + tracing::trace!(uri = %req.uri().path(), "request passed: safe method"); + return Ok(()); + } + + let origin = req.headers().get("origin").map(|h| h.as_bytes()); + + let origin_uri = origin + .filter(|b| !b.is_empty()) + .and_then(|b| Uri::try_from(b).ok()) + .filter(|u| matches!(u.scheme_str(), Some("http" | "https"))); + + let sec_fetch_site = req.headers().get("sec-fetch-site").map(|h| h.as_bytes()); + + let is_exempt = || -> bool { + let bypass = self + .insecure_bypass + .as_ref() + .map_or(false, |bypass| bypass(req.method(), req.uri())); + + if bypass { + #[cfg(feature = "tracing")] + tracing::trace!(uri = %req.uri().path(), "request passed: bypassed"); + return true; + } + + // Strict byte match of the raw Origin header against the registered + // set, mirroring the Go reference's `trustedOrigins[Origin]`. + let trusted = origin.map_or(false, |b| self.trusted_origins.contains(b)); + + if trusted { + #[cfg(feature = "tracing")] + tracing::trace!(uri = %req.uri().path(), "request passed: trusted origin"); + return true; + } + + false + }; + + // Fetch spec mandates lowercase here; exact byte match is intentional. + match sec_fetch_site { + Some(b"same-origin" | b"none") => { + #[cfg(feature = "tracing")] + tracing::trace!(uri = %req.uri().path(), "request passed: sec-fetch-site is same-origin or none"); + return Ok(()); + } + None | Some(b"") => {} // fall through to Origin check + Some(_) if is_exempt() => return Ok(()), + Some(_) => { + return Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequest, + )) + } + } + + if matches!(origin, None | Some(b"")) { + #[cfg(feature = "tracing")] + tracing::trace!(uri = %req.uri().path(), "request passed: neither sec-fetch-site nor origin header (same-origin or not a browser request)"); + return Ok(()); + } + + let host = req.headers().get("host").map(|h| h.as_bytes()); + + // Mirrors the reference's `url.Parse(origin).Host == req.Host`. Per RFC 7230 + // §5.3, req.Host is the request-target authority (absolute-form URI / HTTP/2 + // `:authority`) if present, else the Host header. Byte-exact and scheme-blind, + // so an http→https mismatch can't be caught here — we fail open (HSTS helps). + let effective_host = req + .uri() + .authority() + .map(|a| a.as_str().as_bytes()) + .or(host); + + if let (Some(uri), Some(effective_host)) = (&origin_uri, effective_host) { + if uri.authority().map(|a| a.as_str().as_bytes()) == Some(effective_host) { + #[cfg(feature = "tracing")] + tracing::trace!(uri = %req.uri().path(), "request passed: origin is same as host"); + return Ok(()); + } + } + + if is_exempt() { + return Ok(()); + } + + Err(ProtectionError::new( + ProtectionErrorKind::CrossOriginRequestFromOldBrowser, + )) + } +} + +impl Default for Csrf +where + S: Default, + T: Default, +{ + fn default() -> Self { + Self { + inner: S::default(), + insecure_bypass: None, + rejection_response: T::default(), + trusted_origins: Origins::default(), + } + } +} + +impl Debug for Csrf { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("Csrf") + .field("inner", &self.inner) + .field( + "insecure_bypass", + &self.insecure_bypass.as_ref().map(|_| DebugFn), + ) + .field("trusted_origins", &self.trusted_origins) + .field("rejection_response", &DebugFn) + .finish() + } +} + +impl Service> for Csrf +where + S: Service, Response = Response>, + T: ResponseForProtectionError, +{ + type Error = S::Error; + type Future = ResponseFuture; + type Response = Response; + + fn call(&mut self, req: Request) -> Self::Future { + match self.verify(&req) { + Ok(_) => ResponseFuture::future(self.inner.call(req)), + Err(err) => { + #[cfg(feature = "tracing")] + tracing::trace!(uri = %req.uri().path(), error = %err, "request rejected"); + + let mut response = self + .rejection_response + .response_for_protection_error(err.clone()); + + response.extensions_mut().insert(err); + + ResponseFuture::rejected(Ok(response)) + } + } + } + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Guards the comment in `verify`: `Method::is_safe` exempts more than the + // GET/HEAD/OPTIONS set the reference implementation uses, so we can't rely on it. + #[test] + fn method_is_safe_covers_more_than_get_head_options() { + for method in [&Method::GET, &Method::HEAD, &Method::OPTIONS] { + assert!(method.is_safe()); + } + + // TRACE is "safe" per RFC 7231 but is not in the reference implementation's set. + assert!(Method::TRACE.is_safe()); + } +} diff --git a/tower-http/src/csrf/url.rs b/tower-http/src/csrf/url.rs new file mode 100644 index 00000000..2458edaf --- /dev/null +++ b/tower-http/src/csrf/url.rs @@ -0,0 +1,160 @@ +use http::Uri; + +use super::ConfigError; + +/// Internal extension methods on [`http::Uri`] used by the CSRF middleware to +/// validate trusted-origin strings. +pub(crate) trait UriExt: Sized { + /// Parses a trusted-origin string of the form `scheme://host[:port]`. + /// + /// Rejects inputs that can't represent a browser `Origin`: + /// + /// - unparseable URIs ([`ConfigError::InvalidOriginUrl`]); + /// - non-`http`/`https` schemes or missing host ([`ConfigError::OpaqueOrigin`]); + /// - any path, query, or fragment component + /// ([`ConfigError::InvalidOriginUrlComponents`] — including a bare trailing + /// `/` and fragments that `http::Uri` would otherwise silently strip); + /// - non-ASCII hostnames ([`ConfigError::NonAsciiHostname`] — IDN hosts + /// must be supplied in punycode, since that's what browsers send). + /// + /// The returned [`Uri`] is parsed but not normalized; the origin is matched + /// against the request's `Origin` header byte-for-byte. + fn parse_origin(input: &str) -> Result; +} + +impl UriExt for Uri { + fn parse_origin(input: &str) -> Result { + if input.contains('#') { + return Err(ConfigError::InvalidOriginUrlComponents { + origin: input.to_owned(), + }); + } + + // browsers will send punycode anyways + if !input.is_ascii() { + return Err(ConfigError::NonAsciiHostname { + origin: input.to_owned(), + }); + } + + let uri: Uri = + input + .parse() + .map_err(|e: http::uri::InvalidUri| ConfigError::InvalidOriginUrl { + origin: input.to_owned(), + message: e.to_string(), + })?; + + if !matches!(uri.scheme_str(), Some("http" | "https")) + || uri.host().map_or(true, |h| h.is_empty()) + { + return Err(ConfigError::OpaqueOrigin { + origin: input.to_owned(), + }); + } + + // Reject any path/query (fragments are rejected above). `http::Uri` + // reports `path()` as "/" for both `scheme://host` and `scheme://host/`, + // so detect a path from the raw input (everything after "://") to reach + // parity with Go, which rejects a non-empty path — including a bare "/". + let after_scheme = input.split_once("://").map_or("", |(_, rest)| rest); + + if after_scheme.contains('/') || uri.query().is_some() { + return Err(ConfigError::InvalidOriginUrlComponents { + origin: input.to_owned(), + }); + } + + Ok(uri) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_origin_accepts() { + for input in [ + "https://example.com", + "http://example.com", + "https://example.com:8443", + "HTTPS://Example.COM", + ] { + assert!( + Uri::parse_origin(input).is_ok(), + "expected Ok for {input:?}, got {:?}", + Uri::parse_origin(input) + ); + } + } + + #[test] + fn test_parse_origin_rejects() { + // Each row maps an input to the expected ConfigError variant. + // Marker functions over closures because PartialEq on the enum already + // makes equality the easy assertion shape. + type Check = fn(&ConfigError) -> bool; + let cases: &[(&str, Check)] = &[ + // http::Uri rejects these outright at parse time. + ("not a valid url", |e| { + matches!(e, ConfigError::InvalidOriginUrl { .. }) + }), + ("https://", |e| { + matches!(e, ConfigError::InvalidOriginUrl { .. }) + }), + ("file:///", |e| { + matches!(e, ConfigError::InvalidOriginUrl { .. }) + }), + // Parse OK but scheme is not http/https (or absent). + ("example.com", |e| { + matches!(e, ConfigError::OpaqueOrigin { .. }) + }), + ("file://host/path", |e| { + matches!(e, ConfigError::OpaqueOrigin { .. }) + }), + ("mailto:x@y.z", |e| { + matches!(e, ConfigError::OpaqueOrigin { .. }) + }), + ("javascript:alert(1)", |e| { + matches!(e, ConfigError::OpaqueOrigin { .. }) + }), + // Path/query/fragment not allowed on a trusted origin. A bare + // trailing slash is a (non-empty) path too — rejected, matching Go. + ("https://example.com/", |e| { + matches!(e, ConfigError::InvalidOriginUrlComponents { .. }) + }), + ("https://example.com/path", |e| { + matches!(e, ConfigError::InvalidOriginUrlComponents { .. }) + }), + ("https://example.com/path?query=value", |e| { + matches!(e, ConfigError::InvalidOriginUrlComponents { .. }) + }), + ("https://example.com/path#fragment", |e| { + matches!(e, ConfigError::InvalidOriginUrlComponents { .. }) + }), + // http::Uri silently strips fragments; the `contains('#')` pre-check + // surfaces these as component errors instead of letting them slip in. + ("https://example.com#fragment", |e| { + matches!(e, ConfigError::InvalidOriginUrlComponents { .. }) + }), + ("https://example.com/#fragment", |e| { + matches!(e, ConfigError::InvalidOriginUrlComponents { .. }) + }), + // IDN hosts must be supplied in punycode. + ("https://ümlaut.de", |e| { + matches!(e, ConfigError::NonAsciiHostname { .. }) + }), + ("https://日本.jp", |e| { + matches!(e, ConfigError::NonAsciiHostname { .. }) + }), + ]; + + for (input, predicate) in cases { + match Uri::parse_origin(input) { + Err(e) if predicate(&e) => {} + other => panic!("unexpected result for {:?}: {:?}", input, other), + } + } + } +} diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index c91515b5..5ab4b68f 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -296,6 +296,9 @@ pub mod metrics; #[cfg(feature = "cors")] pub mod cors; +#[cfg(feature = "csrf")] +pub mod csrf; + #[cfg(feature = "request-id")] pub mod request_id;