diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index a3b85da1..8fca86fb 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -144,6 +144,7 @@ impl StreamableHttpClient for reqwest::Client { } request = apply_custom_headers(request, custom_headers)?; + let session_was_attached = session_id.is_some(); if let Some(session_id) = session_id { request = request.header(HEADER_SESSION_ID, session_id.as_ref()); } @@ -186,6 +187,9 @@ impl StreamableHttpClient for reqwest::Client { ) { return Ok(StreamableHttpPostResponse::Accepted); } + if status == reqwest::StatusCode::NOT_FOUND && session_was_attached { + return Err(StreamableHttpError::SessionExpired); + } if !status.is_success() { let body = response .text() diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 85915c97..bbb98bf3 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -11,7 +11,10 @@ use tracing::debug; use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStreamReconnect}; use crate::{ RoleClient, - model::{ClientJsonRpcMessage, ServerJsonRpcMessage, ServerResult}, + model::{ + ClientJsonRpcMessage, ClientNotification, InitializedNotification, ServerJsonRpcMessage, + ServerResult, + }, transport::{ common::client_side_sse::SseAutoReconnectStream, worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport}, @@ -79,6 +82,8 @@ pub enum StreamableHttpError { InsufficientScope(InsufficientScopeError), #[error("Header name '{0}' is reserved and conflicts with default headers")] ReservedHeaderConflict(String), + #[error("Session expired (HTTP 404)")] + SessionExpired, } #[derive(Debug, Clone, Error)] @@ -307,6 +312,69 @@ impl StreamableHttpClientWorker { } Ok(()) } + + /// Performs a transparent re-initialization handshake after a session-expired 404. + /// + /// Takes an owned clone of the client (avoiding `&self` across `.await` so the + /// future remains `Send` without requiring `C: Sync`). POSTs the saved + /// initialize request without a session ID, extracts the new session ID and + /// protocol version, sends `notifications/initialized`, and returns the new + /// `(session_id, protocol_headers)` pair. The init result message is **not** + /// forwarded to the handler because the handler already processed the original + /// initialization. + async fn perform_reinitialization( + client: C, + saved_init_request: ClientJsonRpcMessage, + uri: Arc, + auth_header: Option, + custom_headers: HashMap, + ) -> Result<(Option>, HashMap), StreamableHttpError> + { + let (init_msg, new_session_id_str) = client + .post_message( + uri.clone(), + saved_init_request, + None, + auth_header.clone(), + custom_headers.clone(), + ) + .await? + .expect_initialized::() + .await?; + + let new_session_id: Option> = new_session_id_str.map(|s| Arc::from(s.as_str())); + + // Start from custom_headers, then inject the negotiated MCP-Protocol-Version + // so all subsequent requests carry the right version (MCP 2025-06-18 spec). + let mut new_protocol_headers = custom_headers; + if let ServerJsonRpcMessage::Response(response) = &init_msg { + if let ServerResult::InitializeResult(init_result) = &response.result { + if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) { + new_protocol_headers + .insert(HeaderName::from_static("mcp-protocol-version"), hv); + } + } + } + + let initialized_notification = ClientJsonRpcMessage::notification( + ClientNotification::InitializedNotification(InitializedNotification { + method: Default::default(), + extensions: Default::default(), + }), + ); + client + .post_message( + uri, + initialized_notification, + new_session_id.clone(), + auth_header, + new_protocol_headers.clone(), + ) + .await? + .expect_accepted_or_json::()?; + + Ok((new_session_id, new_protocol_headers)) + } } impl Worker for StreamableHttpClientWorker { @@ -338,14 +406,15 @@ impl Worker for StreamableHttpClientWorker { responder, message: initialize_request, } = context.recv_from_handler().await?; + let saved_init_request = initialize_request.clone(); let (message, session_id) = match self .client .post_message( config.uri.clone(), initialize_request, None, - self.config.auth_header, - self.config.custom_headers, + config.auth_header.clone(), + config.custom_headers.clone(), ) .await { @@ -364,7 +433,7 @@ impl Worker for StreamableHttpClientWorker { )); } }; - let session_id: Option> = if let Some(session_id) = session_id { + let mut session_id: Option> = if let Some(session_id) = session_id { Some(session_id.into()) } else { if !self.config.allow_stateless { @@ -378,7 +447,7 @@ impl Worker for StreamableHttpClientWorker { // Extract the negotiated protocol version from the init response // and build a custom headers map that includes MCP-Protocol-Version // for all subsequent HTTP requests (per MCP 2025-06-18 spec). - let protocol_headers = { + let mut protocol_headers = { let mut headers = config.custom_headers.clone(); if let ServerJsonRpcMessage::Response(response) = &message { if let ServerResult::InitializeResult(init_result) = &response.result { @@ -392,7 +461,7 @@ impl Worker for StreamableHttpClientWorker { }; // Store session info for cleanup when run() exits (not spawned, so cleanup completes before close() returns) - let session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo { + let mut session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo { client: self.client.clone(), uri: config.uri.clone(), session_id: sid.clone(), @@ -516,17 +585,171 @@ impl Worker for StreamableHttpClientWorker { match event { Event::ClientMessage(send_request) => { let WorkerSendRequest { message, responder } = send_request; + // Pass a clone to the first attempt so `message` is retained for a + // potential re-init retry. `post_message` takes ownership and the + // trait cannot be changed, so the clone is unavoidable. let response = self .client .post_message( config.uri.clone(), - message, + message.clone(), session_id.clone(), config.auth_header.clone(), protocol_headers.clone(), ) .await; let send_result = match response { + Err(StreamableHttpError::SessionExpired) => { + // The server discarded the session (HTTP 404). Perform a + // fresh handshake once and replay the original message. + tracing::info!( + "session expired (HTTP 404), attempting transparent re-initialization" + ); + match Self::perform_reinitialization( + self.client.clone(), + saved_init_request.clone(), + config.uri.clone(), + config.auth_header.clone(), + config.custom_headers.clone(), + ) + .await + { + Ok((new_session_id, new_protocol_headers)) => { + // Old streams hold the stale session ID; abort them + // so the new standalone SSE stream takes over. + streams.abort_all(); + + session_id = new_session_id; + protocol_headers = new_protocol_headers; + session_cleanup_info = + session_id.as_ref().map(|sid| SessionCleanupInfo { + client: self.client.clone(), + uri: config.uri.clone(), + session_id: sid.clone(), + auth_header: config.auth_header.clone(), + protocol_headers: protocol_headers.clone(), + }); + + if let Some(new_sid) = &session_id { + let client = self.client.clone(); + let uri = config.uri.clone(); + let new_sid = new_sid.clone(); + let auth_header = config.auth_header.clone(); + let retry_config = self.config.retry_config.clone(); + let sse_tx = sse_worker_tx.clone(); + let task_ct = transport_task_ct.clone(); + let config_uri = config.uri.clone(); + let config_auth = config.auth_header.clone(); + let spawn_headers = protocol_headers.clone(); + streams.spawn(async move { + match client + .get_stream( + uri, + new_sid.clone(), + None, + auth_header.clone(), + spawn_headers.clone(), + ) + .await + { + Ok(stream) => { + let sse_stream = SseAutoReconnectStream::new( + stream, + StreamableHttpClientReconnect { + client: client.clone(), + session_id: new_sid, + uri: config_uri, + auth_header: config_auth, + custom_headers: spawn_headers, + }, + retry_config, + ); + Self::execute_sse_stream( + sse_stream, + sse_tx, + false, + task_ct.child_token(), + ) + .await + } + Err(StreamableHttpError::ServerDoesNotSupportSse) => { + tracing::debug!( + "server doesn't support sse after re-init" + ); + Ok(()) + } + Err(e) => { + tracing::error!( + "fail to get common stream after re-init: {e}" + ); + Err(e) + } + } + }); + } + + let retry_response = self + .client + .post_message( + config.uri.clone(), + message, + session_id.clone(), + config.auth_header.clone(), + protocol_headers.clone(), + ) + .await; + match retry_response { + Err(e) => Err(e), + Ok(StreamableHttpPostResponse::Accepted) => { + tracing::trace!( + "client message accepted after re-init" + ); + Ok(()) + } + Ok(StreamableHttpPostResponse::Json(msg, ..)) => { + context.send_to_handler(msg).await?; + Ok(()) + } + Ok(StreamableHttpPostResponse::Sse(stream, ..)) => { + if let Some(sid) = &session_id { + let sse_stream = SseAutoReconnectStream::new( + stream, + StreamableHttpClientReconnect { + client: self.client.clone(), + session_id: sid.clone(), + uri: config.uri.clone(), + auth_header: config.auth_header.clone(), + custom_headers: protocol_headers.clone(), + }, + self.config.retry_config.clone(), + ); + streams.spawn(Self::execute_sse_stream( + sse_stream, + sse_worker_tx.clone(), + true, + transport_task_ct.child_token(), + )); + } else { + let sse_stream = + SseAutoReconnectStream::never_reconnect( + stream, + StreamableHttpError::::UnexpectedEndOfStream, + ); + streams.spawn(Self::execute_sse_stream( + sse_stream, + sse_worker_tx.clone(), + true, + transport_task_ct.child_token(), + )); + } + tracing::trace!("got new sse stream after re-init"); + Ok(()) + } + } + } + Err(reinit_err) => Err(reinit_err), + } + } Err(e) => Err(e), Ok(StreamableHttpPostResponse::Accepted) => { tracing::trace!("client message accepted"); diff --git a/crates/rmcp/tests/test_streamable_http_stale_session.rs b/crates/rmcp/tests/test_streamable_http_stale_session.rs index a37a0895..11f1a4da 100644 --- a/crates/rmcp/tests/test_streamable_http_stale_session.rs +++ b/crates/rmcp/tests/test_streamable_http_stale_session.rs @@ -7,9 +7,13 @@ use std::{collections::HashMap, sync::Arc}; use rmcp::{ + ServiceExt, model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId}, transport::{ - streamable_http_client::{StreamableHttpClient, StreamableHttpError}, + StreamableHttpClientTransport, + streamable_http_client::{ + StreamableHttpClient, StreamableHttpClientTransportConfig, StreamableHttpError, + }, streamable_http_server::{ StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, }, @@ -76,18 +80,10 @@ async fn test_stale_session_id_returns_status_aware_error() -> anyhow::Result<() assert_eq!(raw_response.status(), reqwest::StatusCode::NOT_FOUND); match result { - Err(StreamableHttpError::UnexpectedServerResponse(message)) => { - let message = message.to_string(); - assert!( - message.contains("404"), - "error should include HTTP status code, got: {message}" - ); - assert!( - message.to_ascii_lowercase().contains("session not found"), - "error should include session-not-found hint, got: {message}" - ); + Err(StreamableHttpError::SessionExpired) => { + // Expected: post_message detects 404 with a session ID and returns SessionExpired } - other => panic!("expected UnexpectedServerResponse, got: {other:?}"), + other => panic!("expected SessionExpired, got: {other:?}"), } ct.cancel(); @@ -95,3 +91,82 @@ async fn test_stale_session_id_returns_status_aware_error() -> anyhow::Result<() Ok(()) } + +/// Verify that when the server loses a session (returns HTTP 404), the client +/// transparently re-initializes and the original request succeeds. +#[tokio::test] +async fn test_transparent_reinitialization_on_session_expiry() -> anyhow::Result<()> { + let ct = CancellationToken::new(); + let session_manager = Arc::new(LocalSessionManager::default()); + + let service = StreamableHttpService::new( + || Ok(Calculator::new()), + session_manager.clone(), + StreamableHttpServerConfig { + stateful_mode: true, + sse_keep_alive: None, + cancellation_token: ct.child_token(), + ..Default::default() + }, + ); + + let router = axum::Router::new().nest_service("/mcp", service); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let server_handle = tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + + // Connect a full client transport (this performs initialize + notifications/initialized) + let transport = StreamableHttpClientTransport::from_config( + StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")), + ); + let client = ().serve(transport).await?; + + // Verify the session is established: list_all_resources() succeeds + let _resources = client.list_all_resources().await?; + + // Capture the current session ID from the server + let original_session_id = { + let sessions = session_manager.sessions.read().await; + sessions + .keys() + .next() + .cloned() + .expect("session should exist") + }; + + // Force session expiry by removing all sessions from the server-side manager + { + let mut sessions = session_manager.sessions.write().await; + sessions.clear(); + } + + // This call should trigger transparent re-initialization and still succeed + let _resources_after = client.list_all_resources().await?; + + // Verify the server created a new session with a different ID + { + let sessions = session_manager.sessions.read().await; + let new_session_id = sessions + .keys() + .next() + .expect("new session should exist after re-initialization"); + assert_ne!( + new_session_id, &original_session_id, + "new session ID should differ from the original" + ); + } + + let _ = client.cancel().await; + ct.cancel(); + server_handle.await?; + + Ok(()) +}