Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ impl StreamableHttpClient for reqwest::Client {
}

request = apply_custom_headers(request, custom_headers)?;
let session_was_attached = session_id.is_some();
if let Some(session_id) = session_id {
request = request.header(HEADER_SESSION_ID, session_id.as_ref());
}
Expand Down Expand Up @@ -186,6 +187,9 @@ impl StreamableHttpClient for reqwest::Client {
) {
return Ok(StreamableHttpPostResponse::Accepted);
}
if status == reqwest::StatusCode::NOT_FOUND && session_was_attached {
return Err(StreamableHttpError::SessionExpired);
}
if !status.is_success() {
let body = response
.text()
Expand Down
237 changes: 230 additions & 7 deletions crates/rmcp/src/transport/streamable_http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStreamReconnect};
use crate::{
RoleClient,
model::{ClientJsonRpcMessage, ServerJsonRpcMessage, ServerResult},
model::{
ClientJsonRpcMessage, ClientNotification, InitializedNotification, ServerJsonRpcMessage,
ServerResult,
},
transport::{
common::client_side_sse::SseAutoReconnectStream,
worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport},
Expand Down Expand Up @@ -79,6 +82,8 @@
InsufficientScope(InsufficientScopeError),
#[error("Header name '{0}' is reserved and conflicts with default headers")]
ReservedHeaderConflict(String),
#[error("Session expired (HTTP 404)")]
SessionExpired,
}

#[derive(Debug, Clone, Error)]
Expand Down Expand Up @@ -307,6 +312,69 @@
}
Ok(())
}

/// Performs a transparent re-initialization handshake after a session-expired 404.
///
/// Takes an owned clone of the client (avoiding `&self` across `.await` so the
/// future remains `Send` without requiring `C: Sync`). POSTs the saved
/// initialize request without a session ID, extracts the new session ID and
/// protocol version, sends `notifications/initialized`, and returns the new
/// `(session_id, protocol_headers)` pair. The init result message is **not**
/// forwarded to the handler because the handler already processed the original
/// initialization.
async fn perform_reinitialization(
client: C,
saved_init_request: ClientJsonRpcMessage,
uri: Arc<str>,
auth_header: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<(Option<Arc<str>>, HashMap<HeaderName, HeaderValue>), StreamableHttpError<C::Error>>
{
let (init_msg, new_session_id_str) = client
.post_message(
uri.clone(),
saved_init_request,
None,
auth_header.clone(),
custom_headers.clone(),
)
.await?
.expect_initialized::<C::Error>()
.await?;

let new_session_id: Option<Arc<str>> = new_session_id_str.map(|s| Arc::from(s.as_str()));

Check failure

Code scanning / CodeQL

Cleartext logging of sensitive information High

This operation writes
new_session_id_str
to a log file.

Copilot Autofix

AI about 4 hours ago

In general, to fix cleartext logging of sensitive information, you either (a) remove the sensitive data from the log output entirely, or (b) mask/replace it with a redacted or hashed version that cannot be used to compromise security. You only log high‑level status (e.g., “session reinitialized”) rather than the exact session identifier or token.

For this specific code, the data in question is new_session_id_str / new_session_id, which likely contains a session identifier. The safe approach is to ensure that, when we log reinitialization, we do not ever include the actual session id. Since the only logging machinery visible in this file is tracing::debug, and CodeQL says that the mapping operation leads to a logging sink, the least‑intrusive fix is to add a dedicated debug message that explicitly avoids printing the session id, and not log new_session_id_str directly at all. To make the intent clear and to satisfy the analyzer, we can log only whether a session id was obtained (e.g., Some vs None), without including its content. This preserves existing functionality (session handling logic is unchanged) while removing any potential for accidentally logging the raw session identifier.

Concretely, within perform_reinitialization in crates/rmcp/src/transport/streamable_http_client.rs, right after we convert new_session_id_str into new_session_id, we’ll add a debug! call that reports only the presence/absence of the ID, not the ID value. We do not need any new imports (the file already imports tracing::debug). We avoid any change to how new_session_id is used elsewhere.

Suggested changeset 1
crates/rmcp/src/transport/streamable_http_client.rs

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs
--- a/crates/rmcp/src/transport/streamable_http_client.rs
+++ b/crates/rmcp/src/transport/streamable_http_client.rs
@@ -343,6 +343,10 @@
             .await?;
 
         let new_session_id: Option<Arc<str>> = new_session_id_str.map(|s| Arc::from(s.as_str()));
+        debug!(
+            "Reinitialization completed; new session id obtained: {}",
+            if new_session_id.is_some() { "yes" } else { "no" }
+        );
 
         // Start from custom_headers, then inject the negotiated MCP-Protocol-Version
         // so all subsequent requests carry the right version (MCP 2025-06-18 spec).
EOF
@@ -343,6 +343,10 @@
.await?;

let new_session_id: Option<Arc<str>> = new_session_id_str.map(|s| Arc::from(s.as_str()));
debug!(
"Reinitialization completed; new session id obtained: {}",
if new_session_id.is_some() { "yes" } else { "no" }
);

// Start from custom_headers, then inject the negotiated MCP-Protocol-Version
// so all subsequent requests carry the right version (MCP 2025-06-18 spec).
Copilot is powered by AI and may make mistakes. Always verify output.

// Start from custom_headers, then inject the negotiated MCP-Protocol-Version
// so all subsequent requests carry the right version (MCP 2025-06-18 spec).
let mut new_protocol_headers = custom_headers;
if let ServerJsonRpcMessage::Response(response) = &init_msg {
if let ServerResult::InitializeResult(init_result) = &response.result {
if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) {
new_protocol_headers
.insert(HeaderName::from_static("mcp-protocol-version"), hv);
}
}
}

let initialized_notification = ClientJsonRpcMessage::notification(
ClientNotification::InitializedNotification(InitializedNotification {
method: Default::default(),
extensions: Default::default(),
}),
);
client
.post_message(
uri,
initialized_notification,
new_session_id.clone(),
auth_header,
new_protocol_headers.clone(),
)
.await?
.expect_accepted_or_json::<C::Error>()?;

Ok((new_session_id, new_protocol_headers))
}
}

impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
Expand Down Expand Up @@ -338,14 +406,15 @@
responder,
message: initialize_request,
} = context.recv_from_handler().await?;
let saved_init_request = initialize_request.clone();
let (message, session_id) = match self
.client
.post_message(
config.uri.clone(),
initialize_request,
None,
self.config.auth_header,
self.config.custom_headers,
config.auth_header.clone(),
config.custom_headers.clone(),
)
.await
{
Expand All @@ -364,7 +433,7 @@
));
}
};
let session_id: Option<Arc<str>> = if let Some(session_id) = session_id {
let mut session_id: Option<Arc<str>> = if let Some(session_id) = session_id {
Some(session_id.into())
} else {
if !self.config.allow_stateless {
Expand All @@ -378,7 +447,7 @@
// Extract the negotiated protocol version from the init response
// and build a custom headers map that includes MCP-Protocol-Version
// for all subsequent HTTP requests (per MCP 2025-06-18 spec).
let protocol_headers = {
let mut protocol_headers = {
let mut headers = config.custom_headers.clone();
if let ServerJsonRpcMessage::Response(response) = &message {
if let ServerResult::InitializeResult(init_result) = &response.result {
Expand All @@ -392,7 +461,7 @@
};

// Store session info for cleanup when run() exits (not spawned, so cleanup completes before close() returns)
let session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo {
let mut session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo {
client: self.client.clone(),
uri: config.uri.clone(),
session_id: sid.clone(),
Expand Down Expand Up @@ -516,17 +585,171 @@
match event {
Event::ClientMessage(send_request) => {
let WorkerSendRequest { message, responder } = send_request;
// Pass a clone to the first attempt so `message` is retained for a
// potential re-init retry. `post_message` takes ownership and the
// trait cannot be changed, so the clone is unavoidable.
let response = self
.client
.post_message(
config.uri.clone(),
message,
message.clone(),
session_id.clone(),
config.auth_header.clone(),
protocol_headers.clone(),
)
.await;
let send_result = match response {
Err(StreamableHttpError::SessionExpired) => {
// The server discarded the session (HTTP 404). Perform a
// fresh handshake once and replay the original message.
tracing::info!(
"session expired (HTTP 404), attempting transparent re-initialization"
);
match Self::perform_reinitialization(
self.client.clone(),
saved_init_request.clone(),
config.uri.clone(),
config.auth_header.clone(),
config.custom_headers.clone(),
)
.await
{
Ok((new_session_id, new_protocol_headers)) => {
// Old streams hold the stale session ID; abort them
// so the new standalone SSE stream takes over.
streams.abort_all();

session_id = new_session_id;
protocol_headers = new_protocol_headers;
session_cleanup_info =
session_id.as_ref().map(|sid| SessionCleanupInfo {
client: self.client.clone(),
uri: config.uri.clone(),
session_id: sid.clone(),
auth_header: config.auth_header.clone(),
protocol_headers: protocol_headers.clone(),
});

if let Some(new_sid) = &session_id {
let client = self.client.clone();
let uri = config.uri.clone();
let new_sid = new_sid.clone();
let auth_header = config.auth_header.clone();
let retry_config = self.config.retry_config.clone();
let sse_tx = sse_worker_tx.clone();
let task_ct = transport_task_ct.clone();
let config_uri = config.uri.clone();
let config_auth = config.auth_header.clone();
let spawn_headers = protocol_headers.clone();
streams.spawn(async move {
match client
.get_stream(
uri,
new_sid.clone(),
None,
auth_header.clone(),
spawn_headers.clone(),
)
.await
{
Ok(stream) => {
let sse_stream = SseAutoReconnectStream::new(
stream,
StreamableHttpClientReconnect {
client: client.clone(),
session_id: new_sid,
uri: config_uri,
auth_header: config_auth,
custom_headers: spawn_headers,
},
retry_config,
);
Self::execute_sse_stream(
sse_stream,
sse_tx,
false,
task_ct.child_token(),
)
.await
}
Err(StreamableHttpError::ServerDoesNotSupportSse) => {
tracing::debug!(
"server doesn't support sse after re-init"
);
Ok(())
}
Err(e) => {
tracing::error!(
"fail to get common stream after re-init: {e}"
);
Err(e)
}
}
});
}

let retry_response = self
.client
.post_message(
config.uri.clone(),
message,
session_id.clone(),
config.auth_header.clone(),
protocol_headers.clone(),
)
.await;
match retry_response {
Err(e) => Err(e),
Ok(StreamableHttpPostResponse::Accepted) => {
tracing::trace!(
"client message accepted after re-init"
);
Ok(())
}
Ok(StreamableHttpPostResponse::Json(msg, ..)) => {
context.send_to_handler(msg).await?;
Ok(())
}
Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
if let Some(sid) = &session_id {
let sse_stream = SseAutoReconnectStream::new(
stream,
StreamableHttpClientReconnect {
client: self.client.clone(),
session_id: sid.clone(),
uri: config.uri.clone(),
auth_header: config.auth_header.clone(),
custom_headers: protocol_headers.clone(),
},
self.config.retry_config.clone(),
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
true,
transport_task_ct.child_token(),
));
} else {
let sse_stream =
SseAutoReconnectStream::never_reconnect(
stream,
StreamableHttpError::<C::Error>::UnexpectedEndOfStream,
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
true,
transport_task_ct.child_token(),
));
}
tracing::trace!("got new sse stream after re-init");
Ok(())
}
}
}
Err(reinit_err) => Err(reinit_err),
}
}
Err(e) => Err(e),
Ok(StreamableHttpPostResponse::Accepted) => {
tracing::trace!("client message accepted");
Expand Down
Loading