diff --git a/packages/server/src/server/streamableHttp.ts b/packages/server/src/server/streamableHttp.ts index 74e689892..11c011a7a 100644 --- a/packages/server/src/server/streamableHttp.ts +++ b/packages/server/src/server/streamableHttp.ts @@ -67,6 +67,8 @@ interface StreamMapping { cleanup: () => void; } +const REPLAY_STREAM_CLOSED_ERROR = 'ERR_MCP_REPLAY_STREAM_CLOSED'; + /** * Configuration options for {@linkcode WebStandardStreamableHTTPServerTransport} */ @@ -306,6 +308,24 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { ); } + private createReplayStreamClosedError(): Error { + const error = new Error('Replay stream closed'); + error.name = REPLAY_STREAM_CLOSED_ERROR; + return error; + } + + private isReplayStreamClosedError(error: unknown): error is Error { + return error instanceof Error && error.name === REPLAY_STREAM_CLOSED_ERROR; + } + + private closeStreamController(controller: ReadableStreamDefaultController): void { + try { + controller.close(); + } catch { + // Controller might already be closed + } + } + /** * Validates request headers for DNS rebinding protection. * @returns Error response if validation fails, `undefined` if validation passes. @@ -513,44 +533,63 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { // Create a ReadableStream with controller for SSE const encoder = new TextEncoder(); let streamController: ReadableStreamDefaultController; + let replayedStreamId = ''; + let streamClosed = false; + + const closeReplayStream = () => { + if (streamClosed) { + return; + } + + streamClosed = true; + + if (replayedStreamId) { + this._streamMapping.delete(replayedStreamId); + } + + this.closeStreamController(streamController!); + }; const readable = new ReadableStream({ start: controller => { streamController = controller; }, cancel: () => { - // Stream was cancelled by client - // Cleanup will be handled by the mapping + closeReplayStream(); } }); // Replay events - returns the streamId for backwards compatibility - const replayedStreamId = await this._eventStore.replayEventsAfter(lastEventId, { - send: async (eventId: string, message: JSONRPCMessage) => { - const success = this.writeSSEEvent(streamController!, encoder, message, eventId); - if (!success) { - this.onerror?.(new Error('Failed replay events')); - try { - streamController!.close(); - } catch { - // Controller might already be closed + try { + replayedStreamId = await this._eventStore.replayEventsAfter(lastEventId, { + send: async (eventId: string, message: JSONRPCMessage) => { + if (streamClosed) { + throw this.createReplayStreamClosedError(); + } + + const success = this.writeSSEEvent(streamController!, encoder, message, eventId); + if (!success) { + closeReplayStream(); + throw this.createReplayStreamClosedError(); } } + }); + } catch (error) { + if (!this.isReplayStreamClosedError(error)) { + throw error; } - }); + } - this._streamMapping.set(replayedStreamId, { - controller: streamController!, - encoder, - cleanup: () => { - this._streamMapping.delete(replayedStreamId); - try { - streamController!.close(); - } catch { - // Controller might already be closed + if (!streamClosed) { + this._streamMapping.set(replayedStreamId, { + controller: streamController!, + encoder, + cleanup: () => { + this._streamMapping.delete(replayedStreamId); + closeReplayStream(); } - } - }); + }); + } return new Response(readable, { headers }); } catch (error) { diff --git a/packages/server/test/server/streamableHttp.test.ts b/packages/server/test/server/streamableHttp.test.ts index ab6f22342..a2b6fb3f1 100644 --- a/packages/server/test/server/streamableHttp.test.ts +++ b/packages/server/test/server/streamableHttp.test.ts @@ -3,6 +3,8 @@ import { randomUUID } from 'node:crypto'; import type { CallToolResult, JSONRPCErrorResponse, JSONRPCMessage } from '@modelcontextprotocol/core'; import * as z from 'zod/v4'; +import { vi } from 'vitest'; + import { McpServer } from '../../src/server/mcp.js'; import type { EventId, EventStore, StreamId } from '../../src/server/streamableHttp.js'; import { WebStandardStreamableHTTPServerTransport } from '../../src/server/streamableHttp.js'; @@ -705,6 +707,73 @@ describe('Zod v4', () => { // Should have id: field in the SSE event expect(text).toContain('id:'); }); + + it('should stop replay when the replay stream closes mid-flight', async () => { + const nativeReadableStream = ReadableStream; + let controller: ReadableStreamDefaultController | undefined; + + class TrackingReadableStream extends nativeReadableStream { + constructor(source: { + cancel?: (reason?: unknown) => PromiseLike | void; + start?: (controller: ReadableStreamDefaultController) => PromiseLike | void; + }) { + super({ + cancel: source.cancel, + start: activeController => { + controller = activeController; + return source.start?.(activeController); + } + }); + } + } + + vi.stubGlobal('ReadableStream', TrackingReadableStream); + + const replayErrors: Error[] = []; + + const replayEventStore: EventStore = { + async storeEvent(): Promise { + throw new Error('storeEvent should not be called during replay'); + }, + async replayEventsAfter( + _lastEventId: EventId, + { send }: { send: (eventId: EventId, message: JSONRPCMessage) => Promise } + ): Promise { + await send('evt-1', TEST_MESSAGES.toolsList); + controller?.close(); + await send('evt-2', TEST_MESSAGES.initialize); + return 'stream-1'; + } + }; + + const replayTransport = new WebStandardStreamableHTTPServerTransport({ + sessionIdGenerator: undefined, + eventStore: replayEventStore + }); + replayTransport.onerror = error => replayErrors.push(error); + + try { + const response = await replayTransport.handleRequest( + createRequest('GET', undefined, { + extraHeaders: { + 'last-event-id': 'evt-0' + } + }) + ); + + expect(response.status).toBe(200); + expect(response.headers.get('mcp-session-id')).toBeNull(); + + const text = await response.text(); + expect(text).toContain('id: evt-1'); + expect(text).not.toContain('id: evt-2'); + expect((replayTransport as unknown as { _streamMapping: Map })._streamMapping.size).toBe(0); + expect(replayErrors).toEqual([]); + } finally { + vi.unstubAllGlobals(); + await replayTransport.close(); + } + }); }); describe('HTTPServerTransport - Protocol Version Validation', () => {