diff --git a/src/ModelContextProtocol.AspNetCore/ISessionMigrationHandler.cs b/src/ModelContextProtocol.AspNetCore/ISessionMigrationHandler.cs new file mode 100644 index 000000000..9eaf0902d --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/ISessionMigrationHandler.cs @@ -0,0 +1,62 @@ +using Microsoft.AspNetCore.Http; +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Provides hooks for persisting and restoring MCP session initialization data, +/// enabling session migration across server instances. +/// +/// +/// +/// When an MCP server is horizontally scaled, stateful sessions are bound to a single process. +/// If that process restarts or scales down, the session is lost. By implementing this interface +/// and registering it with DI, you can persist the initialization handshake data and restore it +/// when a client reconnects to a different server instance with its existing Mcp-Session-Id. +/// +/// +/// This does not solve the session-affinity problem for in-flight server-to-client +/// requests (such as sampling or elicitation). Responses to those requests must still be routed to +/// the process that created the request. This interface only enables migration of idle sessions +/// by persisting the data established during the initialization handshake. +/// +/// +public interface ISessionMigrationHandler +{ + /// + /// Called after a session has been successfully initialized via the MCP initialization handshake. + /// + /// + /// Use this to persist the (which includes client capabilities, + /// client info, and protocol version) to an external store so the session can be migrated to + /// another server instance later via . + /// + /// The for the initialization request. + /// The unique identifier for the session. + /// The initialization parameters sent by the client during the handshake. + /// A cancellation token. + /// A representing the asynchronous operation. + ValueTask OnSessionInitializedAsync(HttpContext context, string sessionId, InitializeRequestParams initializeParams, CancellationToken cancellationToken); + + /// + /// Called when a request arrives with an Mcp-Session-Id that the current server doesn't recognize. + /// + /// + /// + /// Return the original to allow the session to be migrated + /// to this server instance, or to reject the request (returning a 404 to the client). + /// + /// + /// Implementations should validate that the request is authorized, for example by checking + /// , to ensure the caller is permitted to migrate the session. + /// + /// + /// The for the request with the unrecognized session ID. + /// The session ID from the request that was not found on this server. + /// A cancellation token. + /// + /// The original if migration is allowed, + /// or to reject the request. + /// + ValueTask AllowSessionMigrationAsync(HttpContext context, string sessionId, CancellationToken cancellationToken); +} diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 22c861326..a7983185e 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -7,6 +7,7 @@ using Microsoft.Net.Http.Headers; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using System.Collections.Concurrent; using System.Security.Claims; using System.Security.Cryptography; using System.Text.Json.Serialization.Metadata; @@ -20,7 +21,8 @@ internal sealed class StreamableHttpHandler( StatefulSessionManager sessionManager, IHostApplicationLifetime hostApplicationLifetime, IServiceProvider applicationServices, - ILoggerFactory loggerFactory) + ILoggerFactory loggerFactory, + ISessionMigrationHandler? sessionMigrationHandler = null) { private const string McpSessionIdHeaderName = "Mcp-Session-Id"; private const string LastEventIdHeaderName = "Last-Event-ID"; @@ -28,6 +30,8 @@ internal sealed class StreamableHttpHandler( private static readonly JsonTypeInfo s_messageTypeInfo = GetRequiredJsonTypeInfo(); private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); + private readonly ConcurrentDictionary _migrationLocks = new(StringComparer.Ordinal); + public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions.Value; public async Task HandlePostRequestAsync(HttpContext context) @@ -45,14 +49,6 @@ await WriteJsonRpcErrorAsync(context, return; } - var session = await GetOrCreateSessionAsync(context); - if (session is null) - { - return; - } - - await using var _ = await session.AcquireReferenceAsync(context.RequestAborted); - var message = await ReadJsonRpcMessageAsync(context); if (message is null) { @@ -62,6 +58,14 @@ await WriteJsonRpcErrorAsync(context, return; } + var session = await GetOrCreateSessionAsync(context, message); + if (session is null) + { + return; + } + + await using var _ = await session.AcquireReferenceAsync(context.RequestAborted); + InitializeSseResponse(context); var wroteResponse = await session.Transport.HandlePostRequestAsync(message, context.Response.Body, context.RequestAborted); if (!wroteResponse) @@ -188,12 +192,18 @@ public async Task HandleDeleteRequestAsync(HttpContext context) if (!sessionManager.TryGetValue(sessionId, out var session)) { - // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does. - // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this - // JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound - // https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields - await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, -32001); - return null; + // Session not found locally. Attempt migration if a handler is registered. + session = await TryMigrateSessionAsync(context, sessionId); + + if (session is null) + { + // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does. + // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this + // JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound + // https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields + await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, -32001); + return null; + } } if (!session.HasSameUserId(context.User)) @@ -209,12 +219,60 @@ await WriteJsonRpcErrorAsync(context, return session; } - private async ValueTask GetOrCreateSessionAsync(HttpContext context) + private async ValueTask TryMigrateSessionAsync(HttpContext context, string sessionId) + { + if (sessionMigrationHandler is not { } handler) + { + return null; + } + + var migrationLock = _migrationLocks.GetOrAdd(sessionId, static _ => new SemaphoreSlim(1, 1)); + await migrationLock.WaitAsync(context.RequestAborted); + try + { + // Re-check after acquiring the lock - another thread may have already completed migration. + if (sessionManager.TryGetValue(sessionId, out var session)) + { + return session; + } + + var initParams = await handler.AllowSessionMigrationAsync(context, sessionId, context.RequestAborted); + if (initParams is null) + { + return null; + } + + var migratedSession = await MigrateSessionAsync(context, sessionId, initParams); + + // Register the session with the session manager while still holding the lock + // so concurrent requests for the same session ID find it via sessionManager.TryGetValue. + await migratedSession.EnsureStartedAsync(context.RequestAborted); + + return migratedSession; + } + finally + { + migrationLock.Release(); + _migrationLocks.TryRemove(sessionId, out _); + } + } + + private async ValueTask GetOrCreateSessionAsync(HttpContext context, JsonRpcMessage message) { var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString(); if (string.IsNullOrEmpty(sessionId)) { + // In stateful mode, only allow creating new sessions for initialize requests. + // In stateless mode, every request is independent, so we always create a new session. + if (!HttpServerTransportOptions.Stateless && message is not JsonRpcRequest { Method: RequestMethods.Initialize }) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: A new session can only be created by an initialize request. Include a valid Mcp-Session-Id header for non-initialize requests.", + StatusCodes.Status400BadRequest); + return null; + } + return await StartNewSessionAsync(context); } else if (HttpServerTransportOptions.Stateless) @@ -243,7 +301,11 @@ private async ValueTask StartNewSessionAsync(HttpContext SessionId = sessionId, FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext, EventStreamStore = HttpServerTransportOptions.EventStreamStore, + OnSessionInitialized = sessionMigrationHandler is { } handler + ? (initParams, ct) => handler.OnSessionInitializedAsync(context, sessionId, initParams, ct) + : null, }; + context.Response.Headers[McpSessionIdHeaderName] = sessionId; } else @@ -264,11 +326,12 @@ private async ValueTask StartNewSessionAsync(HttpContext private async ValueTask CreateSessionAsync( HttpContext context, StreamableHttpServerTransport transport, - string sessionId) + string sessionId, + Action? configureOptions = null) { var mcpServerServices = applicationServices; var mcpServerOptions = mcpServerOptionsSnapshot.Value; - if (HttpServerTransportOptions.Stateless || HttpServerTransportOptions.ConfigureSessionOptions is not null) + if (HttpServerTransportOptions.Stateless || HttpServerTransportOptions.ConfigureSessionOptions is not null || configureOptions is not null) { mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName); @@ -279,6 +342,8 @@ private async ValueTask CreateSessionAsync( mcpServerOptions.ScopeRequests = false; } + configureOptions?.Invoke(mcpServerOptions); + if (HttpServerTransportOptions.ConfigureSessionOptions is { } configureSessionOptions) { await configureSessionOptions(context, mcpServerOptions, context.RequestAborted); @@ -297,6 +362,30 @@ private async ValueTask CreateSessionAsync( return session; } + private async ValueTask MigrateSessionAsync( + HttpContext context, + string sessionId, + InitializeRequestParams initializeParams) + { + var transport = new StreamableHttpServerTransport(loggerFactory) + { + SessionId = sessionId, + FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext, + EventStreamStore = HttpServerTransportOptions.EventStreamStore, + }; + + // Initialize the transport with the migrated session's init params. + await transport.HandleInitRequestAsync(initializeParams); + + context.Response.Headers[McpSessionIdHeaderName] = sessionId; + + return await CreateSessionAsync(context, transport, sessionId, options => + { + options.KnownClientInfo = initializeParams.ClientInfo; + options.KnownClientCapabilities = initializeParams.Capabilities; + }); + } + private async ValueTask GetEventStreamReaderAsync(HttpContext context, string lastEventId) { if (HttpServerTransportOptions.EventStreamStore is not { } eventStreamStore) diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs index e3226b57d..5065ddcfb 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs @@ -74,6 +74,31 @@ public async ValueTask AcquireReferenceAsync(CancellationToken return new UnreferenceDisposable(this); } + /// + /// Ensures the session is registered with the session manager without acquiring a reference. + /// No-ops if the session is already started. + /// + public async ValueTask EnsureStartedAsync(CancellationToken cancellationToken) + { + bool needsStart; + lock (_stateLock) + { + needsStart = _state == SessionState.Uninitialized; + if (needsStart) + { + _state = SessionState.Started; + } + } + + if (needsStart) + { + await sessionManager.StartNewSessionAsync(this, cancellationToken); + + // Session is registered with 0 references (idle), so reflect that in the idle count. + sessionManager.IncrementIdleSessionCount(); + } + } + public bool TryStartGetRequest() => Interlocked.Exchange(ref _getRequestStarted, 1) == 0; public bool HasSameUserId(ClaimsPrincipal user) => userId == StreamableHttpHandler.GetUserIdClaim(user); diff --git a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs index 04f329437..30b3dff6b 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs @@ -75,6 +75,7 @@ public McpServerImpl(ITransport transport, McpServerOptions options, ILoggerFact } _clientInfo = options.KnownClientInfo; + _clientCapabilities = options.KnownClientCapabilities; UpdateEndpointNameWithClientInfo(); _notificationHandlers = new(); diff --git a/src/ModelContextProtocol.Core/Server/McpServerOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerOptions.cs index e13b4d3b5..1af473da9 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerOptions.cs @@ -81,6 +81,18 @@ public sealed class McpServerOptions /// public Implementation? KnownClientInfo { get; set; } + /// + /// Gets or sets preexisting knowledge about the client's capabilities to support session migration + /// scenarios where the client will not re-send the initialize request. + /// + /// + /// + /// When not specified, this information is sourced from the client's initialize request. + /// This is typically set during session migration in conjunction with . + /// + /// + public ClientCapabilities? KnownClientCapabilities { get; set; } + /// /// Gets the filter collections for MCP server handlers. /// diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index 58227757b..be639b15d 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -4,7 +4,6 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Net.ServerSentEvents; -using System.Security.Claims; using System.Threading.Channels; namespace ModelContextProtocol.Server; @@ -42,6 +41,7 @@ public sealed partial class StreamableHttpServerTransport : ITransport private SseEventWriter? _httpSseWriter; private ISseEventStreamWriter? _storeSseWriter; private TaskCompletionSource? _httpResponseTcs; + private string? _negotiatedProtocolVersion; private bool _getHttpRequestStarted; private bool _getHttpResponseCompleted; @@ -82,9 +82,13 @@ public StreamableHttpServerTransport(ILoggerFactory? loggerFactory = null) public ISseEventStreamStore? EventStreamStore { get; init; } /// - /// Gets or sets the negotiated protocol version for this session. + /// Gets or sets an optional callback invoked after the initialization handshake completes. /// - internal string? NegotiatedProtocolVersion { get; private set; } + /// + /// When set, this callback is invoked with the after a successful + /// initialization handshake. This can be used to persist session data for cross-instance migration. + /// + public Func? OnSessionInitialized { get; init; } /// public ChannelReader MessageReader => _incomingChannel.Reader; @@ -92,12 +96,23 @@ public StreamableHttpServerTransport(ILoggerFactory? loggerFactory = null) internal ChannelWriter MessageWriter => _incomingChannel.Writer; /// - /// Handles the initialize request by capturing the protocol version and invoking the user callback. + /// Handles initialization by capturing the negotiated protocol version and optionally invoking + /// so session data can be persisted. /// - internal async ValueTask HandleInitRequestAsync(InitializeRequestParams? initParams) + /// + /// This is called automatically when an initialize request is processed via + /// . It can also be called + /// directly when restoring a migrated session with known . + /// + /// The initialization parameters from the client, or if unavailable. + public async ValueTask HandleInitRequestAsync(InitializeRequestParams? initParams) { - // Capture the negotiated protocol version for resumability checks - NegotiatedProtocolVersion = initParams?.ProtocolVersion; + _negotiatedProtocolVersion = initParams?.ProtocolVersion; + + if (initParams is not null && OnSessionInitialized is { } callback) + { + await callback(initParams, _transportDisposedCts.Token).ConfigureAwait(false); + } } /// @@ -165,7 +180,7 @@ public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationTo /// /// or is . /// - /// If an authenticated sent the message, that can be included in the . + /// If an authenticated user sent the message, that can be included in the . /// No other part of the context should be set. /// public async Task HandlePostRequestAsync(JsonRpcMessage message, Stream responseStream, CancellationToken cancellationToken = default) @@ -266,7 +281,7 @@ public async ValueTask DisposeAsync() internal async ValueTask TryCreateEventStreamAsync(string streamId, CancellationToken cancellationToken) { - if (EventStreamStore is null || !McpSessionHandler.SupportsPrimingEvent(NegotiatedProtocolVersion)) + if (EventStreamStore is null || !McpSessionHandler.SupportsPrimingEvent(_negotiatedProtocolVersion)) { return null; } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SessionMigrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SessionMigrationTests.cs new file mode 100644 index 000000000..a06a5d129 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SessionMigrationTests.cs @@ -0,0 +1,341 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Net; +using System.Net.ServerSentEvents; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class SessionMigrationTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable +{ + private static McpServerTool[] Tools { get; } = [McpServerTool.Create(EchoAsync), McpServerTool.Create(GetClientInfoAsync)]; + + private WebApplication? _app; + + private static string InitializeRequest => """ + {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"IntegrationTestClient","version":"1.0.0"}}} + """; + + private long _lastRequestId = 1; + private string MakeEchoRequest() + { + var id = Interlocked.Increment(ref _lastRequestId); + return $$$$""" + {"jsonrpc":"2.0","id":{{{{id}}}},"method":"tools/call","params":{"name":"echo","arguments":{"message":"Hello world! ({{{{id}}}})"}}} + """; + } + + [Fact] + public async Task OnSessionInitializedAsync_IsCalled_AfterInitializeHandshake() + { + InitializeRequestParams? capturedParams = null; + string? capturedSessionId = null; + + var handler = new TestMigrationHandler + { + OnInitialized = (context, sessionId, initParams, ct) => + { + capturedSessionId = sessionId; + capturedParams = initParams; + return default; + }, + }; + + await StartAsync(handler); + + var sessionId = await CallInitializeAndValidateAsync(); + + Assert.NotNull(capturedParams); + Assert.Equal(sessionId, capturedSessionId); + Assert.Equal("IntegrationTestClient", capturedParams.ClientInfo.Name); + Assert.Equal("1.0.0", capturedParams.ClientInfo.Version); + Assert.NotNull(capturedParams.Capabilities); + } + + [Fact] + public async Task AllowSessionMigrationAsync_IsCalled_WhenSessionNotFound() + { + string? requestedSessionId = null; + + var handler = new TestMigrationHandler + { + OnInitialized = (_, _, _, _) => default, + OnMigration = (context, sessionId, ct) => + { + requestedSessionId = sessionId; + return new ValueTask(new InitializeRequestParams + { + ProtocolVersion = "2025-03-26", + Capabilities = new ClientCapabilities(), + ClientInfo = new Implementation { Name = "MigratedClient", Version = "2.0.0" }, + }); + }, + }; + + await StartAsync(handler); + + // Send a request with a fake session ID that the server doesn't know about. + SetSessionId("migratable-session-id"); + await CallEchoAndValidateAsync(); + + Assert.Equal("migratable-session-id", requestedSessionId); + + // Verify the migrated client info was applied to the session. + var clientInfo = await CallGetClientInfoAsync(); + Assert.NotNull(clientInfo); + Assert.Equal("MigratedClient", clientInfo.Name); + Assert.Equal("2.0.0", clientInfo.Version); + } + + [Fact] + public async Task MigratedSession_PreservesSessionId() + { + var handler = new TestMigrationHandler + { + OnInitialized = (_, _, _, _) => default, + OnMigration = (context, sessionId, ct) => + { + return new ValueTask(new InitializeRequestParams + { + ProtocolVersion = "2025-03-26", + Capabilities = new ClientCapabilities(), + ClientInfo = new Implementation { Name = "MigratedClient", Version = "2.0.0" }, + }); + }, + }; + + await StartAsync(handler); + + const string OriginalSessionId = "preserved-session-id"; + SetSessionId(OriginalSessionId); + + using var response = await HttpClient.PostAsync("", JsonContent(MakeEchoRequest()), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // The response should echo back the same session ID. + var returnedSessionId = Assert.Single(response.Headers.GetValues("mcp-session-id")); + Assert.Equal(OriginalSessionId, returnedSessionId); + } + + [Fact] + public async Task MigratedSession_CanHandleSubsequentRequests() + { + var migrationCount = 0; + var handler = new TestMigrationHandler + { + OnInitialized = (_, _, _, _) => default, + OnMigration = (context, sessionId, ct) => + { + Interlocked.Increment(ref migrationCount); + return new ValueTask(new InitializeRequestParams + { + ProtocolVersion = "2025-03-26", + Capabilities = new ClientCapabilities(), + ClientInfo = new Implementation { Name = "MigratedClient", Version = "2.0.0" }, + }); + }, + }; + + await StartAsync(handler); + + SetSessionId("multi-request-session"); + + // First request triggers migration + await CallEchoAndValidateAsync(); + + // Second request should use the now-local session without triggering another migration. + await CallEchoAndValidateAsync(); + + Assert.Equal(1, migrationCount); + } + + [Fact] + public async Task AllowSessionMigrationAsync_ReturnsNull_ResultsIn404() + { + var handler = new TestMigrationHandler + { + OnInitialized = (_, _, _, _) => default, + OnMigration = (context, sessionId, ct) => + new ValueTask((InitializeRequestParams?)null), + }; + + await StartAsync(handler); + + SetSessionId("non-migratable-session"); + + using var response = await HttpClient.PostAsync("", JsonContent(MakeEchoRequest()), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + } + + [Fact] + public async Task NoMigrationHandler_UnknownSession_Returns404() + { + // Start without any migration handler — backward compatibility. + await StartAsync(migrationHandler: null); + + SetSessionId("unknown-session"); + + using var response = await HttpClient.PostAsync("", JsonContent(MakeEchoRequest()), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + } + + [Fact] + public async Task GetRequest_WithMigratedSession_Works() + { + var handler = new TestMigrationHandler + { + OnInitialized = (_, _, _, _) => default, + OnMigration = (context, sessionId, ct) => + { + return new ValueTask(new InitializeRequestParams + { + ProtocolVersion = "2025-03-26", + Capabilities = new ClientCapabilities(), + ClientInfo = new Implementation { Name = "MigratedClient", Version = "2.0.0" }, + }); + }, + }; + + await StartAsync(handler); + + // Migrate session via POST first + SetSessionId("get-test-session"); + await CallEchoAndValidateAsync(); + + // Now the GET request should work with the migrated session + using var getResponse = await HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, getResponse.StatusCode); + } + + private async Task StartAsync(ISessionMigrationHandler? migrationHandler = null) + { + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new Implementation + { + Name = "SessionMigrationTestServer", + Version = "1.0.0", + }; + }).WithTools(Tools).WithHttpTransport(); + + if (migrationHandler is not null) + { + Builder.Services.AddSingleton(migrationHandler); + } + + _app = Builder.Build(); + _app.MapMcp(); + await _app.StartAsync(TestContext.Current.CancellationToken); + + HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json")); + HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); + } + + public async ValueTask DisposeAsync() + { + if (_app is not null) + { + await _app.DisposeAsync(); + } + base.Dispose(); + } + + private static StringContent JsonContent(string json) => new(json, Encoding.UTF8, "application/json"); + private static JsonTypeInfo GetJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); + + private async Task CallInitializeAndValidateAsync() + { + HttpClient.DefaultRequestHeaders.Remove("mcp-session-id"); + using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken); + var rpcResponse = await AssertSingleSseResponseAsync(response); + + var sessionId = Assert.Single(response.Headers.GetValues("mcp-session-id")); + SetSessionId(sessionId); + return sessionId; + } + + private void SetSessionId(string sessionId) + { + HttpClient.DefaultRequestHeaders.Remove("mcp-session-id"); + HttpClient.DefaultRequestHeaders.Add("mcp-session-id", sessionId); + } + + private async Task CallEchoAndValidateAsync() + { + using var response = await HttpClient.PostAsync("", JsonContent(MakeEchoRequest()), TestContext.Current.CancellationToken); + var rpcResponse = await AssertSingleSseResponseAsync(response); + + var callToolResult = JsonSerializer.Deserialize(rpcResponse.Result, GetJsonTypeInfo()); + Assert.NotNull(callToolResult); + var content = Assert.Single(callToolResult.Content); + Assert.IsType(content); + } + + private async Task CallGetClientInfoAsync() + { + var id = Interlocked.Increment(ref _lastRequestId); + var request = $$$$""" + {"jsonrpc":"2.0","id":{{{{id}}}},"method":"tools/call","params":{"name":"getClientInfo","arguments":{}}} + """; + + using var response = await HttpClient.PostAsync("", JsonContent(request), TestContext.Current.CancellationToken); + var rpcResponse = await AssertSingleSseResponseAsync(response); + + var callToolResult = JsonSerializer.Deserialize(rpcResponse.Result, GetJsonTypeInfo()); + Assert.NotNull(callToolResult); + var textContent = Assert.IsType(Assert.Single(callToolResult.Content)); + return JsonSerializer.Deserialize(textContent.Text, GetJsonTypeInfo()); + } + + private static async Task AssertSingleSseResponseAsync(HttpResponseMessage response) + { + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("text/event-stream", response.Content.Headers.ContentType?.MediaType); + + var sseItems = new List(); + var responseStream = await response.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken); + await foreach (var sseItem in SseParser.Create(responseStream).EnumerateAsync(TestContext.Current.CancellationToken)) + { + if (sseItem.EventType == "message") + { + sseItems.Add(sseItem.Data); + } + } + + var data = Assert.Single(sseItems); + var jsonRpcResponse = JsonSerializer.Deserialize(data, GetJsonTypeInfo()); + Assert.NotNull(jsonRpcResponse); + return jsonRpcResponse; + } + + [McpServerTool(Name = "echo")] + private static async Task EchoAsync(string message) + { + await Task.Yield(); + return message; + } + + [McpServerTool(Name = "getClientInfo")] + private static string GetClientInfoAsync(McpServer server) + { + return JsonSerializer.Serialize(server.ClientInfo!, GetJsonTypeInfo()); + } + + private sealed class TestMigrationHandler : ISessionMigrationHandler + { + public Func? OnInitialized { get; set; } + public Func>? OnMigration { get; set; } + + public ValueTask OnSessionInitializedAsync(HttpContext context, string sessionId, InitializeRequestParams initializeParams, CancellationToken cancellationToken) + => OnInitialized?.Invoke(context, sessionId, initializeParams, cancellationToken) ?? default; + + public ValueTask AllowSessionMigrationAsync(HttpContext context, string sessionId, CancellationToken cancellationToken) + => OnMigration?.Invoke(context, sessionId, cancellationToken) ?? new ValueTask((InitializeRequestParams?)null); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs index 35e17be84..b5deb264e 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs @@ -175,6 +175,15 @@ public async Task PostRequest_IsNotFound_WithUnrecognizedSessionId() Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); } + [Fact] + public async Task PostWithoutSessionId_NonInitializeRequest_Returns400() + { + await StartAsync(); + + using var response = await HttpClient.PostAsync("", JsonContent(ListToolsRequest), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + [Fact] public async Task InitializeRequest_Matches_CustomRoute() { @@ -660,6 +669,10 @@ private static async Task AssertSingleSseResponseAsync(HttpResp {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"IntegrationTestClient","version":"1.0.0"}}} """; + private static string ListToolsRequest => """ + {"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}} + """; + private long _lastRequestId = 1; private string EchoRequest {