Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/ModelContextProtocol.AspNetCore/ISessionMigrationHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using Microsoft.AspNetCore.Http;
using ModelContextProtocol.Protocol;

namespace ModelContextProtocol.AspNetCore;

/// <summary>
/// Provides hooks for persisting and restoring MCP session initialization data,
/// enabling session migration across server instances.
/// </summary>
/// <remarks>
/// <para>
/// When an MCP server is horizontally scaled, stateful sessions are bound to a single process.
/// If that process restarts or scales down, the session is lost. By implementing this interface
/// and registering it with DI, you can persist the initialization handshake data and restore it
/// when a client reconnects to a different server instance with its existing <c>Mcp-Session-Id</c>.
/// </para>
/// <para>
/// This does <strong>not</strong> solve the session-affinity problem for in-flight server-to-client
/// requests (such as sampling or elicitation). Responses to those requests must still be routed to
/// the process that created the request. This interface only enables migration of idle sessions
/// by persisting the data established during the initialization handshake.
/// </para>
/// </remarks>
public interface ISessionMigrationHandler
{
/// <summary>
/// Called after a session has been successfully initialized via the MCP initialization handshake.
/// </summary>
/// <remarks>
/// Use this to persist the <paramref name="initializeParams"/> (which includes client capabilities,
/// client info, and protocol version) to an external store so the session can be migrated to
/// another server instance later via <see cref="AllowSessionMigrationAsync"/>.
/// </remarks>
/// <param name="context">The <see cref="HttpContext"/> for the initialization request.</param>
/// <param name="sessionId">The unique identifier for the session.</param>
/// <param name="initializeParams">The initialization parameters sent by the client during the handshake.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns>A <see cref="ValueTask"/> representing the asynchronous operation.</returns>
ValueTask OnSessionInitializedAsync(HttpContext context, string sessionId, InitializeRequestParams initializeParams, CancellationToken cancellationToken);

/// <summary>
/// Called when a request arrives with an <c>Mcp-Session-Id</c> that the current server doesn't recognize.
/// </summary>
/// <remarks>
/// <para>
/// Return the original <see cref="InitializeRequestParams"/> to allow the session to be migrated
/// to this server instance, or <see langword="null"/> to reject the request (returning a 404 to the client).
/// </para>
/// <para>
/// Implementations should validate that the request is authorized, for example by checking
/// <see cref="HttpContext.User"/>, to ensure the caller is permitted to migrate the session.
/// </para>
/// </remarks>
/// <param name="context">The <see cref="HttpContext"/> for the request with the unrecognized session ID.</param>
/// <param name="sessionId">The session ID from the request that was not found on this server.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns>
/// The original <see cref="InitializeRequestParams"/> if migration is allowed,
/// or <see langword="null"/> to reject the request.
/// </returns>
ValueTask<InitializeRequestParams?> AllowSessionMigrationAsync(HttpContext context, string sessionId, CancellationToken cancellationToken);
}
125 changes: 107 additions & 18 deletions src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.Net.Http.Headers;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using System.Collections.Concurrent;
using System.Security.Claims;
using System.Security.Cryptography;
using System.Text.Json.Serialization.Metadata;
Expand All @@ -20,14 +21,17 @@ internal sealed class StreamableHttpHandler(
StatefulSessionManager sessionManager,
IHostApplicationLifetime hostApplicationLifetime,
IServiceProvider applicationServices,
ILoggerFactory loggerFactory)
ILoggerFactory loggerFactory,
ISessionMigrationHandler? sessionMigrationHandler = null)
{
private const string McpSessionIdHeaderName = "Mcp-Session-Id";
private const string LastEventIdHeaderName = "Last-Event-ID";

private static readonly JsonTypeInfo<JsonRpcMessage> s_messageTypeInfo = GetRequiredJsonTypeInfo<JsonRpcMessage>();
private static readonly JsonTypeInfo<JsonRpcError> s_errorTypeInfo = GetRequiredJsonTypeInfo<JsonRpcError>();

private readonly ConcurrentDictionary<string, SemaphoreSlim> _migrationLocks = new(StringComparer.Ordinal);

public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions.Value;

public async Task HandlePostRequestAsync(HttpContext context)
Expand All @@ -45,14 +49,6 @@ await WriteJsonRpcErrorAsync(context,
return;
}

var session = await GetOrCreateSessionAsync(context);
if (session is null)
{
return;
}

await using var _ = await session.AcquireReferenceAsync(context.RequestAborted);

var message = await ReadJsonRpcMessageAsync(context);
if (message is null)
{
Expand All @@ -62,6 +58,14 @@ await WriteJsonRpcErrorAsync(context,
return;
}

var session = await GetOrCreateSessionAsync(context, message);
if (session is null)
{
return;
}

await using var _ = await session.AcquireReferenceAsync(context.RequestAborted);

InitializeSseResponse(context);
var wroteResponse = await session.Transport.HandlePostRequestAsync(message, context.Response.Body, context.RequestAborted);
if (!wroteResponse)
Expand Down Expand Up @@ -188,12 +192,18 @@ public async Task HandleDeleteRequestAsync(HttpContext context)

if (!sessionManager.TryGetValue(sessionId, out var session))
{
// -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does.
// One of the few other usages I found was from some Ethereum JSON-RPC documentation and this
// JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound
// https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields
await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, -32001);
return null;
// Session not found locally. Attempt migration if a handler is registered.
session = await TryMigrateSessionAsync(context, sessionId);

if (session is null)
{
// -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does.
// One of the few other usages I found was from some Ethereum JSON-RPC documentation and this
// JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound
// https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields
await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, -32001);
return null;
}
}

