diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index b82731582..fb67d4365 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -1248,6 +1248,7 @@ export abstract class Protocol { } const cancel = (reason: unknown) => { + options?.signal?.removeEventListener('abort', onAbort); this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); this._cleanupTimeout(messageId); @@ -1272,6 +1273,8 @@ export abstract class Protocol { }; this._responseHandlers.set(messageId, response => { + options?.signal?.removeEventListener('abort', onAbort); + if (options?.signal?.aborted) { return; } @@ -1292,9 +1295,10 @@ export abstract class Protocol { } }); - options?.signal?.addEventListener('abort', () => { + const onAbort = () => { cancel(options?.signal?.reason); - }); + }; + options?.signal?.addEventListener('abort', onAbort, { once: true }); const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; const timeoutHandler = () => cancel(new SdkError(SdkErrorCode.RequestTimeout, 'Request timed out', { timeout })); @@ -1321,6 +1325,7 @@ export abstract class Protocol { message: jsonrpcRequest, timestamp: Date.now() }).catch(error => { + options?.signal?.removeEventListener('abort', onAbort); this._cleanupTimeout(messageId); reject(error); }); @@ -1330,6 +1335,7 @@ export abstract class Protocol { } else { // No related task - send through transport normally this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { + options?.signal?.removeEventListener('abort', onAbort); this._cleanupTimeout(messageId); reject(error); }); diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index 8675c1e03..fcc8bdcfb 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -5723,3 +5723,155 @@ describe('Error handling for missing resolvers', () => { }); }); }); + +describe('Abort signal listener cleanup', () => { + let protocol: Protocol; + let transport: MockTransport; + + beforeEach(() => { + vi.useFakeTimers(); + transport = new MockTransport(); + vi.spyOn(transport, 'send'); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected buildContext(ctx: BaseContext): BaseContext { + return ctx; + } + protected assertTaskHandlerCapability(): void {} + })(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + test('should remove abort listener when request completes successfully', async () => { + await protocol.connect(transport); + + const abortController = new AbortController(); + const removeEventListenerSpy = vi.spyOn(abortController.signal, 'removeEventListener'); + + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + + const requestPromise = testRequest(protocol, { method: 'example', params: {} }, mockSchema, { + timeout: 5000, + signal: abortController.signal + }); + + // Simulate a successful response + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 0, + result: { result: 'success' } + }); + } + + await expect(requestPromise).resolves.toEqual({ result: 'success' }); + expect(removeEventListenerSpy).toHaveBeenCalledWith('abort', expect.any(Function)); + }); + + test('should remove abort listener when request times out', async () => { + await protocol.connect(transport); + + const abortController = new AbortController(); + const removeEventListenerSpy = vi.spyOn(abortController.signal, 'removeEventListener'); + + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + + const requestPromise = testRequest(protocol, { method: 'example', params: {} }, mockSchema, { + timeout: 100, + signal: abortController.signal + }); + + vi.advanceTimersByTime(101); + + await expect(requestPromise).rejects.toThrow('Request timed out'); + expect(removeEventListenerSpy).toHaveBeenCalledWith('abort', expect.any(Function)); + }); + + test('should not accumulate listeners across multiple requests on the same signal', async () => { + await protocol.connect(transport); + + const abortController = new AbortController(); + const addEventListenerSpy = vi.spyOn(abortController.signal, 'addEventListener'); + const removeEventListenerSpy = vi.spyOn(abortController.signal, 'removeEventListener'); + + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + + // Make 3 sequential requests on the same signal + for (let i = 0; i < 3; i++) { + const requestPromise = testRequest(protocol, { method: 'example', params: {} }, mockSchema, { + timeout: 5000, + signal: abortController.signal + }); + + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: i, + result: { result: 'success' } + }); + } + + await expect(requestPromise).resolves.toEqual({ result: 'success' }); + } + + // Each request should have added and removed exactly one listener + expect(addEventListenerSpy).toHaveBeenCalledTimes(3); + expect(removeEventListenerSpy).toHaveBeenCalledTimes(3); + }); + + test('should remove abort listener when abort signal is triggered', async () => { + await protocol.connect(transport); + + const abortController = new AbortController(); + const removeEventListenerSpy = vi.spyOn(abortController.signal, 'removeEventListener'); + + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + + const requestPromise = testRequest(protocol, { method: 'example', params: {} }, mockSchema, { + timeout: 5000, + signal: abortController.signal + }); + + abortController.abort('User cancelled'); + + await expect(requestPromise).rejects.toThrow(); + // cancel() calls removeEventListener even though once:true also cleans up + expect(removeEventListenerSpy).toHaveBeenCalledWith('abort', expect.any(Function)); + }); + + test('should remove abort listener when transport.send fails', async () => { + await protocol.connect(transport); + + const abortController = new AbortController(); + const removeEventListenerSpy = vi.spyOn(abortController.signal, 'removeEventListener'); + + // Make transport.send reject + vi.spyOn(transport, 'send').mockRejectedValueOnce(new Error('Transport failure')); + + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + + const requestPromise = testRequest(protocol, { method: 'example', params: {} }, mockSchema, { + timeout: 5000, + signal: abortController.signal + }); + + await expect(requestPromise).rejects.toThrow('Transport failure'); + expect(removeEventListenerSpy).toHaveBeenCalledWith('abort', expect.any(Function)); + }); +});