diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..97e6195b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Preserve LF line endings in test fixtures so tests pass on Windows. +test-files/** eol=lf diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..5ace4600 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c0447164..f02987f8 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -6,6 +6,13 @@ on: - main pull_request: {} +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: check: runs-on: ubuntu-latest @@ -22,6 +29,9 @@ jobs: steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@nightly + - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} - name: cargo doc env: RUSTDOCFLAGS: "-D rustdoc::broken_intra_doc_links --cfg docsrs" @@ -35,7 +45,9 @@ jobs: - uses: Swatinem/rust-cache@v2 with: save-if: ${{ github.ref == 'refs/heads/main' }} - - uses: taiki-e/install-action@cargo-hack + - uses: taiki-e/install-action@v2 + with: + tool: cargo-hack - name: cargo hack check env: RUSTFLAGS: "-D unused_imports -D dead_code -D unused_variables" @@ -59,6 +71,22 @@ jobs: save-if: ${{ github.ref == 'refs/heads/main' }} - run: cargo test --workspace --all-features + test-os: + # Test on macOS and Windows to catch platform-specific issues. + needs: check + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [macos-latest, windows-latest] + steps: + - uses: actions/checkout@v6 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + - run: cargo test --workspace --all-features + test-msrv: needs: check runs-on: ubuntu-latest @@ -80,6 +108,59 @@ jobs: save-if: ${{ github.ref == 'refs/heads/main' }} - run: cargo check -p tower-http --all-features + minimal-versions: + needs: check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - uses: dtolnay/rust-toolchain@nightly + - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + - name: Install cargo-hack + uses: taiki-e/install-action@v2 + with: + tool: cargo-hack + - name: Check with minimal versions + run: | + cargo hack --remove-dev-deps + cargo update -Z minimal-versions + cargo update -p crc32fast --precise 1.2.0 + cargo check -p tower-http --all-features + + semver-checks: + if: github.event_name == 'pull_request' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + with: + fetch-depth: 0 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + - name: Install cargo-semver-checks + uses: taiki-e/install-action@v2 + with: + tool: cargo-semver-checks + - name: Check semver compatibility + id: checks + run: | + set +e + OUTPUT=$(cargo semver-checks check-release -p tower-http --all-features \ + --baseline-rev origin/${{ github.base_ref }} --color never 2>&1) + STATUS=$? + echo "$OUTPUT" + if [ $STATUS -ne 0 ]; then + echo "$OUTPUT" > semver-output.txt + fi + - name: Upload results + if: always() && hashFiles('semver-output.txt') != '' + uses: actions/upload-artifact@v7 + with: + name: semver-checks-output + path: semver-output.txt + style: needs: check runs-on: ubuntu-latest @@ -124,7 +205,7 @@ jobs: with: toolchain: nightly-2025-10-18 - name: Install cargo-check-external-types - uses: taiki-e/cache-cargo-install-action@v2 + uses: taiki-e/cache-cargo-install-action@v3 with: tool: cargo-check-external-types@0.4.0 - uses: Swatinem/rust-cache@v2 diff --git a/.github/workflows/semver-comment.yml b/.github/workflows/semver-comment.yml new file mode 100644 index 00000000..df2e8495 --- /dev/null +++ b/.github/workflows/semver-comment.yml @@ -0,0 +1,58 @@ +name: Semver Checks Comment + +on: + workflow_run: + workflows: ["CI"] + types: + - completed + +permissions: + actions: read + pull-requests: write + +jobs: + comment: + runs-on: ubuntu-latest + if: github.event.workflow_run.event == 'pull_request' + steps: + - name: Download results + id: download + uses: actions/download-artifact@v8 + with: + name: semver-checks-output + run-id: ${{ github.event.workflow_run.id }} + github-token: ${{ github.token }} + continue-on-error: true + + - name: Build comment + id: build + if: steps.download.outcome == 'success' + run: | + echo "comment<> "$GITHUB_OUTPUT" + echo "### ⚠️ Breaking API changes detected" >> "$GITHUB_OUTPUT" + echo "" >> "$GITHUB_OUTPUT" + echo "Please make sure these are intentional and noted in the changelog." >> "$GITHUB_OUTPUT" + echo "" >> "$GITHUB_OUTPUT" + echo "
cargo semver-checks output" >> "$GITHUB_OUTPUT" + echo "" >> "$GITHUB_OUTPUT" + echo '```' >> "$GITHUB_OUTPUT" + cat semver-output.txt >> "$GITHUB_OUTPUT" + echo '```' >> "$GITHUB_OUTPUT" + echo "
" >> "$GITHUB_OUTPUT" + echo "EOF" >> "$GITHUB_OUTPUT" + + - name: Post comment + if: steps.build.outputs.comment + uses: marocchino/sticky-pull-request-comment@v3 + with: + number: ${{ github.event.workflow_run.pull_requests[0].number }} + header: semver-checks + message: ${{ steps.build.outputs.comment }} + + - name: Hide comment + if: steps.download.outcome == 'failure' + uses: marocchino/sticky-pull-request-comment@v3 + with: + number: ${{ github.event.workflow_run.pull_requests[0].number }} + header: semver-checks + hide: true diff --git a/test-files/missing_precompressed.txt b/test-files/missing_precompressed.txt index 1cbaf90b..524acfff 100644 --- a/test-files/missing_precompressed.txt +++ b/test-files/missing_precompressed.txt @@ -1 +1 @@ -Test file! +Test file diff --git a/test-files/only_gzipped.txt.gz b/test-files/only_gzipped.txt.gz index da92e795..895000a4 100644 Binary files a/test-files/only_gzipped.txt.gz and b/test-files/only_gzipped.txt.gz differ diff --git a/test-files/precompressed.txt b/test-files/precompressed.txt index e073b8dc..524acfff 100644 --- a/test-files/precompressed.txt +++ b/test-files/precompressed.txt @@ -1 +1 @@ -"This is a test file!" +Test file diff --git a/test-files/precompressed.txt.br b/test-files/precompressed.txt.br index ca313e78..62bc30a9 100644 --- a/test-files/precompressed.txt.br +++ b/test-files/precompressed.txt.br @@ -1,2 +1,2 @@ - "This is a test file!" +Test file  \ No newline at end of file diff --git a/test-files/precompressed.txt.gz b/test-files/precompressed.txt.gz index f8913e21..895000a4 100644 Binary files a/test-files/precompressed.txt.gz and b/test-files/precompressed.txt.gz differ diff --git a/test-files/precompressed.txt.zst b/test-files/precompressed.txt.zst index 813fc955..e22efdf3 100644 Binary files a/test-files/precompressed.txt.zst and b/test-files/precompressed.txt.zst differ diff --git a/test-files/precompressed.txt.zz b/test-files/precompressed.txt.zz index 69a8dee6..3975e2f4 100644 Binary files a/test-files/precompressed.txt.zz and b/test-files/precompressed.txt.zz differ diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index 284a3198..54c8ad59 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -7,6 +7,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +## Changed + +- `trace`: `DefaultOnRequest`, `DefaultOnResponse`, `DefaultOnFailure`, and + `DefaultOnEos` now explicitly parent their tracing events to the request span + rather than relying on the ambient span context. This fixes intermittent cases + where events could appear without their request span attached ([#655]) +- The implicit `tokio` and `async-compression` features are removed (BREAKING). + These were kept as no-op features in 0.6.x for backwards compatibility after + the switch to `dep:` syntax in [#642]. Downstream crates that activate + `tower-http/tokio` or `tower-http/async-compression` should remove those + feature entries; the underlying dependencies are still pulled in transitively + by the features that need them (e.g. `compression-gzip`, `fs`, `timeout`). + ([#628]) + +[#628]: https://github.com/tower-rs/tower-http/pull/628 +[#642]: https://github.com/tower-rs/tower-http/pull/642 + # 0.6.11 ## Added @@ -51,6 +68,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#408]: https://github.com/tower-rs/tower-http/pull/408 [#506]: https://github.com/tower-rs/tower-http/pull/506 [#587]: https://github.com/tower-rs/tower-http/pull/587 +[#655]: https://github.com/tower-rs/tower-http/issues/655 [#672]: https://github.com/tower-rs/tower-http/pull/672 [#675]: https://github.com/tower-rs/tower-http/pull/675 [#677]: https://github.com/tower-rs/tower-http/pull/677 diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index 9a23c707..e908bd56 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -120,11 +120,6 @@ decompression-full = ["decompression-br", "decompression-deflate", "decompressio decompression-gzip = ["dep:async-compression", "async-compression?/gzip", "futures-core", "dep:http-body", "dep:http-body-util", "tokio-util", "dep:tokio"] decompression-zstd = ["dep:async-compression", "async-compression?/zstd", "futures-core", "dep:http-body", "dep:http-body-util", "tokio-util", "dep:tokio"] -# FIXME: rip this out come 0.7.0. -# ref: https://github.com/tower-rs/tower-http/pull/666#issuecomment-4382555061 -tokio = [] -async-compression = [] - [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] diff --git a/tower-http/src/compression/mod.rs b/tower-http/src/compression/mod.rs index 420a9e88..c32fc579 100644 --- a/tower-http/src/compression/mod.rs +++ b/tower-http/src/compression/mod.rs @@ -94,6 +94,7 @@ mod tests { use super::*; use crate::test_helpers::{Body, WithTrailers}; use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder}; + use bytes::Bytes; use flate2::read::GzDecoder; use http::header::{ ACCEPT_ENCODING, ACCEPT_RANGES, CONTENT_ENCODING, CONTENT_RANGE, CONTENT_TYPE, RANGE, @@ -106,7 +107,7 @@ mod tests { use std::sync::{Arc, RwLock}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_util::io::StreamReader; - use tower::{service_fn, Service, ServiceExt}; + use tower::{service_fn, BoxError, Service, ServiceExt}; // Compression filter allows every other request to be compressed #[derive(Clone)] @@ -522,6 +523,88 @@ mod tests { assert_eq!(decompressed, "Hello, World!"); } + #[tokio::test] + async fn trailers_with_empty_body() { + let svc = service_fn(|_req: Request| async { + let mut trailers = HeaderMap::new(); + trailers.insert("grpc-status", "0".parse().unwrap()); + trailers.insert("grpc-message", "OK".parse().unwrap()); + let body = Body::empty().with_trailers(trailers); + Ok::<_, Infallible>(Response::builder().body(body).unwrap()) + }); + let mut svc = Compression::new(svc).compress_when(Always); + + let req = Request::builder() + .header("accept-encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + let collected = res.into_body().collect().await.unwrap(); + let trailers = collected.trailers().cloned().unwrap(); + assert_eq!(trailers["grpc-status"], "0"); + assert_eq!(trailers["grpc-message"], "OK"); + } + + #[tokio::test] + async fn trailers_with_streamed_body() { + // Simulate a gRPC-like streamed response: multiple data frames followed by trailers + let svc = service_fn(|_req: Request| async { + let stream = futures_util::stream::iter(vec![ + Ok::<_, BoxError>(Bytes::from("chunk1")), + Ok(Bytes::from("chunk2")), + Ok(Bytes::from("chunk3")), + ]); + let mut trailers = HeaderMap::new(); + trailers.insert("grpc-status", "0".parse().unwrap()); + let body = Body::from_stream(stream).with_trailers(trailers); + Ok::<_, Infallible>(Response::builder().body(body).unwrap()) + }); + let mut svc = Compression::new(svc).compress_when(Always); + + let req = Request::builder() + .header("accept-encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + let collected = res.into_body().collect().await.unwrap(); + let trailers = collected.trailers().cloned().unwrap(); + let compressed_data = collected.to_bytes(); + + let mut decoder = GzDecoder::new(&compressed_data[..]); + let mut decompressed = String::new(); + decoder.read_to_string(&mut decompressed).unwrap(); + + assert_eq!(decompressed, "chunk1chunk2chunk3"); + assert_eq!(trailers["grpc-status"], "0"); + } + + #[tokio::test] + async fn trailers_with_grpc_web_content_type() { + let svc = service_fn(|_req: Request| async { + let mut trailers = HeaderMap::new(); + trailers.insert("grpc-status", "0".parse().unwrap()); + let body = Body::from("a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize)) + .with_trailers(trailers); + let mut res = Response::new(body); + res.headers_mut() + .insert(CONTENT_TYPE, "application/grpc-web+proto".parse().unwrap()); + Ok::<_, Infallible>(res) + }); + let mut svc = Compression::new(svc).compress_when(Always); + + let req = Request::builder() + .header("accept-encoding", "gzip") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + let collected = res.into_body().collect().await.unwrap(); + let trailers = collected.trailers().cloned().unwrap(); + assert_eq!(trailers["grpc-status"], "0"); + } + #[tokio::test] async fn size_hint_identity() { let msg = "Hello, world!"; diff --git a/tower-http/src/compression_utils.rs b/tower-http/src/compression_utils.rs index 1fbccb85..9c5cdc20 100644 --- a/tower-http/src/compression_utils.rs +++ b/tower-http/src/compression_utils.rs @@ -238,19 +238,9 @@ where // poll any remaining frames, such as trailers let body = M::get_pin_mut(this.read).get_pin_mut().get_pin_mut(); match ready!(body.poll_frame(cx)) { - Some(Ok(frame)) if frame.is_trailers() => Poll::Ready(Some(Ok( + Some(Ok(frame)) => Poll::Ready(Some(Ok( frame.map_data(|mut data| data.copy_to_bytes(data.remaining())) ))), - Some(Ok(frame)) => { - if let Ok(bytes) = frame.into_data() { - if bytes.has_remaining() { - return Poll::Ready(Some(Err( - "there are extra bytes after body has been decompressed".into(), - ))); - } - } - Poll::Ready(None) - } Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), None => Poll::Ready(None), } diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index 543fd6a2..c91515b5 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -197,7 +197,7 @@ clippy::match_like_matches_macro, clippy::type_complexity )] -#![forbid(unsafe_code)] +#![deny(unsafe_code)] #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(test, allow(clippy::float_cmp))] diff --git a/tower-http/src/on_early_drop/body.rs b/tower-http/src/on_early_drop/body.rs index 0d76268e..3dd476b3 100644 --- a/tower-http/src/on_early_drop/body.rs +++ b/tower-http/src/on_early_drop/body.rs @@ -58,8 +58,8 @@ where self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { - let this = self.project(); - let result = ready!(this.inner.poll_frame(cx)); + let mut this = self.project(); + let result = ready!(this.inner.as_mut().poll_frame(cx)); // End-of-stream (Ready(None)) or body-level error (Ready(Some(Err))) // both mean the body will not yield more frames. Suppress the guard // in either case; service-level errors are out of scope for this @@ -67,6 +67,12 @@ where if matches!(result, None | Some(Err(_))) { this.guard.completed(); } + // If the inner body signals end-of-stream after this frame, mark + // completed now since the consumer may not poll again (e.g. when + // Content-Length is exact). + if matches!(result, Some(Ok(_))) && this.inner.is_end_stream() { + this.guard.completed(); + } Poll::Ready(result) } diff --git a/tower-http/src/on_early_drop/service.rs b/tower-http/src/on_early_drop/service.rs index 229a989c..0a4874b3 100644 --- a/tower-http/src/on_early_drop/service.rs +++ b/tower-http/src/on_early_drop/service.rs @@ -369,6 +369,41 @@ mod tests { assert_clone(&service); } + #[tokio::test] + async fn body_drop_suppressed_when_is_end_stream_after_data() { + // When Content-Length is exact, the consumer stops polling after + // receiving all data bytes. The guard must still be suppressed via + // is_end_stream() so the drop callback does not fire. + let fired = Arc::new(AtomicBool::new(false)); + let fired_clone = fired.clone(); + + let layer = OnEarlyDropLayer::builder().on_body_drop(OnBodyDropFn::new( + move |_req: &Request<()>| { + let fired = fired_clone.clone(); + move |_parts: &http::response::Parts| { + let fired = fired.clone(); + move || { + fired.store(true, Ordering::Relaxed); + } + } + }, + )); + let service = layer.layer(ok_service()); + let response = service.oneshot(request()).await.unwrap(); + let mut body = response.into_body(); + + // Poll only the data frame, then drop (simulating content-length consumer) + use http_body::Body as _; + let frame = std::future::poll_fn(|cx| std::pin::Pin::new(&mut body).poll_frame(cx)).await; + assert!(frame.unwrap().unwrap().data_ref().is_some()); + drop(body); + + assert!( + !fired.load(Ordering::Relaxed), + "body drop must not fire when is_end_stream() is true after last data frame", + ); + } + #[tokio::test] async fn body_drop_suppressed_when_is_end_stream_at_construction() { let fired = Arc::new(AtomicBool::new(false)); diff --git a/tower-http/src/services/fs/serve_dir/open_file.rs b/tower-http/src/services/fs/serve_dir/open_file.rs index ff1d7e46..54d778de 100644 --- a/tower-http/src/services/fs/serve_dir/open_file.rs +++ b/tower-http/src/services/fs/serve_dir/open_file.rs @@ -239,7 +239,9 @@ async fn open_file_with_fallback( let encoding = preferred_encoding(&mut path, &negotiated_encoding); match (File::open(&path).await, encoding) { (Ok(file), maybe_encoding) => break (file, maybe_encoding), - (Err(err), Some(encoding)) if err.kind() == io::ErrorKind::NotFound => { + (Err(err), Some(encoding)) + if err.kind() == io::ErrorKind::NotFound && encoding != Encoding::Identity => + { // Remove the extension corresponding to a precompressed file (.gz, .br, .zz) // to reset the path before the next iteration. path.set_extension(OsStr::new("")); @@ -265,7 +267,9 @@ async fn file_metadata_with_fallback( let encoding = preferred_encoding(&mut path, &negotiated_encoding); match (tokio::fs::metadata(&path).await, encoding) { (Ok(file), maybe_encoding) => break (file, maybe_encoding), - (Err(err), Some(encoding)) if err.kind() == io::ErrorKind::NotFound => { + (Err(err), Some(encoding)) + if err.kind() == io::ErrorKind::NotFound && encoding != Encoding::Identity => + { // Remove the extension corresponding to a precompressed file (.gz, .br, .zz) // to reset the path before the next iteration. path.set_extension(OsStr::new("")); diff --git a/tower-http/src/services/fs/serve_dir/tests.rs b/tower-http/src/services/fs/serve_dir/tests.rs index 3023bdf8..72c567a6 100644 --- a/tower-http/src/services/fs/serve_dir/tests.rs +++ b/tower-http/src/services/fs/serve_dir/tests.rs @@ -13,9 +13,19 @@ use std::fs; use std::io::Read; use tower::{service_fn, ServiceExt}; +/// Expected prefix of the decompressed content in precompressed test files. +const EXPECTED_CONTENT_PREFIX: &str = "Test file"; + +/// Root of the repository, relative to the working directory of the test binary. +const REPO_ROOT: &str = ".."; +/// Directory containing test fixture files. +const TEST_FILES_DIR: &str = "../test-files"; +/// Path to the repository README, used as a large test fixture. +const README_PATH: &str = "../README.md"; + #[tokio::test] async fn basic() { - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let req = Request::builder() .uri("/README.md") @@ -28,13 +38,13 @@ async fn basic() { let body = body_into_text(res.into_body()).await; - let contents = std::fs::read_to_string("../README.md").unwrap(); + let contents = std::fs::read_to_string(README_PATH).unwrap(); assert_eq!(body, contents); } #[tokio::test] async fn basic_with_index() { - let svc = ServeDir::new("../test-files"); + let svc = ServeDir::new(TEST_FILES_DIR); let req = Request::new(Body::empty()); let res = svc.oneshot(req).await.unwrap(); @@ -48,7 +58,7 @@ async fn basic_with_index() { #[tokio::test] async fn head_request() { - let svc = ServeDir::new("../test-files"); + let svc = ServeDir::new(TEST_FILES_DIR); let req = Request::builder() .uri("/precompressed.txt") @@ -59,14 +69,14 @@ async fn head_request() { let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); - assert_eq!(res.headers()["content-length"], "23"); + assert_eq!(res.headers()["content-length"], "10"); assert!(res.into_body().frame().await.is_none()); } #[tokio::test] async fn precompresed_head_request() { - let svc = ServeDir::new("../test-files").precompressed_gzip(); + let svc = ServeDir::new(TEST_FILES_DIR).precompressed_gzip(); let req = Request::builder() .uri("/precompressed.txt") @@ -78,14 +88,14 @@ async fn precompresed_head_request() { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); - assert_eq!(res.headers()["content-length"], "59"); + assert_eq!(res.headers()["content-length"], "30"); assert!(res.into_body().frame().await.is_none()); } #[tokio::test] async fn with_custom_chunk_size() { - let svc = ServeDir::new("..").with_buf_chunk_size(1024 * 32); + let svc = ServeDir::new(REPO_ROOT).with_buf_chunk_size(1024 * 32); let req = Request::builder() .uri("/README.md") @@ -98,13 +108,13 @@ async fn with_custom_chunk_size() { let body = body_into_text(res.into_body()).await; - let contents = std::fs::read_to_string("../README.md").unwrap(); + let contents = std::fs::read_to_string(README_PATH).unwrap(); assert_eq!(body, contents); } #[tokio::test] async fn precompressed_gzip() { - let svc = ServeDir::new("../test-files").precompressed_gzip(); + let svc = ServeDir::new(TEST_FILES_DIR).precompressed_gzip(); let req = Request::builder() .uri("/precompressed.txt") @@ -120,12 +130,12 @@ async fn precompressed_gzip() { let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); - assert!(decompressed.starts_with("\"This is a test file!\"")); + assert!(decompressed.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn precompressed_br() { - let svc = ServeDir::new("../test-files").precompressed_br(); + let svc = ServeDir::new(TEST_FILES_DIR).precompressed_br(); let req = Request::builder() .uri("/precompressed.txt") @@ -141,12 +151,12 @@ async fn precompressed_br() { let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); - assert!(decompressed.starts_with("\"This is a test file!\"")); + assert!(decompressed.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn precompressed_deflate() { - let svc = ServeDir::new("../test-files").precompressed_deflate(); + let svc = ServeDir::new(TEST_FILES_DIR).precompressed_deflate(); let request = Request::builder() .uri("/precompressed.txt") .header("Accept-Encoding", "deflate,br") @@ -161,12 +171,12 @@ async fn precompressed_deflate() { let mut decoder = DeflateDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); - assert!(decompressed.starts_with("\"This is a test file!\"")); + assert!(decompressed.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn unsupported_precompression_alogrithm_fallbacks_to_uncompressed() { - let svc = ServeDir::new("../test-files").precompressed_gzip(); + let svc = ServeDir::new(TEST_FILES_DIR).precompressed_gzip(); let request = Request::builder() .uri("/precompressed.txt") @@ -180,12 +190,12 @@ async fn unsupported_precompression_alogrithm_fallbacks_to_uncompressed() { let body = res.into_body().collect().await.unwrap().to_bytes(); let body = String::from_utf8(body.to_vec()).unwrap(); - assert!(body.starts_with("\"This is a test file!\"")); + assert!(body.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn only_precompressed_variant_existing() { - let svc = ServeDir::new("../test-files").precompressed_gzip(); + let svc = ServeDir::new(TEST_FILES_DIR).precompressed_gzip(); let request = Request::builder() .uri("/only_gzipped.txt") @@ -210,12 +220,12 @@ async fn only_precompressed_variant_existing() { let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); - assert!(decompressed.starts_with("\"This is a test file\"")); + assert!(decompressed.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn missing_precompressed_variant_fallbacks_to_uncompressed() { - let svc = ServeDir::new("../test-files").precompressed_gzip(); + let svc = ServeDir::new(TEST_FILES_DIR).precompressed_gzip(); let request = Request::builder() .uri("/missing_precompressed.txt") @@ -230,12 +240,12 @@ async fn missing_precompressed_variant_fallbacks_to_uncompressed() { let body = res.into_body().collect().await.unwrap().to_bytes(); let body = String::from_utf8(body.to_vec()).unwrap(); - assert!(body.starts_with("Test file!")); + assert!(body.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn missing_precompressed_variant_fallbacks_to_uncompressed_for_head_request() { - let svc = ServeDir::new("../test-files").precompressed_gzip(); + let svc = ServeDir::new(TEST_FILES_DIR).precompressed_gzip(); let request = Request::builder() .uri("/missing_precompressed.txt") @@ -246,7 +256,7 @@ async fn missing_precompressed_variant_fallbacks_to_uncompressed_for_head_reques let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); - assert_eq!(res.headers()["content-length"], "11"); + assert_eq!(res.headers()["content-length"], "10"); // Uncompressed file is served because compressed version is missing assert!(res.headers().get("content-encoding").is_none()); @@ -255,7 +265,7 @@ async fn missing_precompressed_variant_fallbacks_to_uncompressed_for_head_reques #[tokio::test] async fn precompressed_without_extension() { - let svc = ServeDir::new("../test-files").precompressed_gzip(); + let svc = ServeDir::new(TEST_FILES_DIR).precompressed_gzip(); let request = Request::builder() .uri("/extensionless_precompressed") @@ -274,13 +284,14 @@ async fn precompressed_without_extension() { let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); - let correct = fs::read_to_string("../test-files/extensionless_precompressed").unwrap(); + let correct = + fs::read_to_string(format!("{TEST_FILES_DIR}/extensionless_precompressed")).unwrap(); assert_eq!(decompressed, correct); } #[tokio::test] async fn missing_precompressed_without_extension_fallbacks_to_uncompressed() { - let svc = ServeDir::new("../test-files").precompressed_gzip(); + let svc = ServeDir::new(TEST_FILES_DIR).precompressed_gzip(); let request = Request::builder() .uri("/extensionless_precompressed_missing") @@ -297,13 +308,16 @@ async fn missing_precompressed_without_extension_fallbacks_to_uncompressed() { let body = res.into_body().collect().await.unwrap().to_bytes(); let body = String::from_utf8(body.to_vec()).unwrap(); - let correct = fs::read_to_string("../test-files/extensionless_precompressed_missing").unwrap(); + let correct = fs::read_to_string(format!( + "{TEST_FILES_DIR}/extensionless_precompressed_missing" + )) + .unwrap(); assert_eq!(body, correct); } #[tokio::test] async fn access_to_sub_dirs() { - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let req = Request::builder() .uri("/tower-http/Cargo.toml") @@ -322,7 +336,7 @@ async fn access_to_sub_dirs() { #[tokio::test] async fn not_found() { - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let req = Request::builder() .uri("/not-found") @@ -340,7 +354,7 @@ async fn not_found() { #[cfg(unix)] #[tokio::test] async fn not_found_when_not_a_directory() { - let svc = ServeDir::new("../test-files"); + let svc = ServeDir::new(TEST_FILES_DIR); // `index.html` is a file, and we are trying to request // it as a directory. @@ -360,7 +374,7 @@ async fn not_found_when_not_a_directory() { #[tokio::test] async fn not_found_precompressed() { - let svc = ServeDir::new("../test-files").precompressed_gzip(); + let svc = ServeDir::new(TEST_FILES_DIR).precompressed_gzip(); let req = Request::builder() .uri("/not-found") @@ -378,7 +392,7 @@ async fn not_found_precompressed() { #[tokio::test] async fn fallbacks_to_different_precompressed_variant_if_not_found_for_head_request() { - let svc = ServeDir::new("../test-files") + let svc = ServeDir::new(TEST_FILES_DIR) .precompressed_gzip() .precompressed_br(); @@ -399,7 +413,7 @@ async fn fallbacks_to_different_precompressed_variant_if_not_found_for_head_requ #[tokio::test] async fn fallbacks_to_different_precompressed_variant_if_not_found() { - let svc = ServeDir::new("../test-files") + let svc = ServeDir::new(TEST_FILES_DIR) .precompressed_gzip() .precompressed_br(); @@ -417,7 +431,7 @@ async fn fallbacks_to_different_precompressed_variant_if_not_found() { let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); - assert!(decompressed.starts_with("Test file")); + assert!(decompressed.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] @@ -449,7 +463,7 @@ async fn empty_directory_without_index() { #[tokio::test] async fn empty_directory_without_index_no_information_leak() { - let svc = ServeDir::new("..").append_index_html_on_directories(false); + let svc = ServeDir::new(REPO_ROOT).append_index_html_on_directories(false); let req = Request::builder() .uri("/test-files") @@ -478,7 +492,7 @@ async fn access_cjk_percent_encoded_uri_path() { // percent encoding present of 你好世界.txt let cjk_filename_encoded = "%E4%BD%A0%E5%A5%BD%E4%B8%96%E7%95%8C.txt"; - let svc = ServeDir::new("../test-files"); + let svc = ServeDir::new(TEST_FILES_DIR); let req = Request::builder() .uri(format!("/{}", cjk_filename_encoded)) @@ -494,7 +508,7 @@ async fn access_cjk_percent_encoded_uri_path() { async fn access_space_percent_encoded_uri_path() { let encoded_filename = "filename%20with%20space.txt"; - let svc = ServeDir::new("../test-files"); + let svc = ServeDir::new(TEST_FILES_DIR); let req = Request::builder() .uri(format!("/{}", encoded_filename)) @@ -508,7 +522,7 @@ async fn access_space_percent_encoded_uri_path() { #[tokio::test] async fn read_partial_empty() { - let svc = ServeDir::new("../test-files"); + let svc = ServeDir::new(TEST_FILES_DIR); let req = Request::builder() .uri("/empty.txt") @@ -527,7 +541,7 @@ async fn read_partial_empty() { #[tokio::test] async fn read_partial_in_bounds() { - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let bytes_start_incl = 9; let bytes_end_incl = 1023; @@ -541,7 +555,7 @@ async fn read_partial_in_bounds() { .unwrap(); let res = svc.oneshot(req).await.unwrap(); - let file_contents = std::fs::read("../README.md").unwrap(); + let file_contents = std::fs::read(README_PATH).unwrap(); assert_eq!(res.status(), StatusCode::PARTIAL_CONTENT); assert_eq!( res.headers()["content-length"], @@ -565,7 +579,7 @@ async fn read_partial_in_bounds() { #[tokio::test] async fn read_partial_accepts_out_of_bounds_range() { - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let bytes_start_incl = 0; let bytes_end_excl = 9999999; let requested_len = bytes_end_excl - bytes_start_incl; @@ -581,7 +595,7 @@ async fn read_partial_accepts_out_of_bounds_range() { let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::PARTIAL_CONTENT); - let file_contents = std::fs::read("../README.md").unwrap(); + let file_contents = std::fs::read(README_PATH).unwrap(); // Out of bounds range gives all bytes assert_eq!( res.headers()["content-range"], @@ -595,7 +609,7 @@ async fn read_partial_accepts_out_of_bounds_range() { #[tokio::test] async fn read_partial_errs_on_garbage_header() { - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let req = Request::builder() .uri("/README.md") .header("Range", "bad_format") @@ -603,7 +617,7 @@ async fn read_partial_errs_on_garbage_header() { .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::RANGE_NOT_SATISFIABLE); - let file_contents = std::fs::read("../README.md").unwrap(); + let file_contents = std::fs::read(README_PATH).unwrap(); assert_eq!( res.headers()["content-range"], &format!("bytes */{}", file_contents.len()) @@ -612,7 +626,7 @@ async fn read_partial_errs_on_garbage_header() { #[tokio::test] async fn read_partial_errs_on_bad_range() { - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let req = Request::builder() .uri("/README.md") .header("Range", "bytes=-1-15") @@ -620,7 +634,7 @@ async fn read_partial_errs_on_bad_range() { .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::RANGE_NOT_SATISFIABLE); - let file_contents = std::fs::read("../README.md").unwrap(); + let file_contents = std::fs::read(README_PATH).unwrap(); assert_eq!( res.headers()["content-range"], &format!("bytes */{}", file_contents.len()) @@ -629,7 +643,7 @@ async fn read_partial_errs_on_bad_range() { #[tokio::test] async fn accept_encoding_identity() { - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let req = Request::builder() .uri("/README.md") .header("Accept-Encoding", "identity") @@ -643,7 +657,7 @@ async fn accept_encoding_identity() { #[tokio::test] async fn last_modified() { - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let req = Request::builder() .uri("/README.md") .body(Body::empty()) @@ -658,7 +672,7 @@ async fn last_modified() { // -- If-Modified-Since - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let req = Request::builder() .uri("/README.md") .header(header::IF_MODIFIED_SINCE, last_modified) @@ -669,7 +683,7 @@ async fn last_modified() { assert_eq!(res.status(), StatusCode::NOT_MODIFIED); assert!(res.into_body().frame().await.is_none()); - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let req = Request::builder() .uri("/README.md") .header(header::IF_MODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT") @@ -684,7 +698,7 @@ async fn last_modified() { // -- If-Unmodified-Since - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let req = Request::builder() .uri("/README.md") .header(header::IF_UNMODIFIED_SINCE, last_modified) @@ -696,7 +710,7 @@ async fn last_modified() { let body = res.into_body().collect().await.unwrap().to_bytes(); assert_eq!(body.as_ref(), readme_bytes); - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let req = Request::builder() .uri("/README.md") .header(header::IF_UNMODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT") @@ -717,7 +731,7 @@ async fn with_fallback_svc() { )))) } - let svc = ServeDir::new("..").fallback(tower::service_fn(fallback)); + let svc = ServeDir::new(REPO_ROOT).fallback(tower::service_fn(fallback)); let req = Request::builder() .uri("/doesnt-exist") @@ -733,7 +747,7 @@ async fn with_fallback_svc() { #[tokio::test] async fn with_fallback_serve_file() { - let svc = ServeDir::new("..").fallback(ServeFile::new("../README.md")); + let svc = ServeDir::new(REPO_ROOT).fallback(ServeFile::new(README_PATH)); let req = Request::builder() .uri("/doesnt-exist") @@ -746,13 +760,13 @@ async fn with_fallback_serve_file() { let body = body_into_text(res.into_body()).await; - let contents = std::fs::read_to_string("../README.md").unwrap(); + let contents = std::fs::read_to_string(README_PATH).unwrap(); assert_eq!(body, contents); } #[tokio::test] async fn method_not_allowed() { - let svc = ServeDir::new(".."); + let svc = ServeDir::new(REPO_ROOT); let req = Request::builder() .method(Method::POST) @@ -774,7 +788,7 @@ async fn calling_fallback_on_not_allowed() { )))) } - let svc = ServeDir::new("..") + let svc = ServeDir::new(REPO_ROOT) .call_fallback_on_method_not_allowed(true) .fallback(tower::service_fn(fallback)); @@ -793,7 +807,7 @@ async fn calling_fallback_on_not_allowed() { #[tokio::test] async fn method_not_allowed_without_fallback() { - let svc = ServeDir::new("..").call_fallback_on_method_not_allowed(true); + let svc = ServeDir::new(REPO_ROOT).call_fallback_on_method_not_allowed(true); let req = Request::builder() .method(Method::POST) @@ -815,7 +829,7 @@ async fn with_fallback_svc_and_not_append_index_html_on_directories() { )))) } - let svc = ServeDir::new("..") + let svc = ServeDir::new(REPO_ROOT) .append_index_html_on_directories(false) .fallback(tower::service_fn(fallback)); @@ -838,7 +852,7 @@ async fn calls_fallback_on_invalid_paths() { Ok(res) } - let svc = ServeDir::new("..").fallback(service_fn(fallback)); + let svc = ServeDir::new(REPO_ROOT).fallback(service_fn(fallback)); let req = Request::builder() .uri("/weird_%c3%28_path") @@ -860,7 +874,7 @@ async fn calls_fallback_on_invalid_filenames() { Ok(res) } - let svc = ServeDir::new("..").fallback(service_fn(fallback)); + let svc = ServeDir::new(REPO_ROOT).fallback(service_fn(fallback)); let req = Request::builder() .uri("/invalid|path") @@ -881,7 +895,7 @@ async fn calls_fallback_on_null() { Ok(res) } - let svc = ServeDir::new("..").fallback(service_fn(fallback)); + let svc = ServeDir::new(REPO_ROOT).fallback(service_fn(fallback)); let req = Request::builder() .uri("/invalid-path%00") @@ -895,7 +909,7 @@ async fn calls_fallback_on_null() { #[tokio::test] async fn not_found_when_file_requested_with_trailing_slash() { - let svc = ServeDir::new("../test-files"); + let svc = ServeDir::new(TEST_FILES_DIR); let req = Request::builder() .uri("/index.html/") @@ -919,7 +933,7 @@ async fn file_requested_with_trailing_slash_with_fallback() { )))) } - let svc = ServeDir::new("../test-files").fallback(tower::service_fn(fallback)); + let svc = ServeDir::new(TEST_FILES_DIR).fallback(tower::service_fn(fallback)); let req = Request::builder() .uri("/index.html/") @@ -935,7 +949,7 @@ async fn file_requested_with_trailing_slash_with_fallback() { #[tokio::test] async fn directory_with_trailing_slash_appends_index_html() { - let svc = ServeDir::new("../test-files").append_index_html_on_directories(true); + let svc = ServeDir::new(TEST_FILES_DIR).append_index_html_on_directories(true); let req = Request::builder().uri("/foo/").body(Body::empty()).unwrap(); let res = svc.oneshot(req).await.unwrap(); @@ -948,7 +962,7 @@ async fn directory_with_trailing_slash_appends_index_html() { #[tokio::test] async fn root_with_trailing_slash_serves_appends_index_html() { - let svc = ServeDir::new("../test-files").append_index_html_on_directories(true); + let svc = ServeDir::new(TEST_FILES_DIR).append_index_html_on_directories(true); let req = Request::builder().uri("/").body(Body::empty()).unwrap(); let res = svc.oneshot(req).await.unwrap(); @@ -960,6 +974,7 @@ async fn root_with_trailing_slash_serves_appends_index_html() { } #[cfg(windows)] +#[allow(unsafe_code)] fn verify_windows_device(name: &str, is_positive: bool) { use std::fs::OpenOptions; use std::os::windows::io::AsRawHandle; @@ -1100,3 +1115,34 @@ fn test_build_and_validate_path_reserved_dos_names() { } } } + +// Regression test for https://github.com/tower-rs/tower-http/issues/664 +// Accept-Encoding: identity should not cause extension stripping +#[tokio::test] +async fn identity_encoding_does_not_strip_extension() { + let svc = ServeDir::new("../test-files"); + + let req = Request::builder() + .uri("/extensionless_precompressed.foobar") + .header("Accept-Encoding", "identity") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test] +async fn identity_encoding_does_not_strip_extension_head_request() { + let svc = ServeDir::new("../test-files"); + + let req = Request::builder() + .uri("/extensionless_precompressed.foobar") + .method(Method::HEAD) + .header("Accept-Encoding", "identity") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_FOUND); +} diff --git a/tower-http/src/services/fs/serve_file.rs b/tower-http/src/services/fs/serve_file.rs index ade3cd15..d3b5e2f0 100644 --- a/tower-http/src/services/fs/serve_file.rs +++ b/tower-http/src/services/fs/serve_file.rs @@ -157,9 +157,17 @@ mod tests { use tokio::io::AsyncReadExt; use tower::ServiceExt; + /// Expected prefix of the decompressed content in precompressed test files. + const EXPECTED_CONTENT_PREFIX: &str = "Test file"; + + /// Directory containing test fixture files. + const TEST_FILES_DIR: &str = "../test-files"; + /// Path to the repository README, used as a large test fixture. + const README_PATH: &str = "../README.md"; + #[tokio::test] async fn basic() { - let svc = ServeFile::new("../README.md"); + let svc = ServeFile::new(README_PATH); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); @@ -173,7 +181,7 @@ mod tests { #[tokio::test] async fn basic_with_mime() { - let svc = ServeFile::new_with_mime("../README.md", &Mime::from_str("image/jpg").unwrap()); + let svc = ServeFile::new_with_mime(README_PATH, &Mime::from_str("image/jpg").unwrap()); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); @@ -187,21 +195,22 @@ mod tests { #[tokio::test] async fn head_request() { - let svc = ServeFile::new("../test-files/precompressed.txt"); + let svc = ServeFile::new(format!("{TEST_FILES_DIR}/precompressed.txt")); let mut request = Request::new(Body::empty()); *request.method_mut() = Method::HEAD; let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); - assert_eq!(res.headers()["content-length"], "23"); + assert_eq!(res.headers()["content-length"], "10"); assert!(res.into_body().frame().await.is_none()); } #[tokio::test] async fn precompresed_head_request() { - let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_gzip(); + let svc = + ServeFile::new(format!("{TEST_FILES_DIR}/precompressed.txt")).precompressed_gzip(); let request = Request::builder() .header("Accept-Encoding", "gzip") @@ -212,14 +221,15 @@ mod tests { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); - assert_eq!(res.headers()["content-length"], "59"); + assert_eq!(res.headers()["content-length"], "30"); assert!(res.into_body().frame().await.is_none()); } #[tokio::test] async fn precompressed_gzip() { - let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_gzip(); + let svc = + ServeFile::new(format!("{TEST_FILES_DIR}/precompressed.txt")).precompressed_gzip(); let request = Request::builder() .header("Accept-Encoding", "gzip") @@ -234,12 +244,13 @@ mod tests { let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); - assert!(decompressed.starts_with("\"This is a test file!\"")); + assert!(decompressed.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn unsupported_precompression_alogrithm_fallbacks_to_uncompressed() { - let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_gzip(); + let svc = + ServeFile::new(format!("{TEST_FILES_DIR}/precompressed.txt")).precompressed_gzip(); let request = Request::builder() .header("Accept-Encoding", "br") @@ -252,12 +263,13 @@ mod tests { let body = res.into_body().collect().await.unwrap().to_bytes(); let body = String::from_utf8(body.to_vec()).unwrap(); - assert!(body.starts_with("\"This is a test file!\"")); + assert!(body.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn missing_precompressed_variant_fallbacks_to_uncompressed() { - let svc = ServeFile::new("../test-files/missing_precompressed.txt").precompressed_gzip(); + let svc = ServeFile::new(format!("{TEST_FILES_DIR}/missing_precompressed.txt")) + .precompressed_gzip(); let request = Request::builder() .header("Accept-Encoding", "gzip") @@ -271,12 +283,13 @@ mod tests { let body = res.into_body().collect().await.unwrap().to_bytes(); let body = String::from_utf8(body.to_vec()).unwrap(); - assert!(body.starts_with("Test file!")); + assert!(body.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn missing_precompressed_variant_fallbacks_to_uncompressed_head_request() { - let svc = ServeFile::new("../test-files/missing_precompressed.txt").precompressed_gzip(); + let svc = ServeFile::new(format!("{TEST_FILES_DIR}/missing_precompressed.txt")) + .precompressed_gzip(); let request = Request::builder() .header("Accept-Encoding", "gzip") @@ -286,7 +299,7 @@ mod tests { let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); - assert_eq!(res.headers()["content-length"], "11"); + assert_eq!(res.headers()["content-length"], "10"); // Uncompressed file is served because compressed version is missing assert!(res.headers().get("content-encoding").is_none()); @@ -295,7 +308,7 @@ mod tests { #[tokio::test] async fn only_precompressed_variant_existing() { - let svc = ServeFile::new("../test-files/only_gzipped.txt").precompressed_gzip(); + let svc = ServeFile::new(format!("{TEST_FILES_DIR}/only_gzipped.txt")).precompressed_gzip(); let request = Request::builder().body(Body::empty()).unwrap(); let res = svc.clone().oneshot(request).await.unwrap(); @@ -316,12 +329,12 @@ mod tests { let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); - assert!(decompressed.starts_with("\"This is a test file\"")); + assert!(decompressed.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn precompressed_br() { - let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_br(); + let svc = ServeFile::new(format!("{TEST_FILES_DIR}/precompressed.txt")).precompressed_br(); let request = Request::builder() .header("Accept-Encoding", "gzip,br") @@ -336,12 +349,13 @@ mod tests { let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); - assert!(decompressed.starts_with("\"This is a test file!\"")); + assert!(decompressed.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn precompressed_deflate() { - let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_deflate(); + let svc = + ServeFile::new(format!("{TEST_FILES_DIR}/precompressed.txt")).precompressed_deflate(); let request = Request::builder() .header("Accept-Encoding", "deflate,br") .body(Body::empty()) @@ -355,12 +369,13 @@ mod tests { let mut decoder = DeflateDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); - assert!(decompressed.starts_with("\"This is a test file!\"")); + assert!(decompressed.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn precompressed_zstd() { - let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_zstd(); + let svc = + ServeFile::new(format!("{TEST_FILES_DIR}/precompressed.txt")).precompressed_zstd(); let request = Request::builder() .header("Accept-Encoding", "zstd,br") .body(Body::empty()) @@ -374,12 +389,12 @@ mod tests { let mut decoder = ZstdDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).await.unwrap(); - assert!(decompressed.starts_with("\"This is a test file!\"")); + assert!(decompressed.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn multi_precompressed() { - let svc = ServeFile::new("../test-files/precompressed.txt") + let svc = ServeFile::new(format!("{TEST_FILES_DIR}/precompressed.txt")) .precompressed_gzip() .precompressed_br(); @@ -396,7 +411,7 @@ mod tests { let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); - assert!(decompressed.starts_with("\"This is a test file!\"")); + assert!(decompressed.starts_with(EXPECTED_CONTENT_PREFIX)); let request = Request::builder() .header("Accept-Encoding", "br") @@ -411,12 +426,12 @@ mod tests { let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); - assert!(decompressed.starts_with("\"This is a test file!\"")); + assert!(decompressed.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn with_custom_chunk_size() { - let svc = ServeFile::new("../README.md").with_buf_chunk_size(1024 * 32); + let svc = ServeFile::new(README_PATH).with_buf_chunk_size(1024 * 32); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); @@ -430,7 +445,7 @@ mod tests { #[tokio::test] async fn fallbacks_to_different_precompressed_variant_if_not_found() { - let svc = ServeFile::new("../test-files/precompressed_br.txt") + let svc = ServeFile::new(format!("{TEST_FILES_DIR}/precompressed_br.txt")) .precompressed_gzip() .precompressed_deflate() .precompressed_br(); @@ -448,12 +463,12 @@ mod tests { let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); - assert!(decompressed.starts_with("Test file")); + assert!(decompressed.starts_with(EXPECTED_CONTENT_PREFIX)); } #[tokio::test] async fn fallbacks_to_different_precompressed_variant_if_not_found_head_request() { - let svc = ServeFile::new("../test-files/precompressed_br.txt") + let svc = ServeFile::new(format!("{TEST_FILES_DIR}/precompressed_br.txt")) .precompressed_gzip() .precompressed_deflate() .precompressed_br(); @@ -498,7 +513,7 @@ mod tests { #[tokio::test] async fn last_modified() { - let svc = ServeFile::new("../README.md"); + let svc = ServeFile::new(README_PATH); let req = Request::builder().body(Body::empty()).unwrap(); let res = svc.oneshot(req).await.unwrap(); @@ -512,7 +527,7 @@ mod tests { // -- If-Modified-Since - let svc = ServeFile::new("../README.md"); + let svc = ServeFile::new(README_PATH); let req = Request::builder() .header(header::IF_MODIFIED_SINCE, last_modified) .body(Body::empty()) @@ -522,7 +537,7 @@ mod tests { assert_eq!(res.status(), StatusCode::NOT_MODIFIED); assert!(res.into_body().frame().await.is_none()); - let svc = ServeFile::new("../README.md"); + let svc = ServeFile::new(README_PATH); let req = Request::builder() .header(header::IF_MODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT") .body(Body::empty()) @@ -536,7 +551,7 @@ mod tests { // -- If-Unmodified-Since - let svc = ServeFile::new("../README.md"); + let svc = ServeFile::new(README_PATH); let req = Request::builder() .header(header::IF_UNMODIFIED_SINCE, last_modified) .body(Body::empty()) @@ -547,7 +562,7 @@ mod tests { let body = res.into_body().collect().await.unwrap().to_bytes(); assert_eq!(body.as_ref(), readme_bytes); - let svc = ServeFile::new("../README.md"); + let svc = ServeFile::new(README_PATH); let req = Request::builder() .header(header::IF_UNMODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT") .body(Body::empty()) diff --git a/tower-http/src/trace/body.rs b/tower-http/src/trace/body.rs index df82561c..861e5c46 100644 --- a/tower-http/src/trace/body.rs +++ b/tower-http/src/trace/body.rs @@ -43,9 +43,9 @@ where self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { - let this = self.project(); + let mut this = self.project(); let _guard = this.span.enter(); - let result = ready!(this.inner.poll_frame(cx)); + let result = ready!(this.inner.as_mut().poll_frame(cx)); let latency = this.start.elapsed(); *this.start = Instant::now(); @@ -77,6 +77,22 @@ where Err(frame) => frame, }; + // If the inner body signals end-of-stream after this frame, + // fire on_eos now since the consumer may not poll again (e.g. + // when Content-Length is exact). + if this.inner.is_end_stream() { + if let Some((classify_eos, mut on_failure)) = + this.classify_eos.take().zip(this.on_failure.take()) + { + if let Err(failure_class) = classify_eos.classify_eos(None) { + on_failure.on_failure(failure_class, latency, this.span); + } + } + if let Some((on_eos, stream_start)) = this.on_eos.take() { + on_eos.on_eos(None, stream_start.elapsed(), this.span); + } + } + Poll::Ready(Some(Ok(frame))) } Some(Err(err)) => { diff --git a/tower-http/src/trace/mod.rs b/tower-http/src/trace/mod.rs index 286d7cd3..2c9aca70 100644 --- a/tower-http/src/trace/mod.rs +++ b/tower-http/src/trace/mod.rs @@ -412,47 +412,79 @@ pub type HttpMakeClassifier = SharedClassifier; pub type GrpcMakeClassifier = SharedClassifier; macro_rules! event_dynamic_lvl { - ( $(target: $target:expr,)? $(parent: $parent:expr,)? $lvl:expr, $($tt:tt)* ) => { + ( target: $target:expr, parent: $parent:expr, $lvl:expr, $($tt:tt)* ) => { match $lvl { tracing::Level::ERROR => { - tracing::event!( - $(target: $target,)? - $(parent: $parent,)? - tracing::Level::ERROR, - $($tt)* - ); + tracing::event!(target: $target, parent: $parent, tracing::Level::ERROR, $($tt)*); } tracing::Level::WARN => { - tracing::event!( - $(target: $target,)? - $(parent: $parent,)? - tracing::Level::WARN, - $($tt)* - ); + tracing::event!(target: $target, parent: $parent, tracing::Level::WARN, $($tt)*); } tracing::Level::INFO => { - tracing::event!( - $(target: $target,)? - $(parent: $parent,)? - tracing::Level::INFO, - $($tt)* - ); + tracing::event!(target: $target, parent: $parent, tracing::Level::INFO, $($tt)*); } tracing::Level::DEBUG => { - tracing::event!( - $(target: $target,)? - $(parent: $parent,)? - tracing::Level::DEBUG, - $($tt)* - ); + tracing::event!(target: $target, parent: $parent, tracing::Level::DEBUG, $($tt)*); } tracing::Level::TRACE => { - tracing::event!( - $(target: $target,)? - $(parent: $parent,)? - tracing::Level::TRACE, - $($tt)* - ); + tracing::event!(target: $target, parent: $parent, tracing::Level::TRACE, $($tt)*); + } + } + }; + ( target: $target:expr, $lvl:expr, $($tt:tt)* ) => { + match $lvl { + tracing::Level::ERROR => { + tracing::event!(target: $target, tracing::Level::ERROR, $($tt)*); + } + tracing::Level::WARN => { + tracing::event!(target: $target, tracing::Level::WARN, $($tt)*); + } + tracing::Level::INFO => { + tracing::event!(target: $target, tracing::Level::INFO, $($tt)*); + } + tracing::Level::DEBUG => { + tracing::event!(target: $target, tracing::Level::DEBUG, $($tt)*); + } + tracing::Level::TRACE => { + tracing::event!(target: $target, tracing::Level::TRACE, $($tt)*); + } + } + }; + ( parent: $parent:expr, $lvl:expr, $($tt:tt)* ) => { + match $lvl { + tracing::Level::ERROR => { + tracing::event!(parent: $parent, tracing::Level::ERROR, $($tt)*); + } + tracing::Level::WARN => { + tracing::event!(parent: $parent, tracing::Level::WARN, $($tt)*); + } + tracing::Level::INFO => { + tracing::event!(parent: $parent, tracing::Level::INFO, $($tt)*); + } + tracing::Level::DEBUG => { + tracing::event!(parent: $parent, tracing::Level::DEBUG, $($tt)*); + } + tracing::Level::TRACE => { + tracing::event!(parent: $parent, tracing::Level::TRACE, $($tt)*); + } + } + }; + ( $lvl:expr, $($tt:tt)* ) => { + match $lvl { + tracing::Level::ERROR => { + tracing::event!(tracing::Level::ERROR, $($tt)*); + } + tracing::Level::WARN => { + tracing::event!(tracing::Level::WARN, $($tt)*); + } + tracing::Level::INFO => { + tracing::event!(tracing::Level::INFO, $($tt)*); + } + tracing::Level::DEBUG => { + tracing::event!(tracing::Level::DEBUG, $($tt)*); + } + tracing::Level::TRACE => { + tracing::event!(tracing::Level::TRACE, $($tt)*); } } }; @@ -717,6 +749,109 @@ mod tests { assert_eq!(1, ON_FAILURE.load(Ordering::SeqCst), "failure"); } + #[tokio::test] + async fn on_eos_fires_for_content_length_body() { + // Simulates the scenario where a consumer stops polling after receiving + // all bytes (as hyper does when Content-Length is exact). We poll only + // the data frame and never poll to None. + use http_body_util::BodyExt; + + static ON_BODY_CHUNK_COUNT: Lazy = Lazy::new(|| AtomicU32::new(0)); + static ON_EOS: Lazy = Lazy::new(|| AtomicU32::new(0)); + + let trace_layer = TraceLayer::new_for_http() + .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| { + ON_BODY_CHUNK_COUNT.fetch_add(1, Ordering::SeqCst); + }) + .on_eos( + |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| { + ON_EOS.fetch_add(1, Ordering::SeqCst); + }, + ); + + let mut svc = ServiceBuilder::new().layer(trace_layer).service_fn(echo); + + let res = svc + .ready() + .await + .unwrap() + .call(Request::new(Body::from("hello"))) + .await + .unwrap(); + + let mut body = res.into_body(); + + // Poll only the data frame (simulating a content-length aware consumer) + let frame = body.frame().await.unwrap().unwrap(); + assert!(frame.data_ref().is_some()); + + // on_eos should have fired immediately after the data frame since + // is_end_stream() is true for Full bodies after yielding their data. + assert_eq!(1, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk"); + assert_eq!(1, ON_EOS.load(Ordering::SeqCst), "eos"); + } + + #[tokio::test] + async fn on_eos_fires_for_streaming_body_on_none() { + // Streaming bodies (no content-length) don't report is_end_stream() + // until polled to None. Verify on_eos still fires via the None path. + static ON_EOS: Lazy = Lazy::new(|| AtomicU32::new(0)); + + let trace_layer = TraceLayer::new_for_http().on_eos( + |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| { + ON_EOS.fetch_add(1, Ordering::SeqCst); + }, + ); + + let mut svc = ServiceBuilder::new() + .layer(trace_layer) + .service_fn(streaming_body); + + let res = svc + .ready() + .await + .unwrap() + .call(Request::new(Body::empty())) + .await + .unwrap(); + + crate::test_helpers::to_bytes(res.into_body()) + .await + .unwrap(); + assert_eq!(1, ON_EOS.load(Ordering::SeqCst), "eos"); + } + + #[tokio::test] + async fn on_eos_not_called_twice() { + // When is_end_stream() fires on_eos after a data frame, a subsequent + // poll returning None must not fire on_eos again. + static ON_EOS: Lazy = Lazy::new(|| AtomicU32::new(0)); + + let trace_layer = TraceLayer::new_for_http().on_eos( + |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| { + ON_EOS.fetch_add(1, Ordering::SeqCst); + }, + ); + + let mut svc = ServiceBuilder::new().layer(trace_layer).service_fn(echo); + + let res = svc + .ready() + .await + .unwrap() + .call(Request::new(Body::from("hello"))) + .await + .unwrap(); + + // Consume the body fully (polls data frame then None) + crate::test_helpers::to_bytes(res.into_body()) + .await + .unwrap(); + + // on_eos should fire exactly once, not twice + assert_eq!(1, ON_EOS.load(Ordering::SeqCst), "eos"); + } + async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } @@ -814,4 +949,131 @@ mod tests { "error" } } + + /// Regression test for https://github.com/tower-rs/tower-http/issues/655 + /// + /// Reproduces the reported bug: when a subscriber's filter disables the + /// request span but still enables events, the events appear without any + /// span context. This happens because `Span::enter()` on a disabled span + /// is a no-op, so events relying on ambient context have no parent. + /// + /// The fix (using explicit `parent: span`) ensures events always reference + /// the request span, even when it's disabled. A subscriber that records + /// disabled spans will still see the correct parent relationship. + #[test] + fn events_have_span_context_when_span_is_disabled() { + use std::sync::{Arc, Mutex}; + use tracing::subscriber::with_default; + use tracing_subscriber::{layer::SubscriberExt, registry::LookupSpan, Layer as _}; + + /// A filter that disables spans (by rejecting at the span level) + /// but allows all events through. This simulates the scenario where + /// a per-layer EnvFilter disables the request span's callsite. + struct DisableSpansFilter; + + impl tracing_subscriber::layer::Filter for DisableSpansFilter { + fn enabled( + &self, + meta: &tracing::Metadata<'_>, + _cx: &tracing_subscriber::layer::Context<'_, S>, + ) -> bool { + // Disable spans, keep events + !meta.is_span() + } + } + + /// Records (event_message, has_any_parent) pairs. + #[derive(Clone)] + struct RecordingLayer { + events: Arc>>, + } + + impl tracing_subscriber::Layer for RecordingLayer + where + S: tracing::Subscriber + for<'a> LookupSpan<'a>, + { + fn on_event( + &self, + event: &tracing::Event<'_>, + ctx: tracing_subscriber::layer::Context<'_, S>, + ) { + let mut msg = String::new(); + event.record(&mut MessageVisitor(&mut msg)); + + // Check if the event has ANY parent: explicit or contextual + let has_parent = event.parent().is_some() || ctx.event_span(event).is_some(); + + self.events.lock().unwrap().push((msg, has_parent)); + } + } + + struct MessageVisitor<'a>(&'a mut String); + impl tracing::field::Visit for MessageVisitor<'_> { + fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { + if field.name() == "message" { + *self.0 = format!("{:?}", value); + } + } + } + + let events = Arc::new(Mutex::new(Vec::new())); + let layer = RecordingLayer { + events: events.clone(), + }; + let subscriber = tracing_subscriber::registry().with(layer.with_filter(DisableSpansFilter)); + + // Use with_default to guarantee cleanup even on panic, avoiding + // cross-test subscriber pollution. + with_default(subscriber, || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async { + let mut svc = ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .service_fn(echo); + + let res = svc + .ready() + .await + .unwrap() + .call(Request::new(Body::from("test"))) + .await + .unwrap(); + + crate::test_helpers::to_bytes(res.into_body()) + .await + .unwrap(); + }); + }); + + let events = events.lock().unwrap(); + let request_events: Vec<_> = events + .iter() + .filter(|(msg, _)| { + msg.contains("started processing request") + || msg.contains("finished processing request") + }) + .collect(); + + assert!( + request_events.len() >= 2, + "expected on_request and on_response events to fire" + ); + + // The bug: without explicit parent, these events have no span context + // at all when the request span is disabled. With the fix, they still + // reference the span (even though it's disabled). + for (msg, has_parent) in &request_events { + assert!( + *has_parent, + "event {:?} has no span context. When the request span is \ + disabled by a filter, events must still reference it via \ + explicit parent so subscribers can associate them correctly.", + msg + ); + } + } } diff --git a/tower-http/src/trace/on_eos.rs b/tower-http/src/trace/on_eos.rs index 95788d7e..92abb7eb 100644 --- a/tower-http/src/trace/on_eos.rs +++ b/tower-http/src/trace/on_eos.rs @@ -84,7 +84,7 @@ impl DefaultOnEos { } impl OnEos for DefaultOnEos { - fn on_eos(self, trailers: Option<&HeaderMap>, stream_duration: Duration, _span: &Span) { + fn on_eos(self, trailers: Option<&HeaderMap>, stream_duration: Duration, span: &Span) { let stream_duration = Latency { unit: self.latency_unit, duration: stream_duration, @@ -100,6 +100,6 @@ impl OnEos for DefaultOnEos { } }); - event_dynamic_lvl!(self.level, %stream_duration, status, "end of stream"); + event_dynamic_lvl!(parent: span, self.level, %stream_duration, status, "end of stream"); } } diff --git a/tower-http/src/trace/on_failure.rs b/tower-http/src/trace/on_failure.rs index 7dfa186d..4ed67f5b 100644 --- a/tower-http/src/trace/on_failure.rs +++ b/tower-http/src/trace/on_failure.rs @@ -85,12 +85,13 @@ impl OnFailure for DefaultOnFailure where FailureClass: fmt::Display, { - fn on_failure(&mut self, failure_classification: FailureClass, latency: Duration, _: &Span) { + fn on_failure(&mut self, failure_classification: FailureClass, latency: Duration, span: &Span) { let latency = Latency { unit: self.latency_unit, duration: latency, }; event_dynamic_lvl!( + parent: span, self.level, classification = %failure_classification, %latency, diff --git a/tower-http/src/trace/on_request.rs b/tower-http/src/trace/on_request.rs index 07de1893..c0dcb56a 100644 --- a/tower-http/src/trace/on_request.rs +++ b/tower-http/src/trace/on_request.rs @@ -76,7 +76,7 @@ impl DefaultOnRequest { } impl OnRequest for DefaultOnRequest { - fn on_request(&mut self, _: &Request, _: &Span) { - event_dynamic_lvl!(self.level, "started processing request"); + fn on_request(&mut self, _: &Request, span: &Span) { + event_dynamic_lvl!(parent: span, self.level, "started processing request"); } } diff --git a/tower-http/src/trace/on_response.rs b/tower-http/src/trace/on_response.rs index edcf498b..a4d2249c 100644 --- a/tower-http/src/trace/on_response.rs +++ b/tower-http/src/trace/on_response.rs @@ -102,7 +102,7 @@ impl DefaultOnResponse { } impl OnResponse for DefaultOnResponse { - fn on_response(self, response: &Response, latency: Duration, _: &Span) { + fn on_response(self, response: &Response, latency: Duration, span: &Span) { let latency = Latency { unit: self.latency_unit, duration: latency, @@ -112,6 +112,7 @@ impl OnResponse for DefaultOnResponse { .then(|| tracing::field::debug(response.headers())); event_dynamic_lvl!( + parent: span, self.level, %latency, status = status(response),