if (!session.HasSameUserId(context.User))
Expand All @@ -209,12 +219,60 @@ await WriteJsonRpcErrorAsync(context,
return session;
}

private async ValueTask<StreamableHttpSession?> GetOrCreateSessionAsync(HttpContext context)
private async ValueTask<StreamableHttpSession?> TryMigrateSessionAsync(HttpContext context, string sessionId)
{
if (sessionMigrationHandler is not { } handler)
{
return null;
}

var migrationLock = _migrationLocks.GetOrAdd(sessionId, static _ => new SemaphoreSlim(1, 1));
await migrationLock.WaitAsync(context.RequestAborted);
try
{
// Re-check after acquiring the lock - another thread may have already completed migration.
if (sessionManager.TryGetValue(sessionId, out var session))
{
return session;
}

var initParams = await handler.AllowSessionMigrationAsync(context, sessionId, context.RequestAborted);
if (initParams is null)
{
return null;
}

var migratedSession = await MigrateSessionAsync(context, sessionId, initParams);

// Register the session with the session manager while still holding the lock
// so concurrent requests for the same session ID find it via sessionManager.TryGetValue.
await migratedSession.EnsureStartedAsync(context.RequestAborted);

return migratedSession;
}
finally
{
migrationLock.Release();
_migrationLocks.TryRemove(sessionId, out _);
}
}

