diff --git a/lambda-runtime/Cargo.toml b/lambda-runtime/Cargo.toml index e829a72a..9b02931e 100644 --- a/lambda-runtime/Cargo.toml +++ b/lambda-runtime/Cargo.toml @@ -73,6 +73,8 @@ idna_adapter = "=1.2.0" lambda_runtime = { path = ".", features = ["tracing", "graceful-shutdown"] } pin-project-lite = { workspace = true } tracing-appender = "0.2" +tracing-capture = "0.1.0" +tracing-subscriber = { version = "0.3", features = ["registry"] } [package.metadata.docs.rs] all-features = true diff --git a/lambda-runtime/src/runtime.rs b/lambda-runtime/src/runtime.rs index 1175b023..e9a6bb27 100644 --- a/lambda-runtime/src/runtime.rs +++ b/lambda-runtime/src/runtime.rs @@ -908,4 +908,176 @@ mod endpoint_tests { server_handle.abort(); Ok(()) } + + #[tokio::test] + #[cfg(feature = "experimental-concurrency")] + async fn test_concurrent_structured_logging_isolation() -> Result<(), Error> { + use std::collections::HashSet; + use tracing::info; + use tracing_capture::{CaptureLayer, SharedStorage}; + use tracing_subscriber::layer::SubscriberExt; + + let storage = SharedStorage::default(); + let subscriber = tracing_subscriber::registry().with(CaptureLayer::new(&storage)); + tracing::subscriber::set_global_default(subscriber).unwrap(); + + let request_count = Arc::new(AtomicUsize::new(0)); + let done = Arc::new(tokio::sync::Notify::new()); + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let base: http::Uri = format!("http://{addr}").parse()?; + + let server_handle = { + let request_count = request_count.clone(); + let done = done.clone(); + tokio::spawn(async move { + loop { + let (tcp, _) = match listener.accept().await { + Ok(v) => v, + Err(_) => return, + }; + + let request_count = request_count.clone(); + let done = done.clone(); + let service = service_fn(move |req: Request| { + let request_count = request_count.clone(); + let done = done.clone(); + async move { + let (parts, body) = req.into_parts(); + if parts.method == Method::POST { + let _ = body.collect().await; + } + + if parts.method == Method::GET && parts.uri.path() == "/2018-06-01/runtime/invocation/next" + { + let count = request_count.fetch_add(1, Ordering::SeqCst); + if count < 300 { + let request_id = format!("test-request-{}", count + 1); + let res = Response::builder() + .status(StatusCode::OK) + .header("lambda-runtime-aws-request-id", &request_id) + .header("lambda-runtime-deadline-ms", "9999999999999") + .body(Full::new(Bytes::from_static(b"{}"))) + .unwrap(); + return Ok::<_, Infallible>(res); + } else { + done.notify_one(); + let res = Response::builder() + .status(StatusCode::NO_CONTENT) + .body(Full::new(Bytes::new())) + .unwrap(); + return Ok::<_, Infallible>(res); + } + } + + if parts.method == Method::POST && parts.uri.path().contains("/response") { + let res = Response::builder() + .status(StatusCode::OK) + .body(Full::new(Bytes::new())) + .unwrap(); + return Ok::<_, Infallible>(res); + } + + let res = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Full::new(Bytes::new())) + .unwrap(); + Ok::<_, Infallible>(res) + } + }); + + let io = TokioIo::new(tcp); + tokio::spawn(async move { + let _ = ServerBuilder::new(TokioExecutor::new()) + .serve_connection(io, service) + .await; + }); + } + }) + }; + + async fn test_handler(event: crate::LambdaEvent) -> Result<(), Error> { + let request_id = &event.context.request_id; + info!(observed_request_id = request_id); + tokio::time::sleep(Duration::from_millis(100)).await; + Ok(()) + } + + let handler = crate::service_fn(test_handler); + let client = Arc::new(Client::builder().with_endpoint(base).build()?); + + // Add tracing layer to capture span fields + use crate::layers::trace::TracingLayer; + use tower::ServiceBuilder; + let service = ServiceBuilder::new() + .layer(TracingLayer::new()) + .service(wrap_handler(handler, client.clone())); + + let runtime = Runtime { + client: client.clone(), + config: Arc::new(Config { + function_name: "test_fn".to_string(), + memory: 128, + version: "1".to_string(), + log_stream: "test_stream".to_string(), + log_group: "test_log".to_string(), + }), + service, + concurrency_limit: 3, + }; + + let runtime_handle = tokio::spawn(async move { runtime.run_concurrent().await }); + + done.notified().await; + // Give handlers time to complete after server signals done + tokio::time::sleep(Duration::from_millis(500)).await; + + runtime_handle.abort(); + server_handle.abort(); + + let storage = storage.lock(); + let events: Vec<_> = storage + .all_events() + .filter(|e| e.value("observed_request_id").is_some()) + .collect(); + + assert!( + events.len() >= 300, + "Should have at least 300 log entries, got {}", + events.len() + ); + + let mut seen_ids = HashSet::new(); + for event in &events { + let observed_id = event["observed_request_id"].as_str().unwrap(); + + // Find the parent "Lambda runtime invoke" span and get its requestId + let span_request_id = event + .ancestors() + .find(|s| s.metadata().name() == "Lambda runtime invoke") + .and_then(|s| s.value("requestId")) + .and_then(|v| v.as_str()) + .expect("Event should have a Lambda runtime invoke ancestor with requestId"); + + assert!( + observed_id.starts_with("test-request-"), + "Request ID should match pattern: {}", + observed_id + ); + assert!( + seen_ids.insert(observed_id.to_string()), + "Request ID should be unique: {}", + observed_id + ); + + // Verify span request ID matches logged request ID + assert_eq!( + observed_id, span_request_id, + "Span request ID should match logged request ID: span={}, logged={}", + span_request_id, observed_id + ); + } + + Ok(()) + } }