Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 62 additions & 23 deletions packages/server/src/server/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ interface StreamMapping {
cleanup: () => void;
}

const REPLAY_STREAM_CLOSED_ERROR = 'ERR_MCP_REPLAY_STREAM_CLOSED';

/**
* Configuration options for {@linkcode WebStandardStreamableHTTPServerTransport}
*/
Expand Down Expand Up @@ -306,6 +308,24 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
);
}

private createReplayStreamClosedError(): Error {
const error = new Error('Replay stream closed');
error.name = REPLAY_STREAM_CLOSED_ERROR;
return error;
}

private isReplayStreamClosedError(error: unknown): error is Error {
return error instanceof Error && error.name === REPLAY_STREAM_CLOSED_ERROR;
}

private closeStreamController(controller: ReadableStreamDefaultController<Uint8Array>): void {
try {
controller.close();
} catch {
// Controller might already be closed
}
}

/**
* Validates request headers for DNS rebinding protection.
* @returns Error response if validation fails, `undefined` if validation passes.
Expand Down Expand Up @@ -513,44 +533,63 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
// Create a ReadableStream with controller for SSE
const encoder = new TextEncoder();
let streamController: ReadableStreamDefaultController<Uint8Array>;
let replayedStreamId = '';
let streamClosed = false;

const closeReplayStream = () => {
if (streamClosed) {
return;
}

streamClosed = true;

if (replayedStreamId) {
this._streamMapping.delete(replayedStreamId);
}

this.closeStreamController(streamController!);
};

const readable = new ReadableStream<Uint8Array>({
start: controller => {
streamController = controller;
},
cancel: () => {
// Stream was cancelled by client
// Cleanup will be handled by the mapping
closeReplayStream();
}
});

// Replay events - returns the streamId for backwards compatibility
const replayedStreamId = await this._eventStore.replayEventsAfter(lastEventId, {
send: async (eventId: string, message: JSONRPCMessage) => {
const success = this.writeSSEEvent(streamController!, encoder, message, eventId);
if (!success) {
this.onerror?.(new Error('Failed replay events'));
try {
streamController!.close();
} catch {
// Controller might already be closed
try {
replayedStreamId = await this._eventStore.replayEventsAfter(lastEventId, {
send: async (eventId: string, message: JSONRPCMessage) => {
if (streamClosed) {
throw this.createReplayStreamClosedError();
}

const success = this.writeSSEEvent(streamController!, encoder, message, eventId);
if (!success) {
closeReplayStream();
throw this.createReplayStreamClosedError();
}
}
});
} catch (error) {
if (!this.isReplayStreamClosedError(error)) {
throw error;
}
});
}

this._streamMapping.set(replayedStreamId, {
controller: streamController!,
encoder,
cleanup: () => {
this._streamMapping.delete(replayedStreamId);
try {
streamController!.close();
} catch {
// Controller might already be closed
if (!streamClosed) {
this._streamMapping.set(replayedStreamId, {
controller: streamController!,
encoder,
cleanup: () => {
this._streamMapping.delete(replayedStreamId);
closeReplayStream();
}
}
});
});
}

return new Response(readable, { headers });
} catch (error) {
Expand Down
69 changes: 69 additions & 0 deletions packages/server/test/server/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import { randomUUID } from 'node:crypto';
import type { CallToolResult, JSONRPCErrorResponse, JSONRPCMessage } from '@modelcontextprotocol/core';
import * as z from 'zod/v4';

import { vi } from 'vitest';

import { McpServer } from '../../src/server/mcp.js';
import type { EventId, EventStore, StreamId } from '../../src/server/streamableHttp.js';
import { WebStandardStreamableHTTPServerTransport } from '../../src/server/streamableHttp.js';
Expand Down Expand Up @@ -705,6 +707,73 @@ describe('Zod v4', () => {
// Should have id: field in the SSE event
expect(text).toContain('id:');
});

it('should stop replay when the replay stream closes mid-flight', async () => {
const nativeReadableStream = ReadableStream;
let controller: ReadableStreamDefaultController<Uint8Array> | undefined;

class TrackingReadableStream extends nativeReadableStream<Uint8Array> {
constructor(source: {
cancel?: (reason?: unknown) => PromiseLike<void> | void;
start?: (controller: ReadableStreamDefaultController<Uint8Array>) => PromiseLike<void> | void;
}) {
super({
cancel: source.cancel,
start: activeController => {
controller = activeController;
return source.start?.(activeController);
}
});
}
}

vi.stubGlobal('ReadableStream', TrackingReadableStream);

const replayErrors: Error[] = [];

const replayEventStore: EventStore = {
async storeEvent(): Promise<EventId> {
throw new Error('storeEvent should not be called during replay');
},
async replayEventsAfter(
_lastEventId: EventId,
{ send }: { send: (eventId: EventId, message: JSONRPCMessage) => Promise<void> }
): Promise<StreamId> {
await send('evt-1', TEST_MESSAGES.toolsList);
controller?.close();
await send('evt-2', TEST_MESSAGES.initialize);
return 'stream-1';
}
};

const replayTransport = new WebStandardStreamableHTTPServerTransport({
sessionIdGenerator: undefined,
eventStore: replayEventStore
});
replayTransport.onerror = error => replayErrors.push(error);

try {
const response = await replayTransport.handleRequest(
createRequest('GET', undefined, {
extraHeaders: {
'last-event-id': 'evt-0'
}
})
);

expect(response.status).toBe(200);
expect(response.headers.get('mcp-session-id')).toBeNull();

const text = await response.text();
expect(text).toContain('id: evt-1');
expect(text).not.toContain('id: evt-2');
expect((replayTransport as unknown as { _streamMapping: Map<string, unknown> })._streamMapping.size).toBe(0);
expect(replayErrors).toEqual([]);
} finally {
vi.unstubAllGlobals();
await replayTransport.close();
}
});
});

describe('HTTPServerTransport - Protocol Version Validation', () => {
Expand Down
Loading