private async ValueTask<StreamableHttpSession?> GetOrCreateSessionAsync(HttpContext context, JsonRpcMessage message)
{
var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString();

if (string.IsNullOrEmpty(sessionId))
{
// In stateful mode, only allow creating new sessions for initialize requests.
// In stateless mode, every request is independent, so we always create a new session.
if (!HttpServerTransportOptions.Stateless && message is not JsonRpcRequest { Method: RequestMethods.Initialize })
{
await WriteJsonRpcErrorAsync(context,
"Bad Request: A new session can only be created by an initialize request. Include a valid Mcp-Session-Id header for non-initialize requests.",
StatusCodes.Status400BadRequest);
return null;
}

return await StartNewSessionAsync(context);
}
else if (HttpServerTransportOptions.Stateless)
Expand Down Expand Up @@ -243,7 +301,11 @@ private async ValueTask<StreamableHttpSession> StartNewSessionAsync(HttpContext
SessionId = sessionId,
FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext,
EventStreamStore = HttpServerTransportOptions.EventStreamStore,
OnSessionInitialized = sessionMigrationHandler is { } handler
? (initParams, ct) => handler.OnSessionInitializedAsync(context, sessionId, initParams, ct)
: null,
};

context.Response.Headers[McpSessionIdHeaderName] = sessionId;
}
else
Expand All @@ -264,11 +326,12 @@ private async ValueTask<StreamableHttpSession> StartNewSessionAsync(HttpContext
private async ValueTask<StreamableHttpSession> CreateSessionAsync(
HttpContext context,
StreamableHttpServerTransport transport,
string sessionId)
string sessionId,
Action<McpServerOptions>? configureOptions = null)
{
var mcpServerServices = applicationServices;
var mcpServerOptions = mcpServerOptionsSnapshot.Value;
if (HttpServerTransportOptions.Stateless || HttpServerTransportOptions.ConfigureSessionOptions is not null)
if (HttpServerTransportOptions.Stateless || HttpServerTransportOptions.ConfigureSessionOptions is not null || configureOptions is not null)
{
mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName);

Expand All @@ -279,6 +342,8 @@ private async ValueTask<StreamableHttpSession> CreateSessionAsync(
mcpServerOptions.ScopeRequests = false;
}

configureOptions?.Invoke(mcpServerOptions);

if (HttpServerTransportOptions.ConfigureSessionOptions is { } configureSessionOptions)
{
await configureSessionOptions(context, mcpServerOptions, context.RequestAborted);
Expand All @@ -297,6 +362,30 @@ private async ValueTask<StreamableHttpSession> CreateSessionAsync(
return session;
}

private async ValueTask<StreamableHttpSession> MigrateSessionAsync(
HttpContext context,
string sessionId,
InitializeRequestParams initializeParams)
{
var transport = new StreamableHttpServerTransport(loggerFactory)
{
SessionId = sessionId,
FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext,
EventStreamStore = HttpServerTransportOptions.EventStreamStore,
};

// Initialize the transport with the migrated session's init params.
await transport.HandleInitRequestAsync(initializeParams);

context.Response.Headers[McpSessionIdHeaderName] = sessionId;

return await CreateSessionAsync(context, transport, sessionId, options =>
{
options.KnownClientInfo = initializeParams.ClientInfo;
options.KnownClientCapabilities = initializeParams.Capabilities;
});
}

private async ValueTask<ISseEventStreamReader?> GetEventStreamReaderAsync(HttpContext context, string lastEventId)
{
if (HttpServerTransportOptions.EventStreamStore is not { } eventStreamStore)
Expand Down
25 changes: 25 additions & 0 deletions src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,31 @@ public async ValueTask<IAsyncDisposable> AcquireReferenceAsync(CancellationToken
return new UnreferenceDisposable(this);
}

/// <summary>
/// Ensures the session is registered with the session manager without acquiring a reference.
/// No-ops if the session is already started.
/// </summary>
public async ValueTask EnsureStartedAsync(CancellationToken cancellationToken)
{
bool needsStart;
lock (_stateLock)
{
needsStart = _state == SessionState.Uninitialized;
if (needsStart)
{
_state = SessionState.Started;
}
}

if (needsStart)
{
await sessionManager.StartNewSessionAsync(this, cancellationToken);

// Session is registered with 0 references (idle), so reflect that in the idle count.
sessionManager.IncrementIdleSessionCount();
}
}

public bool TryStartGetRequest() => Interlocked.Exchange(ref _getRequestStarted, 1) == 0;
public bool HasSameUserId(ClaimsPrincipal user) => userId == StreamableHttpHandler.GetUserIdClaim(user);

Expand Down
1 change: 1 addition & 0 deletions src/ModelContextProtocol.Core/Server/McpServerImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public McpServerImpl(ITransport transport, McpServerOptions options, ILoggerFact
}

_clientInfo = options.KnownClientInfo;
_clientCapabilities = options.KnownClientCapabilities;
UpdateEndpointNameWithClientInfo();

_notificationHandlers = new();
Expand Down
12 changes: 12 additions & 0 deletions src/ModelContextProtocol.Core/Server/McpServerOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ public sealed class McpServerOptions
/// </remarks>
public Implementation? KnownClientInfo { get; set; }

/// <summary>
/// Gets or sets preexisting knowledge about the client's capabilities to support session migration
/// scenarios where the client will not re-send the initialize request.
/// </summary>
/// <remarks>
/// <para>
/// When not specified, this information is sourced from the client's initialize request.
/// This is typically set during session migration in conjunction with <see cref="KnownClientInfo"/>.
/// </para>
/// </remarks>
public ClientCapabilities? KnownClientCapabilities { get; set; }

/// <summary>
/// Gets the filter collections for MCP server handlers.
/// </summary>
Expand Down
Loading
Loading