From 2bf500d40406ba1322d949fe2ee30430f094cb6e Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Sun, 17 May 2026 15:55:10 +1000 Subject: [PATCH 01/19] Add session methods and SSE stream functions --- src/protocol.test.ts | 109 ++++++++++++++++++++++++++++++++++ src/protocol.ts | 59 ++++++++++++++++++ src/sse.test.ts | 138 +++++++++++++++++++++++++++++++++++++++++++ src/sse.ts | 105 ++++++++++++++++++++++++++++++++ 4 files changed, 411 insertions(+) create mode 100644 src/protocol.test.ts create mode 100644 src/protocol.ts create mode 100644 src/sse.test.ts create mode 100644 src/sse.ts diff --git a/src/protocol.test.ts b/src/protocol.test.ts new file mode 100644 index 0000000..d27199b --- /dev/null +++ b/src/protocol.test.ts @@ -0,0 +1,109 @@ +import { describe, expect, it } from "vitest"; + +import { AGENT_METHODS } from "./schema/index.js"; +import { + methodRequiresSessionHeader, + sessionIdFromParams, + isInitializeRequest, + messageIdKey, + HEADER_CONNECTION_ID, + HEADER_SESSION_ID, + EVENT_STREAM_MIME_TYPE, + JSON_MIME_TYPE, +} from "./protocol.js"; + +import type { AnyMessage } from "./jsonrpc.js"; + +describe("protocol transport helpers", () => { + it("exports HTTP transport constants", () => { + expect(HEADER_CONNECTION_ID).toBe("Acp-Connection-Id"); + expect(HEADER_SESSION_ID).toBe("Acp-Session-Id"); + expect(EVENT_STREAM_MIME_TYPE).toBe("text/event-stream"); + expect(JSON_MIME_TYPE).toBe("application/json"); + }); + + it("requires a session header for existing-session methods", () => { + expect(methodRequiresSessionHeader(AGENT_METHODS.session_cancel)).toBe( + true, + ); + expect(methodRequiresSessionHeader(AGENT_METHODS.session_close)).toBe(true); + expect(methodRequiresSessionHeader(AGENT_METHODS.session_load)).toBe(true); + expect(methodRequiresSessionHeader(AGENT_METHODS.session_prompt)).toBe( + true, + ); + expect(methodRequiresSessionHeader(AGENT_METHODS.session_resume)).toBe( + true, + ); + expect( + methodRequiresSessionHeader(AGENT_METHODS.session_set_config_option), + ).toBe(true); + expect(methodRequiresSessionHeader(AGENT_METHODS.session_set_mode)).toBe( + true, + ); + expect(methodRequiresSessionHeader(AGENT_METHODS.session_set_model)).toBe( + true, + ); + }); + + it("does not require a session header for connection-level or unsupported methods", () => { + expect(methodRequiresSessionHeader(AGENT_METHODS.initialize)).toBe(false); + expect(methodRequiresSessionHeader(AGENT_METHODS.session_new)).toBe(false); + expect(methodRequiresSessionHeader(AGENT_METHODS.session_list)).toBe(false); + expect(methodRequiresSessionHeader(AGENT_METHODS.session_fork)).toBe(false); + expect(methodRequiresSessionHeader(AGENT_METHODS.nes_start)).toBe(false); + expect(methodRequiresSessionHeader(AGENT_METHODS.nes_suggest)).toBe(false); + expect(methodRequiresSessionHeader(AGENT_METHODS.nes_close)).toBe(false); + }); + + it("extracts a top-level string session ID from params", () => { + expect(sessionIdFromParams({ sessionId: "session-1" })).toBe("session-1"); + }); + + it("returns undefined when params do not contain a top-level string session ID", () => { + expect(sessionIdFromParams(undefined)).toBeUndefined(); + expect(sessionIdFromParams(null)).toBeUndefined(); + expect(sessionIdFromParams("session-1")).toBeUndefined(); + expect(sessionIdFromParams({})).toBeUndefined(); + expect(sessionIdFromParams({ sessionId: 1 })).toBeUndefined(); + expect( + sessionIdFromParams({ nested: { sessionId: "session-1" } }), + ).toBeUndefined(); + }); + + it("detects initialize requests", () => { + const request: AnyMessage = { + jsonrpc: "2.0", + id: 1, + method: AGENT_METHODS.initialize, + params: { protocolVersion: 1, clientCapabilities: {} }, + }; + + expect(isInitializeRequest(request)).toBe(true); + }); + + it("rejects non-initialize messages", () => { + const notification: AnyMessage = { + jsonrpc: "2.0", + method: AGENT_METHODS.initialize, + params: { protocolVersion: 1, clientCapabilities: {} }, + }; + const response: AnyMessage = { jsonrpc: "2.0", id: 1, result: {} }; + const otherRequest: AnyMessage = { + jsonrpc: "2.0", + id: 1, + method: AGENT_METHODS.session_new, + params: { cwd: "/tmp", mcpServers: [] }, + }; + + expect(isInitializeRequest(notification)).toBe(false); + expect(isInitializeRequest(response)).toBe(false); + expect(isInitializeRequest(otherRequest)).toBe(false); + }); + + it("normalizes JSON-RPC request IDs for map keys", () => { + expect(messageIdKey("foo")).toBe("string:foo"); + expect(messageIdKey(1)).toBe("number:1"); + expect(messageIdKey(null)).toBeUndefined(); + expect(messageIdKey(undefined)).toBeUndefined(); + }); +}); diff --git a/src/protocol.ts b/src/protocol.ts new file mode 100644 index 0000000..a88cb95 --- /dev/null +++ b/src/protocol.ts @@ -0,0 +1,59 @@ +import { AGENT_METHODS } from "./schema/index.js"; + +import type { AnyMessage } from "./jsonrpc.js"; + +export const HEADER_CONNECTION_ID = "Acp-Connection-Id"; +export const HEADER_SESSION_ID = "Acp-Session-Id"; +export const EVENT_STREAM_MIME_TYPE = "text/event-stream"; +export const JSON_MIME_TYPE = "application/json"; + +const SESSION_SCOPED_METHODS = new Set([ + AGENT_METHODS.session_cancel, + AGENT_METHODS.session_close, + AGENT_METHODS.session_load, + AGENT_METHODS.session_prompt, + AGENT_METHODS.session_resume, + AGENT_METHODS.session_set_config_option, + AGENT_METHODS.session_set_mode, + AGENT_METHODS.session_set_model, +]); + +export function methodRequiresSessionHeader(method: string): boolean { + return SESSION_SCOPED_METHODS.has(method); +} + +export function sessionIdFromParams(params: unknown): string | undefined { + if (!isRecord(params)) { + return undefined; + } + + const sessionId = params["sessionId"]; + return typeof sessionId === "string" ? sessionId : undefined; +} + +export function isInitializeRequest(msg: AnyMessage): boolean { + return ( + msg.jsonrpc === "2.0" && + "id" in msg && + "method" in msg && + msg.method === AGENT_METHODS.initialize + ); +} + +export function messageIdKey( + id: string | number | null | undefined, +): string | undefined { + if (typeof id === "string") { + return `string:${id}`; + } + + if (typeof id === "number") { + return `number:${id}`; + } + + return undefined; +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} diff --git a/src/sse.test.ts b/src/sse.test.ts new file mode 100644 index 0000000..6af3758 --- /dev/null +++ b/src/sse.test.ts @@ -0,0 +1,138 @@ +import { describe, expect, it, vi } from "vitest"; + +import { + parseSseStream, + serializeSseEvent, + serializeSseKeepAlive, +} from "./sse.js"; + +import type { AnyMessage } from "./jsonrpc.js"; + +const encoder = new TextEncoder(); + +function streamFromChunks(chunks: string[]): ReadableStream { + return new ReadableStream({ + start(controller) { + for (const chunk of chunks) { + controller.enqueue(encoder.encode(chunk)); + } + controller.close(); + }, + }); +} + +async function collectMessages( + body: ReadableStream, +): Promise { + const messages: AnyMessage[] = []; + for await (const message of parseSseStream(body)) { + messages.push(message); + } + return messages; +} + +describe("SSE transport helpers", () => { + it("serializes a message event", () => { + const message: AnyMessage = { + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { protocolVersion: 1 }, + }; + + expect(serializeSseEvent(message)).toBe( + `data: ${JSON.stringify(message)}\n\n`, + ); + }); + + it("serializes a keepalive comment", () => { + expect(serializeSseKeepAlive()).toBe(":\n\n"); + }); + + it("parses one event", async () => { + const message: AnyMessage = { jsonrpc: "2.0", id: 1, result: { ok: true } }; + + await expect( + collectMessages(streamFromChunks([serializeSseEvent(message)])), + ).resolves.toEqual([message]); + }); + + it("parses multiple events in one chunk", async () => { + const first: AnyMessage = { jsonrpc: "2.0", id: 1, result: { ok: true } }; + const second: AnyMessage = { + jsonrpc: "2.0", + method: "session/update", + params: { sessionId: "s1" }, + }; + + await expect( + collectMessages( + streamFromChunks([ + serializeSseEvent(first) + serializeSseEvent(second), + ]), + ), + ).resolves.toEqual([first, second]); + }); + + it("parses events split across chunk boundaries", async () => { + const message: AnyMessage = { + jsonrpc: "2.0", + id: "abc", + result: { ok: true }, + }; + const serialized = serializeSseEvent(message); + + await expect( + collectMessages( + streamFromChunks([ + serialized.slice(0, 7), + serialized.slice(7, 18), + serialized.slice(18), + ]), + ), + ).resolves.toEqual([message]); + }); + + it("ignores comments, keepalives, and non-data fields", async () => { + const message: AnyMessage = { jsonrpc: "2.0", id: 1, result: { ok: true } }; + + await expect( + collectMessages( + streamFromChunks([ + `:\n\nevent: message\nid: 1\ndata: ${JSON.stringify(message)}\nretry: 1000\n\n`, + ]), + ), + ).resolves.toEqual([message]); + }); + + it("joins multiline data fields", async () => { + const expected: AnyMessage = { + jsonrpc: "2.0", + id: 1, + result: { ok: true }, + }; + const body = [ + 'data: {"jsonrpc":"2.0",\n', + 'data: "id":1,\n', + 'data: "result":{"ok":true}}\n\n', + ]; + + await expect(collectMessages(streamFromChunks(body))).resolves.toEqual([ + expected, + ]); + }); + + it("skips malformed JSON without throwing", async () => { + const warn = vi.spyOn(console, "warn").mockImplementation(() => undefined); + const message: AnyMessage = { jsonrpc: "2.0", id: 1, result: { ok: true } }; + + await expect( + collectMessages( + streamFromChunks(["data: {not-json}\n\n", serializeSseEvent(message)]), + ), + ).resolves.toEqual([message]); + expect(warn).toHaveBeenCalledOnce(); + + warn.mockRestore(); + }); +}); diff --git a/src/sse.ts b/src/sse.ts new file mode 100644 index 0000000..9b7b8cb --- /dev/null +++ b/src/sse.ts @@ -0,0 +1,105 @@ +import type { AnyMessage } from "./jsonrpc.js"; + +export function serializeSseEvent(msg: AnyMessage): string { + return `data: ${JSON.stringify(msg)}\n\n`; +} + +export function serializeSseKeepAlive(): string { + return ":\n\n"; +} + +export async function* parseSseStream( + body: ReadableStream, +): AsyncIterable { + const decoder = new TextDecoder(); + const reader = body.getReader(); + let buffer = ""; + + try { + while (true) { + const chunk = await reader.read(); + + if (chunk.done) { + buffer += decoder.decode(); + yield* parseBufferedEvents(buffer); + return; + } + + buffer += decoder.decode(chunk.value, { stream: true }); + const eventParts = buffer.split(/\r?\n\r?\n/); + buffer = eventParts.pop() ?? ""; + + for (const eventPart of eventParts) { + const msg = parseSseEvent(eventPart); + if (msg) { + yield msg; + } + } + } + } finally { + reader.releaseLock(); + } +} + +function* parseBufferedEvents(buffer: string): Iterable { + if (!buffer.trim()) { + return; + } + + const eventParts = buffer.split(/\r?\n\r?\n/); + + for (const eventPart of eventParts) { + const msg = parseSseEvent(eventPart); + if (msg) { + yield msg; + } + } +} + +function parseSseEvent(eventPart: string): AnyMessage | undefined { + const dataLines = eventPart + .split(/\r?\n/) + .filter((line) => line.startsWith("data:")) + .map((line) => { + const value = line.slice("data:".length); + return value.startsWith(" ") ? value.slice(1) : value; + }); + + if (dataLines.length === 0) { + return undefined; + } + + const data = dataLines.join("\n"); + if (!data.trim()) { + return undefined; + } + + try { + const parsed: unknown = JSON.parse(data); + if (isAnyMessage(parsed)) { + return parsed; + } + + console.warn("Skipping SSE payload that is not a JSON-RPC message"); + return undefined; + } catch (error) { + console.warn("Failed to parse SSE JSON payload:", error); + return undefined; + } +} + +function isAnyMessage(value: unknown): value is AnyMessage { + if (!isRecord(value) || value["jsonrpc"] !== "2.0") { + return false; + } + + if ("method" in value) { + return typeof value["method"] === "string"; + } + + return "id" in value; +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} From 8654fe1dca403f732bb0668eb9bb20690a3eeee4 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Mon, 18 May 2026 15:25:41 +1000 Subject: [PATCH 02/19] Add server implementation with Node HTTP adapter, only support POST and initialize --- package.json | 15 ++ src/connection.test.ts | 75 ++++++++ src/connection.ts | 99 ++++++++++ src/node-adapter.test.ts | 167 ++++++++++++++++ src/node-adapter.ts | 138 ++++++++++++++ src/server.test.ts | 274 +++++++++++++++++++++++++++ src/server.ts | 172 +++++++++++++++++ src/test-support/test-agent.ts | 85 +++++++++ src/test-support/test-http-server.ts | 74 ++++++++ 9 files changed, 1099 insertions(+) create mode 100644 src/connection.test.ts create mode 100644 src/connection.ts create mode 100644 src/node-adapter.test.ts create mode 100644 src/node-adapter.ts create mode 100644 src/server.test.ts create mode 100644 src/server.ts create mode 100644 src/test-support/test-agent.ts create mode 100644 src/test-support/test-http-server.ts diff --git a/package.json b/package.json index 528390e..b7b2b7c 100644 --- a/package.json +++ b/package.json @@ -24,6 +24,21 @@ "type": "module", "main": "dist/acp.js", "types": "dist/acp.d.ts", + "exports": { + ".": { + "types": "./dist/acp.d.ts", + "default": "./dist/acp.js" + }, + "./server": { + "types": "./dist/server.d.ts", + "default": "./dist/server.js" + }, + "./node": { + "types": "./dist/node-adapter.d.ts", + "default": "./dist/node-adapter.js" + }, + "./schema/schema.json": "./schema/schema.json" + }, "directories": { "example": "examples" }, diff --git a/src/connection.test.ts b/src/connection.test.ts new file mode 100644 index 0000000..f07bdb0 --- /dev/null +++ b/src/connection.test.ts @@ -0,0 +1,75 @@ +import { describe, expect, it } from "vitest"; +import { ConnectionRegistry } from "./connection.js"; +import { TestAgent } from "./test-support/test-agent.js"; + +import type { AgentSideConnection } from "./acp.js"; + +const initializeRequest = { + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: 1, + clientCapabilities: {}, + }, +} as const; + +describe("ConnectionRegistry", () => { + it("creates retrievable connections with unique UUID connection IDs", () => { + const registry = new ConnectionRegistry(); + const first = registry.createConnection( + (conn: AgentSideConnection) => new TestAgent(conn), + ); + const second = registry.createConnection( + (conn: AgentSideConnection) => new TestAgent(conn), + ); + + expect(first.connectionId).toMatch(/^[0-9a-f-]{36}$/); + expect(second.connectionId).toMatch(/^[0-9a-f-]{36}$/); + expect(first.connectionId).not.toBe(second.connectionId); + expect(registry.get(first.connectionId)).toBe(first); + expect(registry.get(second.connectionId)).toBe(second); + + registry.closeAll(); + }); + + it("removes connections", () => { + const registry = new ConnectionRegistry(); + const connection = registry.createConnection( + (conn: AgentSideConnection) => new TestAgent(conn), + ); + + expect(registry.remove(connection.connectionId)).toBe(connection); + expect(registry.get(connection.connectionId)).toBeUndefined(); + expect(registry.remove(connection.connectionId)).toBeUndefined(); + }); + + it("receives the initialize response directly from the agent", async () => { + const registry = new ConnectionRegistry(); + const connection = registry.createConnection( + (conn: AgentSideConnection) => new TestAgent(conn), + ); + const writer = connection.inboundTx.getWriter(); + + try { + await writer.write(initializeRequest); + } finally { + writer.releaseLock(); + } + + const response = await connection.recvInitial(initializeRequest.id); + + expect(response).toMatchObject({ + jsonrpc: "2.0", + id: initializeRequest.id, + result: { + protocolVersion: 1, + agentCapabilities: { + loadSession: false, + }, + }, + }); + + registry.closeAll(); + }); +}); diff --git a/src/connection.ts b/src/connection.ts new file mode 100644 index 0000000..1bc9bab --- /dev/null +++ b/src/connection.ts @@ -0,0 +1,99 @@ +import { AgentSideConnection } from "./acp.js"; + +import type { Agent } from "./acp.js"; +import type { AnyMessage, AnyResponse } from "./jsonrpc.js"; +import type { Stream } from "./stream.js"; + +export class ConnectionState { + readonly connectionId: string; + readonly inboundTx: WritableStream; + readonly outboundRx: ReadableStream; + readonly agentConnection: AgentSideConnection; + + constructor(agentFactory: (conn: AgentSideConnection) => Agent) { + this.connectionId = globalThis.crypto.randomUUID(); + const inbound = new TransformStream(); + const outbound = new TransformStream(); + + this.inboundTx = inbound.writable; + this.outboundRx = outbound.readable; + + const stream: Stream = { + readable: inbound.readable, + writable: outbound.writable, + }; + + this.agentConnection = new AgentSideConnection(agentFactory, stream); + } + + async recvInitial(initializeId: string | number): Promise { + const reader = this.outboundRx.getReader(); + + try { + const result = await reader.read(); + + if ( + result.done || + !result.value || + !isMatchingResponse(result.value, initializeId) + ) { + await this.shutdown(); + throw new Error("Expected initialize response from agent"); + } + + return result.value; + } finally { + reader.releaseLock(); + } + } + + async shutdown() { + await Promise.allSettled([ + this.inboundTx.close(), + this.outboundRx.cancel(), + ]); + } +} + +export class ConnectionRegistry { + private readonly connections = new Map(); + + createConnection( + agentFactory: (conn: AgentSideConnection) => Agent, + ): ConnectionState { + const connection = new ConnectionState(agentFactory); + this.connections.set(connection.connectionId, connection); + return connection; + } + + get(connectionId: string): ConnectionState | undefined { + return this.connections.get(connectionId); + } + + remove(connectionId: string): ConnectionState | undefined { + const connection = this.get(connectionId); + + if (!connection) { + return undefined; + } + + this.connections.delete(connectionId); + void connection.shutdown(); + return connection; + } + + closeAll(): void { + for (const connection of this.connections.values()) { + void connection.shutdown(); + } + + this.connections.clear(); + } +} + +function isMatchingResponse( + msg: AnyMessage, + id: string | number, +): msg is AnyResponse { + return "id" in msg && !("method" in msg) && msg.id === id; +} diff --git a/src/node-adapter.test.ts b/src/node-adapter.test.ts new file mode 100644 index 0000000..629d758 --- /dev/null +++ b/src/node-adapter.test.ts @@ -0,0 +1,167 @@ +import http from "node:http"; + +import { describe, expect, it } from "vitest"; +import { AcpServer } from "./server.js"; +import { createNodeHttpHandler } from "./node-adapter.js"; +import { TestAgent } from "./test-support/test-agent.js"; + +import type { AgentSideConnection } from "./acp.js"; + +interface RunningServer { + readonly url: string; + readonly close: () => Promise; +} + +describe("createNodeHttpHandler", () => { + it("forwards method, URL, headers, and body to AcpServer.handleRequest", async () => { + const acpServer = new AcpServer({ + createAgent: (conn: AgentSideConnection) => new TestAgent(conn), + }); + const seenRequests: Request[] = []; + const seenBodies: string[] = []; + acpServer.handleRequest = async (req) => { + seenRequests.push(req); + seenBodies.push(await req.text()); + return new Response("created", { + status: 201, + headers: { + "X-Adapter-Test": "ok", + }, + }); + }; + + const server = await startNodeServer(acpServer); + + try { + const response = await fetch(`${server.url}/acp?hello=world`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-Client-Test": "forwarded", + }, + body: JSON.stringify({ ok: true }), + }); + + expect(response.status).toBe(201); + expect(response.headers.get("X-Adapter-Test")).toBe("ok"); + expect(await response.text()).toBe("created"); + expect(seenRequests).toHaveLength(1); + expect(seenRequests[0]?.method).toBe("POST"); + expect(seenRequests[0]?.url).toBe(`${server.url}/acp?hello=world`); + expect(seenRequests[0]?.headers.get("Content-Type")).toBe( + "application/json", + ); + expect(seenRequests[0]?.headers.get("X-Client-Test")).toBe("forwarded"); + expect(seenBodies).toEqual([JSON.stringify({ ok: true })]); + } finally { + await server.close(); + } + }); + + it("streams response bodies to ServerResponse", async () => { + const acpServer = new AcpServer({ + createAgent: (conn: AgentSideConnection) => new TestAgent(conn), + }); + acpServer.handleRequest = () => + Promise.resolve( + new Response( + new ReadableStream({ + start(controller) { + const encoder = new TextEncoder(); + controller.enqueue(encoder.encode("data: one\n\n")); + controller.enqueue(encoder.encode("data: two\n\n")); + controller.close(); + }, + }), + { + status: 200, + headers: { + "Content-Type": "text/event-stream", + }, + }, + ), + ); + + const server = await startNodeServer(acpServer); + + try { + const response = await fetch(server.url, { method: "POST" }); + + expect(response.status).toBe(200); + expect(response.headers.get("Content-Type")).toContain( + "text/event-stream", + ); + expect(await response.text()).toBe("data: one\n\ndata: two\n\n"); + } finally { + await server.close(); + } + }); + + it("handles empty response bodies", async () => { + const acpServer = new AcpServer({ + createAgent: (conn: AgentSideConnection) => new TestAgent(conn), + }); + acpServer.handleRequest = () => + Promise.resolve( + new Response(null, { + status: 202, + headers: { + "X-Empty-Body": "yes", + }, + }), + ); + + const server = await startNodeServer(acpServer); + + try { + const response = await fetch(server.url, { method: "POST" }); + + expect(response.status).toBe(202); + expect(response.headers.get("X-Empty-Body")).toBe("yes"); + expect(await response.text()).toBe(""); + } finally { + await server.close(); + } + }); +}); + +async function startNodeServer(acpServer: AcpServer): Promise { + const server = http.createServer(createNodeHttpHandler(acpServer)); + + await new Promise((resolve, reject) => { + const onError = (error: Error): void => { + server.off("listening", onListening); + reject(error); + }; + + const onListening = (): void => { + server.off("error", onError); + resolve(); + }; + + server.once("error", onError); + server.once("listening", onListening); + server.listen(0, "127.0.0.1"); + }); + + const address = server.address(); + + if (typeof address !== "object" || address === null) { + throw new Error("Node test server did not bind to a TCP port"); + } + + return { + url: `http://127.0.0.1:${address.port}`, + close: () => + new Promise((resolve, reject) => { + server.close((error) => { + if (error) { + reject(error); + return; + } + + resolve(); + }); + }), + }; +} diff --git a/src/node-adapter.ts b/src/node-adapter.ts new file mode 100644 index 0000000..8a3d91d --- /dev/null +++ b/src/node-adapter.ts @@ -0,0 +1,138 @@ +import type { IncomingMessage, ServerResponse } from "node:http"; +import type { AcpServer } from "./server.js"; + +export function createNodeHttpHandler( + server: AcpServer, +): (req: IncomingMessage, res: ServerResponse) => void { + return (req, res) => { + void handleNodeRequest(server, req, res); + }; +} + +async function handleNodeRequest( + server: AcpServer, + req: IncomingMessage, + res: ServerResponse, +): Promise { + try { + await writeNodeResponse( + res, + await server.handleRequest(await toWebRequest(req)), + ); + } catch (error) { + if (!res.headersSent) { + res.statusCode = 500; + res.setHeader("Content-Type", "text/plain"); + } + + res.end(error instanceof Error ? error.message : "Internal Server Error"); + } +} + +async function toWebRequest(req: IncomingMessage): Promise { + return new Request(nodeRequestUrl(req), { + method: req.method ?? "GET", + headers: nodeHeaders(req), + body: hasRequestBody(req) ? await readRequestBody(req) : undefined, + }); +} + +function hasRequestBody(req: IncomingMessage): boolean { + return req.method !== "GET" && req.method !== "HEAD"; +} + +async function readRequestBody(req: IncomingMessage): Promise { + const chunks: string[] = []; + + for await (const chunk of req) { + chunks.push( + typeof chunk === "string" ? chunk : Buffer.from(chunk).toString("utf8"), + ); + } + + return chunks.join(""); +} + +function nodeRequestUrl(req: IncomingMessage): string { + const host = req.headers.host ?? "localhost"; + return `http://${host}${req.url ?? "/"}`; +} + +function nodeHeaders(req: IncomingMessage): Headers { + const headers = new Headers(); + + for (const [name, value] of Object.entries(req.headers)) { + if (Array.isArray(value)) { + for (const item of value) { + headers.append(name, item); + } + + continue; + } + + if (value !== undefined) { + headers.set(name, value); + } + } + + return headers; +} + +async function writeNodeResponse( + res: ServerResponse, + response: Response, +): Promise { + res.statusCode = response.status; + + response.headers.forEach((value, name) => { + res.setHeader(name, value); + }); + + const responseBody = response.body; + + if (!responseBody) { + res.end(); + return; + } + + const reader = responseBody.getReader(); + + try { + while (true) { + const result = await reader.read(); + + if (result.done) { + res.end(); + return; + } + + await writeChunk(res, result.value); + } + } finally { + reader.releaseLock(); + } +} + +function writeChunk(res: ServerResponse, chunk: Uint8Array): Promise { + return new Promise((resolve, reject) => { + const onError = (error: Error): void => { + res.off("drain", onDrain); + reject(error); + }; + + const onDrain = (): void => { + res.off("error", onError); + resolve(); + }; + + res.once("error", onError); + + if (res.write(chunk)) { + res.off("error", onError); + resolve(); + return; + } + + res.once("drain", onDrain); + }); +} diff --git a/src/server.test.ts b/src/server.test.ts new file mode 100644 index 0000000..9268338 --- /dev/null +++ b/src/server.test.ts @@ -0,0 +1,274 @@ +import { describe, expect, it } from "vitest"; +import { HEADER_CONNECTION_ID, JSON_MIME_TYPE } from "./protocol.js"; +import { startTestServer } from "./test-support/test-http-server.js"; +import { TestAgent } from "./test-support/test-agent.js"; + +import type { AgentSideConnection } from "./acp.js"; + +const initializeRequest = { + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: 1, + clientCapabilities: {}, + }, +}; + +describe("AcpServer", () => { + it("handles initialize over HTTP and returns a connection ID", async () => { + const server = await startTestServer(); + + try { + const response = await postJson(server.url, initializeRequest); + const body = await response.json(); + + expect(response.status).toBe(200); + expect(response.headers.get(HEADER_CONNECTION_ID)).toMatch( + /^[0-9a-f-]{36}$/, + ); + expect(body).toMatchObject({ + jsonrpc: "2.0", + id: 1, + result: { + protocolVersion: 1, + agentCapabilities: { + loadSession: false, + }, + }, + }); + } finally { + await server.close(); + } + }); + + it.each(["GET", "PUT", "PATCH", "DELETE"])( + "rejects %s requests in Phase 1", + async (method) => { + const server = await startTestServer(); + + try { + const response = await fetch(server.url, { method }); + + expect(response.status).toBe(405); + } finally { + await server.close(); + } + }, + ); + + it("rejects POST without application/json Content-Type", async () => { + const server = await startTestServer(); + + try { + const response = await fetch(server.url, { + method: "POST", + headers: { + "Content-Type": "text/plain", + }, + body: JSON.stringify(initializeRequest), + }); + + expect(response.status).toBe(415); + } finally { + await server.close(); + } + }); + + it("rejects invalid JSON", async () => { + const server = await startTestServer(); + + try { + const response = await fetch(server.url, { + method: "POST", + headers: { + "Content-Type": JSON_MIME_TYPE, + }, + body: "{ nope", + }); + + expect(response.status).toBe(400); + } finally { + await server.close(); + } + }); + + it("rejects JSON-RPC batches", async () => { + const server = await startTestServer(); + + try { + const response = await postJson(server.url, [initializeRequest]); + + expect(response.status).toBe(501); + } finally { + await server.close(); + } + }); + + it.each([ + null, + "initialize", + 1, + {}, + { jsonrpc: "1.0", method: "initialize" }, + ])("rejects invalid JSON-RPC messages", async (body) => { + const server = await startTestServer(); + + try { + const response = await postJson(server.url, body); + + expect(response.status).toBe(400); + } finally { + await server.close(); + } + }); + + it("rejects non-initialize requests without a connection ID", async () => { + const server = await startTestServer(); + + try { + const response = await postJson(server.url, { + jsonrpc: "2.0", + id: 2, + method: "session/new", + params: { + cwd: "/tmp", + mcpServers: [], + }, + }); + + expect(response.status).toBe(400); + } finally { + await server.close(); + } + }); + + it("rejects unknown connection IDs", async () => { + const server = await startTestServer(); + + try { + const response = await postJson( + server.url, + { + jsonrpc: "2.0", + id: 2, + method: "session/new", + params: { + cwd: "/tmp", + mcpServers: [], + }, + }, + { + [HEADER_CONNECTION_ID]: globalThis.crypto.randomUUID(), + }, + ); + + expect(response.status).toBe(404); + } finally { + await server.close(); + } + }); + + it("rejects connected POSTs after initialize in Phase 1", async () => { + const server = await startTestServer(); + + try { + const initializeResponse = await postJson(server.url, initializeRequest); + const connectionId = initializeResponse.headers.get(HEADER_CONNECTION_ID); + + expect(connectionId).toBeTruthy(); + + const response = await postJson( + server.url, + { + jsonrpc: "2.0", + id: 2, + method: "session/new", + params: { + cwd: "/tmp", + mcpServers: [], + }, + }, + { + [HEADER_CONNECTION_ID]: connectionId ?? "", + }, + ); + + expect(response.status).toBe(400); + } finally { + await server.close(); + } + }); + + it("returns an error response when agent creation fails", async () => { + const server = await startTestServer(() => { + throw new Error("agent factory failed"); + }); + + try { + const response = await postJson(server.url, initializeRequest); + const body = await response.json(); + + expect(response.status).toBe(500); + expect(response.headers.get(HEADER_CONNECTION_ID)).toBeNull(); + expect(body).toMatchObject({ + jsonrpc: "2.0", + id: 1, + error: { + code: -32603, + message: "Initialize failed", + data: "agent factory failed", + }, + }); + } finally { + await server.close(); + } + }); + + it("returns JSON-RPC initialize errors as the initialize response", async () => { + class FailingInitializeAgent extends TestAgent { + initialize() { + return Promise.reject(new Error("initialize failed")); + } + } + + const server = await startTestServer( + (conn: AgentSideConnection) => new FailingInitializeAgent(conn), + ); + + try { + const response = await postJson(server.url, initializeRequest); + const body = await response.json(); + + expect(response.status).toBe(200); + expect(response.headers.get(HEADER_CONNECTION_ID)).toMatch( + /^[0-9a-f-]{36}$/, + ); + expect(body).toMatchObject({ + jsonrpc: "2.0", + id: 1, + error: { + code: -32603, + message: "Internal error", + }, + }); + } finally { + await server.close(); + } + }); +}); + +function postJson( + url: string, + body: unknown, + headers: Record = {}, +): Promise { + return fetch(url, { + method: "POST", + headers: { + "Content-Type": JSON_MIME_TYPE, + ...headers, + }, + body: JSON.stringify(body), + }); +} diff --git a/src/server.ts b/src/server.ts new file mode 100644 index 0000000..9d2bee8 --- /dev/null +++ b/src/server.ts @@ -0,0 +1,172 @@ +import { ConnectionRegistry } from "./connection.js"; +import { + HEADER_CONNECTION_ID, + JSON_MIME_TYPE, + isInitializeRequest, +} from "./protocol.js"; + +import type { Agent, AgentSideConnection } from "./acp.js"; +import type { AnyMessage } from "./jsonrpc.js"; + +export interface AcpServerOptions { + createAgent: (conn: AgentSideConnection) => Agent; +} + +export class AcpServer { + private readonly createAgent: (conn: AgentSideConnection) => Agent; + private readonly registry = new ConnectionRegistry(); + + constructor(options: AcpServerOptions) { + this.createAgent = options.createAgent; + } + + async handleRequest(req: Request): Promise { + if (req.method !== "POST") { + return textResponse("Method Not Allowed", 405); + } + + const contentType = req.headers.get("Content-Type"); + + if (!contentType?.startsWith(JSON_MIME_TYPE)) { + return textResponse("Unsupported Media Type", 415); + } + + const body = await readJson(req); + + if (!body.ok) { + return textResponse("Invalid JSON", 400); + } + + if (Array.isArray(body.value)) { + return textResponse("Batch JSON-RPC requests are not implemented", 501); + } + + if (!isJsonRpcMessage(body.value)) { + return textResponse("Invalid JSON-RPC message", 400); + } + + const connectionId = req.headers.get(HEADER_CONNECTION_ID); + + if (isInitializeRequest(body.value) && !connectionId) { + return await this.handleInitialize(body.value); + } + + if (!connectionId) { + return textResponse("Missing Acp-Connection-Id", 400); + } + + if (!this.registry.get(connectionId)) { + return textResponse("Unknown Acp-Connection-Id", 404); + } + + return textResponse( + "Connected POST handling is not implemented in Phase 1", + 400, + ); + } + + async close(): Promise { + this.registry.closeAll(); + } + + private async handleInitialize(message: AnyMessage): Promise { + if (!("id" in message) || message.id === null) { + return textResponse("Initialize request must include an ID", 400); + } + + let connection: + | ReturnType + | undefined; + + try { + connection = this.registry.createConnection(this.createAgent); + const writer = connection.inboundTx.getWriter(); + + try { + await writer.write(message); + } finally { + writer.releaseLock(); + } + + const initialResponse = await connection.recvInitial(message.id); + + return jsonResponse(initialResponse, 200, { + [HEADER_CONNECTION_ID]: connection.connectionId, + }); + } catch (error) { + if (connection) { + this.registry.remove(connection.connectionId); + } + + return jsonResponse( + { + jsonrpc: "2.0", + id: message.id, + error: { + code: -32603, + message: "Initialize failed", + data: error instanceof Error ? error.message : undefined, + }, + }, + 500, + ); + } + } +} + +type JsonResult = + | { + ok: true; + value: unknown; + } + | { + ok: false; + }; + +async function readJson(req: Request): Promise { + try { + return { + ok: true, + value: await req.json(), + }; + } catch { + return { + ok: false, + }; + } +} + +function isJsonRpcMessage(value: unknown): value is AnyMessage { + return ( + isRecord(value) && + value.jsonrpc === "2.0" && + ("method" in value || "id" in value) + ); +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + +function jsonResponse( + value: unknown, + status: number, + headers?: HeadersInit, +): Response { + return new Response(JSON.stringify(value), { + status, + headers: { + "Content-Type": JSON_MIME_TYPE, + ...headers, + }, + }); +} + +function textResponse(body: string, status: number): Response { + return new Response(body, { + status, + headers: { + "Content-Type": "text/plain", + }, + }); +} diff --git a/src/test-support/test-agent.ts b/src/test-support/test-agent.ts new file mode 100644 index 0000000..a1fcc87 --- /dev/null +++ b/src/test-support/test-agent.ts @@ -0,0 +1,85 @@ +import { PROTOCOL_VERSION } from "../schema/index.js"; + +import type { + Agent, + AgentSideConnection, + AuthenticateRequest, + AuthenticateResponse, + CancelNotification, + InitializeRequest, + InitializeResponse, + NewSessionRequest, + NewSessionResponse, + PromptRequest, + PromptResponse, +} from "../acp.js"; + +export interface TestAgentOptions { + readonly chunkCount?: number; + readonly chunkDelayMs?: number; +} + +export class TestAgent implements Agent { + private readonly connection: AgentSideConnection; + private readonly chunkCount: number; + private readonly chunkDelayMs: number; + + constructor(connection: AgentSideConnection, options: TestAgentOptions = {}) { + this.connection = connection; + this.chunkCount = options.chunkCount ?? 1; + this.chunkDelayMs = options.chunkDelayMs ?? 0; + } + + initialize(_params: InitializeRequest): Promise { + return Promise.resolve({ + protocolVersion: PROTOCOL_VERSION, + agentCapabilities: { + loadSession: false, + }, + }); + } + + newSession(_params: NewSessionRequest): Promise { + return Promise.resolve({ sessionId: globalThis.crypto.randomUUID() }); + } + + authenticate( + _params: AuthenticateRequest, + ): Promise { + return Promise.resolve(); + } + + async prompt(params: PromptRequest): Promise { + for (const index of Array.from( + { length: this.chunkCount }, + (_, item) => item, + )) { + if (this.chunkDelayMs > 0) { + await delay(this.chunkDelayMs); + } + + await this.connection.sessionUpdate({ + sessionId: params.sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { + type: "text", + text: `chunk-${index + 1}`, + }, + }, + }); + } + + return { stopReason: "end_turn" }; + } + + cancel(_params: CancelNotification): Promise { + return Promise.resolve(); + } +} + +function delay(ms: number): Promise { + return new Promise((resolve) => { + setTimeout(resolve, ms); + }); +} diff --git a/src/test-support/test-http-server.ts b/src/test-support/test-http-server.ts new file mode 100644 index 0000000..de29c5e --- /dev/null +++ b/src/test-support/test-http-server.ts @@ -0,0 +1,74 @@ +import http from "node:http"; + +import { AcpServer } from "../server.js"; +import { createNodeHttpHandler } from "../node-adapter.js"; +import { TestAgent } from "./test-agent.js"; + +import type { AddressInfo } from "node:net"; +import type { Agent, AgentSideConnection } from "../acp.js"; + +export interface TestHttpServer { + readonly url: string; + readonly close: () => Promise; +} + +export async function startTestServer( + agentFactory: (conn: AgentSideConnection) => Agent = (conn) => + new TestAgent(conn), + options: { port?: number } = {}, +): Promise { + const acpServer = new AcpServer({ createAgent: agentFactory }); + const httpServer = http.createServer(createNodeHttpHandler(acpServer)); + + await listen(httpServer, options.port ?? 0); + + const address = httpServer.address(); + + if (!isAddressInfo(address)) { + throw new Error("Test HTTP server did not bind to a TCP port"); + } + + return { + url: `http://127.0.0.1:${address.port}`, + close: async () => { + await Promise.all([acpServer.close(), closeHttpServer(httpServer)]); + }, + }; +} + +function listen(server: http.Server, port: number): Promise { + return new Promise((resolve, reject) => { + const onError = (error: Error): void => { + server.off("listening", onListening); + reject(error); + }; + + const onListening = (): void => { + server.off("error", onError); + resolve(); + }; + + server.once("error", onError); + server.once("listening", onListening); + server.listen(port, "127.0.0.1"); + }); +} + +function closeHttpServer(server: http.Server): Promise { + return new Promise((resolve, reject) => { + server.close((error) => { + if (error) { + reject(error); + return; + } + + resolve(); + }); + }); +} + +function isAddressInfo( + address: ReturnType, +): address is AddressInfo { + return typeof address === "object" && address !== null; +} From 1a779dc1b7a647df56f6336a253bc46db3bb87e3 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Mon, 18 May 2026 16:46:22 +1000 Subject: [PATCH 03/19] Add SSE routing and session/new --- src/connection.test.ts | 196 ++++++++++++++++++++++- src/connection.ts | 240 +++++++++++++++++++++++++++- src/node-adapter.ts | 2 + src/server.test.ts | 348 +++++++++++++++++++++++++++++++++-------- src/server.ts | 219 ++++++++++++++++++++++++-- 5 files changed, 910 insertions(+), 95 deletions(-) diff --git a/src/connection.test.ts b/src/connection.test.ts index f07bdb0..e0f52ec 100644 --- a/src/connection.test.ts +++ b/src/connection.test.ts @@ -1,8 +1,14 @@ -import { describe, expect, it } from "vitest"; -import { ConnectionRegistry } from "./connection.js"; +import { describe, expect, it, vi } from "vitest"; +import { + ConnectionRegistry, + OutboundStream, + type ResponseRoute, +} from "./connection.js"; +import { messageIdKey } from "./protocol.js"; import { TestAgent } from "./test-support/test-agent.js"; import type { AgentSideConnection } from "./acp.js"; +import type { AnyMessage } from "./jsonrpc.js"; const initializeRequest = { jsonrpc: "2.0", @@ -14,6 +20,21 @@ const initializeRequest = { }, } as const; +const sessionNewRequest = { + jsonrpc: "2.0", + id: 2, + method: "session/new", + params: { + cwd: "/tmp", + mcpServers: [], + }, +} as const; + +const messageOne = { jsonrpc: "2.0", id: 1, result: "one" } as const; +const messageTwo = { jsonrpc: "2.0", id: 2, result: "two" } as const; +const messageThree = { jsonrpc: "2.0", id: 3, result: "three" } as const; +const messageFour = { jsonrpc: "2.0", id: 4, result: "four" } as const; + describe("ConnectionRegistry", () => { it("creates retrievable connections with unique UUID connection IDs", () => { const registry = new ConnectionRegistry(); @@ -49,13 +70,8 @@ describe("ConnectionRegistry", () => { const connection = registry.createConnection( (conn: AgentSideConnection) => new TestAgent(conn), ); - const writer = connection.inboundTx.getWriter(); - try { - await writer.write(initializeRequest); - } finally { - writer.releaseLock(); - } + await writeInbound(connection.inboundTx, initializeRequest); const response = await connection.recvInitial(initializeRequest.id); @@ -72,4 +88,168 @@ describe("ConnectionRegistry", () => { registry.closeAll(); }); + + it("routes pending responses to the connection stream and all outbound stream", async () => { + const registry = new ConnectionRegistry(); + const connection = registry.createConnection( + (conn: AgentSideConnection) => new TestAgent(conn), + ); + + await initializeConnection(connection); + + const connectionSubscription = connection.connectionStream.subscribe(); + const allOutboundSubscription = connection.allOutbound.subscribe(); + const key = messageIdKey(sessionNewRequest.id); + + expect(key).toBe("number:2"); + connection.pendingRoutes.set(key ?? "", "connection"); + + await writeInbound(connection.inboundTx, sessionNewRequest); + + const connectionMessage = await readNext(connectionSubscription.stream); + const allOutboundMessage = await readNext(allOutboundSubscription.stream); + + expect(connectionMessage).toMatchObject({ + jsonrpc: "2.0", + id: sessionNewRequest.id, + result: { + sessionId: expect.stringMatching(/^[0-9a-f-]{36}$/), + }, + }); + expect(allOutboundMessage).toMatchObject(connectionMessage); + expect(connection.pendingRoutes.has(key ?? "")).toBe(false); + + registry.closeAll(); + }); + + it("falls back to the connection stream for responses without a pending route", async () => { + const registry = new ConnectionRegistry(); + const connection = registry.createConnection( + (conn: AgentSideConnection) => new TestAgent(conn), + ); + + await initializeConnection(connection); + + const subscription = connection.connectionStream.subscribe(); + + await writeInbound(connection.inboundTx, sessionNewRequest); + + expect(await readNext(subscription.stream)).toMatchObject({ + jsonrpc: "2.0", + id: sessionNewRequest.id, + result: { + sessionId: expect.stringMatching(/^[0-9a-f-]{36}$/), + }, + }); + + registry.closeAll(); + }); +}); + +describe("OutboundStream", () => { + it("replays buffered messages to the first subscriber", () => { + const stream = new OutboundStream(); + + stream.push(messageOne); + stream.push(messageTwo); + + expect(stream.subscribe().replay).toEqual([messageOne, messageTwo]); + }); + + it("does not replay buffered messages to later subscribers", async () => { + const stream = new OutboundStream(); + + stream.push(messageOne); + + const first = stream.subscribe(); + const second = stream.subscribe(); + + expect(first.replay).toEqual([messageOne]); + expect(second.replay).toEqual([]); + + stream.push(messageTwo); + + expect(await readNext(first.stream)).toEqual(messageTwo); + expect(await readNext(second.stream)).toEqual(messageTwo); + }); + + it("evicts oldest replay messages when capacity is exceeded", () => { + const stream = new OutboundStream(2); + + stream.push(messageOne); + stream.push(messageTwo); + stream.push(messageThree); + + expect(stream.subscribe().replay).toEqual([messageTwo, messageThree]); + }); + + it("drops oldest queued live messages for lagging subscribers", async () => { + const warn = vi.spyOn(console, "warn").mockImplementation(() => undefined); + const stream = new OutboundStream(2); + const subscription = stream.subscribe(); + + stream.push(messageOne); + stream.push(messageTwo); + stream.push(messageThree); + stream.push(messageFour); + + expect(await readNext(subscription.stream)).toEqual(messageOne); + expect(await readNext(subscription.stream)).toEqual(messageThree); + expect(await readNext(subscription.stream)).toEqual(messageFour); + expect(warn).toHaveBeenCalledOnce(); + + warn.mockRestore(); + }); + + it("closes subscriber streams", async () => { + const stream = new OutboundStream(); + const reader = stream.subscribe().stream.getReader(); + + stream.close(); + + expect(await reader.read()).toEqual({ done: true, value: undefined }); + reader.releaseLock(); + }); }); + +type TestConnection = ReturnType; + +async function initializeConnection(connection: TestConnection): Promise { + await writeInbound(connection.inboundTx, initializeRequest); + await connection.recvInitial(initializeRequest.id); + connection.startRouter(); +} + +async function writeInbound( + stream: WritableStream, + message: AnyMessage, +): Promise { + const writer = stream.getWriter(); + + try { + await writer.write(message); + } finally { + writer.releaseLock(); + } +} + +async function readNext( + stream: ReadableStream, +): Promise { + const reader = stream.getReader(); + + try { + const result = await reader.read(); + + if (result.done) { + throw new Error("Expected stream message"); + } + + return result.value; + } finally { + reader.releaseLock(); + } +} + +const routeShapeCheck = "connection" satisfies ResponseRoute; +void routeShapeCheck; diff --git a/src/connection.ts b/src/connection.ts index 1bc9bab..c02994a 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -1,14 +1,93 @@ import { AgentSideConnection } from "./acp.js"; +import { messageIdKey } from "./protocol.js"; import type { Agent } from "./acp.js"; import type { AnyMessage, AnyResponse } from "./jsonrpc.js"; import type { Stream } from "./stream.js"; +export type ResponseRoute = "connection" | { readonly session: string }; + +export interface OutboundSubscription { + readonly replay: readonly AnyMessage[]; + readonly stream: ReadableStream; +} + +export class OutboundStream { + private readonly subscribers = new Set(); + private replayBuffer: AnyMessage[] = []; + private hasSubscriber = false; + private isClosed = false; + + constructor(private readonly capacity = 1024) {} + + push(message: AnyMessage): void { + if (this.isClosed) { + return; + } + + if (!this.hasSubscriber) { + this.replayBuffer.push(message); + + if (this.replayBuffer.length > this.capacity) { + this.replayBuffer.shift(); + } + + return; + } + + for (const subscriber of this.subscribers) { + subscriber.push(message); + } + } + + subscribe(): OutboundSubscription { + const replay = this.hasSubscriber ? [] : [...this.replayBuffer]; + this.replayBuffer = []; + this.hasSubscriber = true; + + const subscriber = new OutboundSubscriber(this.capacity, (item) => { + this.subscribers.delete(item); + }); + + this.subscribers.add(subscriber); + + if (this.isClosed) { + subscriber.close(); + } + + return { + replay, + stream: subscriber.stream, + }; + } + + close(): void { + if (this.isClosed) { + return; + } + + this.isClosed = true; + this.replayBuffer = []; + + for (const subscriber of this.subscribers) { + subscriber.close(); + } + + this.subscribers.clear(); + } +} + export class ConnectionState { readonly connectionId: string; readonly inboundTx: WritableStream; readonly outboundRx: ReadableStream; readonly agentConnection: AgentSideConnection; + readonly connectionStream = new OutboundStream(); + readonly allOutbound = new OutboundStream(); + readonly pendingRoutes = new Map(); + + private hasStartedRouter = false; + private outboundReader: ReadableStreamDefaultReader | undefined; constructor(agentFactory: (conn: AgentSideConnection) => Agent) { this.connectionId = globalThis.crypto.randomUUID(); @@ -47,12 +126,79 @@ export class ConnectionState { } } - async shutdown() { + startRouter(): void { + if (this.hasStartedRouter) { + return; + } + + this.hasStartedRouter = true; + void this.runRouter(); + } + + async shutdown(): Promise { + this.connectionStream.close(); + this.allOutbound.close(); + this.pendingRoutes.clear(); + await Promise.allSettled([ this.inboundTx.close(), - this.outboundRx.cancel(), + this.outboundReader?.cancel() ?? this.outboundRx.cancel(), ]); } + + private async runRouter(): Promise { + const reader = this.outboundRx.getReader(); + this.outboundReader = reader; + + try { + while (true) { + const result = await reader.read(); + + if (result.done) { + return; + } + + this.routeOutbound(result.value); + } + } catch (error) { + console.error("ACP connection router stopped unexpectedly:", error); + } finally { + if (this.outboundReader === reader) { + this.outboundReader = undefined; + } + + reader.releaseLock(); + this.connectionStream.close(); + this.allOutbound.close(); + } + } + + private routeOutbound(message: AnyMessage): void { + this.allOutbound.push(message); + + if (isResponse(message)) { + const key = messageIdKey(message.id); + const route = key ? this.pendingRoutes.get(key) : undefined; + + if (key) { + this.pendingRoutes.delete(key); + } + + this.pushToRoute(route ?? "connection", message); + return; + } + + this.connectionStream.push(message); + } + + private pushToRoute(route: ResponseRoute, message: AnyMessage): void { + if (route === "connection") { + this.connectionStream.push(message); + return; + } + + this.connectionStream.push(message); + } } export class ConnectionRegistry { @@ -91,9 +237,99 @@ export class ConnectionRegistry { } } +class OutboundSubscriber { + readonly stream: ReadableStream; + + private controller: ReadableStreamDefaultController | undefined; + private queue: AnyMessage[] = []; + private isClosed = false; + private hasWarnedAboutOverflow = false; + + constructor( + private readonly capacity: number, + private readonly onCancel: (subscriber: OutboundSubscriber) => void, + ) { + this.stream = new ReadableStream({ + start: (controller) => { + this.controller = controller; + this.flush(); + }, + pull: () => { + this.flush(); + }, + cancel: () => { + this.cancel(); + }, + }); + } + + push(message: AnyMessage): void { + if (this.isClosed) { + return; + } + + this.queue.push(message); + + if (this.queue.length > this.capacity) { + this.queue.shift(); + + if (!this.hasWarnedAboutOverflow) { + console.warn("ACP outbound subscriber lagged; dropping oldest message"); + this.hasWarnedAboutOverflow = true; + } + } + + this.flush(); + } + + close(): void { + if (this.isClosed) { + return; + } + + this.isClosed = true; + this.queue = []; + this.controller?.close(); + } + + private cancel(): void { + this.isClosed = true; + this.queue = []; + this.onCancel(this); + } + + private flush(): void { + if (!this.controller) { + return; + } + + while ( + this.queue.length > 0 && + this.controller.desiredSize !== null && + this.controller.desiredSize > 0 + ) { + const message = this.queue.shift(); + + if (!message) { + return; + } + + this.controller.enqueue(message); + } + + if (this.queue.length === 0) { + this.hasWarnedAboutOverflow = false; + } + } +} + function isMatchingResponse( msg: AnyMessage, id: string | number, ): msg is AnyResponse { return "id" in msg && !("method" in msg) && msg.id === id; } + +function isResponse(msg: AnyMessage): msg is AnyResponse { + return "id" in msg && !("method" in msg); +} diff --git a/src/node-adapter.ts b/src/node-adapter.ts index 8a3d91d..a92d7c0 100644 --- a/src/node-adapter.ts +++ b/src/node-adapter.ts @@ -88,6 +88,8 @@ async function writeNodeResponse( res.setHeader(name, value); }); + res.flushHeaders(); + const responseBody = response.body; if (!responseBody) { diff --git a/src/server.test.ts b/src/server.test.ts index 9268338..50727ee 100644 --- a/src/server.test.ts +++ b/src/server.test.ts @@ -1,9 +1,17 @@ import { describe, expect, it } from "vitest"; -import { HEADER_CONNECTION_ID, JSON_MIME_TYPE } from "./protocol.js"; -import { startTestServer } from "./test-support/test-http-server.js"; +import { + EVENT_STREAM_MIME_TYPE, + HEADER_CONNECTION_ID, + HEADER_SESSION_ID, + JSON_MIME_TYPE, +} from "./protocol.js"; +import { AcpServer } from "./server.js"; +import { parseSseStream } from "./sse.js"; import { TestAgent } from "./test-support/test-agent.js"; +import { startTestServer } from "./test-support/test-http-server.js"; import type { AgentSideConnection } from "./acp.js"; +import type { AnyMessage } from "./jsonrpc.js"; const initializeRequest = { jsonrpc: "2.0", @@ -15,6 +23,16 @@ const initializeRequest = { }, }; +const sessionNewRequest = { + jsonrpc: "2.0", + id: 2, + method: "session/new", + params: { + cwd: "/tmp", + mcpServers: [], + }, +}; + describe("AcpServer", () => { it("handles initialize over HTTP and returns a connection ID", async () => { const server = await startTestServer(); @@ -42,20 +60,223 @@ describe("AcpServer", () => { } }); - it.each(["GET", "PUT", "PATCH", "DELETE"])( - "rejects %s requests in Phase 1", - async (method) => { - const server = await startTestServer(); + it("streams session/new responses over the connection SSE stream", async () => { + const server = await startTestServer(); - try { - const response = await fetch(server.url, { method }); + try { + const connectionId = await initialize(server.url); + const sseResponse = await openConnectionSse(server.url, connectionId); - expect(response.status).toBe(405); - } finally { - await server.close(); - } - }, - ); + expect(sseResponse.status).toBe(200); + expect(sseResponse.headers.get("Content-Type")).toContain( + EVENT_STREAM_MIME_TYPE, + ); + + const accepted = await postJson(server.url, sessionNewRequest, { + [HEADER_CONNECTION_ID]: connectionId, + }); + + expect(accepted.status).toBe(202); + expect(await accepted.text()).toBe(""); + expect(await readFirstSseMessage(sseResponse)).toMatchObject({ + jsonrpc: "2.0", + id: sessionNewRequest.id, + result: { + sessionId: expect.stringMatching(/^[0-9a-f-]{36}$/), + }, + }); + } finally { + await server.close(); + } + }); + + it("replays buffered connection messages when SSE attaches after POST", async () => { + const server = await startTestServer(); + + try { + const connectionId = await initialize(server.url); + const accepted = await postJson(server.url, sessionNewRequest, { + [HEADER_CONNECTION_ID]: connectionId, + }); + + expect(accepted.status).toBe(202); + + const sseResponse = await openConnectionSse(server.url, connectionId); + + expect(await readFirstSseMessage(sseResponse)).toMatchObject({ + jsonrpc: "2.0", + id: sessionNewRequest.id, + result: { + sessionId: expect.stringMatching(/^[0-9a-f-]{36}$/), + }, + }); + } finally { + await server.close(); + } + }); + + it.each(["PUT", "PATCH"])("rejects %s requests", async (method) => { + const server = await startTestServer(); + + try { + const response = await fetch(server.url, { method }); + + expect(response.status).toBe(405); + } finally { + await server.close(); + } + }); + + it("rejects GET without Accept: text/event-stream", async () => { + const server = await startTestServer(); + + try { + const response = await fetch(server.url, { + method: "GET", + headers: { + [HEADER_CONNECTION_ID]: globalThis.crypto.randomUUID(), + }, + }); + + expect(response.status).toBe(406); + } finally { + await server.close(); + } + }); + + it("rejects GET without a connection ID", async () => { + const server = await startTestServer(); + + try { + const response = await fetch(server.url, { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + }, + }); + + expect(response.status).toBe(400); + } finally { + await server.close(); + } + }); + + it("rejects GET with an unknown connection ID", async () => { + const server = await startTestServer(); + + try { + const response = await openConnectionSse( + server.url, + globalThis.crypto.randomUUID(), + ); + + expect(response.status).toBe(404); + } finally { + await server.close(); + } + }); + + it("rejects session-scoped GETs until session SSE is implemented", async () => { + const server = await startTestServer(); + + try { + const connectionId = await initialize(server.url); + const response = await fetch(server.url, { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: globalThis.crypto.randomUUID(), + }, + }); + + expect(response.status).toBe(404); + } finally { + await server.close(); + } + }); + + it("returns 426 for WebSocket upgrade GETs", async () => { + const server = new AcpServer({ + createAgent: (conn: AgentSideConnection) => new TestAgent(conn), + }); + + try { + const response = await server.handleRequest( + new Request("http://127.0.0.1/acp", { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + Upgrade: "websocket", + }, + }), + ); + + expect(response.status).toBe(426); + } finally { + await server.close(); + } + }); + + it("deletes connections and closes SSE streams", async () => { + const server = await startTestServer(); + + try { + const connectionId = await initialize(server.url); + const sseResponse = await openConnectionSse(server.url, connectionId); + const reader = sseResponse.body?.getReader(); + + expect(reader).toBeDefined(); + + const deleted = await fetch(server.url, { + method: "DELETE", + headers: { + [HEADER_CONNECTION_ID]: connectionId, + }, + }); + + expect(deleted.status).toBe(202); + expect(await reader?.read()).toEqual({ done: true, value: undefined }); + reader?.releaseLock(); + + const postAfterDelete = await postJson(server.url, sessionNewRequest, { + [HEADER_CONNECTION_ID]: connectionId, + }); + + expect(postAfterDelete.status).toBe(404); + } finally { + await server.close(); + } + }); + + it("rejects DELETE without a connection ID", async () => { + const server = await startTestServer(); + + try { + const response = await fetch(server.url, { method: "DELETE" }); + + expect(response.status).toBe(400); + } finally { + await server.close(); + } + }); + + it("rejects DELETE with an unknown connection ID", async () => { + const server = await startTestServer(); + + try { + const response = await fetch(server.url, { + method: "DELETE", + headers: { + [HEADER_CONNECTION_ID]: globalThis.crypto.randomUUID(), + }, + }); + + expect(response.status).toBe(404); + } finally { + await server.close(); + } + }); it("rejects POST without application/json Content-Type", async () => { const server = await startTestServer(); @@ -127,15 +348,7 @@ describe("AcpServer", () => { const server = await startTestServer(); try { - const response = await postJson(server.url, { - jsonrpc: "2.0", - id: 2, - method: "session/new", - params: { - cwd: "/tmp", - mcpServers: [], - }, - }); + const response = await postJson(server.url, sessionNewRequest); expect(response.status).toBe(400); } finally { @@ -147,21 +360,9 @@ describe("AcpServer", () => { const server = await startTestServer(); try { - const response = await postJson( - server.url, - { - jsonrpc: "2.0", - id: 2, - method: "session/new", - params: { - cwd: "/tmp", - mcpServers: [], - }, - }, - { - [HEADER_CONNECTION_ID]: globalThis.crypto.randomUUID(), - }, - ); + const response = await postJson(server.url, sessionNewRequest, { + [HEADER_CONNECTION_ID]: globalThis.crypto.randomUUID(), + }); expect(response.status).toBe(404); } finally { @@ -169,37 +370,6 @@ describe("AcpServer", () => { } }); - it("rejects connected POSTs after initialize in Phase 1", async () => { - const server = await startTestServer(); - - try { - const initializeResponse = await postJson(server.url, initializeRequest); - const connectionId = initializeResponse.headers.get(HEADER_CONNECTION_ID); - - expect(connectionId).toBeTruthy(); - - const response = await postJson( - server.url, - { - jsonrpc: "2.0", - id: 2, - method: "session/new", - params: { - cwd: "/tmp", - mcpServers: [], - }, - }, - { - [HEADER_CONNECTION_ID]: connectionId ?? "", - }, - ); - - expect(response.status).toBe(400); - } finally { - await server.close(); - } - }); - it("returns an error response when agent creation fails", async () => { const server = await startTestServer(() => { throw new Error("agent factory failed"); @@ -258,6 +428,46 @@ describe("AcpServer", () => { }); }); +async function initialize(url: string): Promise { + const response = await postJson(url, initializeRequest); + const connectionId = response.headers.get(HEADER_CONNECTION_ID); + + expect(response.status).toBe(200); + expect(connectionId).toMatch(/^[0-9a-f-]{36}$/); + + return connectionId ?? ""; +} + +function openConnectionSse( + url: string, + connectionId: string, +): Promise { + return fetch(url, { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + }, + }); +} + +async function readFirstSseMessage(response: Response): Promise { + if (!response.body) { + throw new Error("Expected SSE response body"); + } + + const iterator = parseSseStream(response.body)[Symbol.asyncIterator](); + const result = await iterator.next(); + await iterator.return?.(); + await response.body.cancel(); + + if (result.done) { + throw new Error("Expected SSE message"); + } + + return result.value; +} + function postJson( url: string, body: unknown, diff --git a/src/server.ts b/src/server.ts index 9d2bee8..f758736 100644 --- a/src/server.ts +++ b/src/server.ts @@ -1,10 +1,19 @@ import { ConnectionRegistry } from "./connection.js"; import { + EVENT_STREAM_MIME_TYPE, HEADER_CONNECTION_ID, + HEADER_SESSION_ID, JSON_MIME_TYPE, isInitializeRequest, + messageIdKey, } from "./protocol.js"; +import { serializeSseEvent, serializeSseKeepAlive } from "./sse.js"; +import type { + ConnectionState, + OutboundSubscription, + ResponseRoute, +} from "./connection.js"; import type { Agent, AgentSideConnection } from "./acp.js"; import type { AnyMessage } from "./jsonrpc.js"; @@ -21,10 +30,26 @@ export class AcpServer { } async handleRequest(req: Request): Promise { - if (req.method !== "POST") { - return textResponse("Method Not Allowed", 405); + if (req.method === "POST") { + return await this.handlePost(req); } + if (req.method === "GET") { + return this.handleGet(req); + } + + if (req.method === "DELETE") { + return this.handleDelete(req); + } + + return textResponse("Method Not Allowed", 405); + } + + async close(): Promise { + this.registry.closeAll(); + } + + private async handlePost(req: Request): Promise { const contentType = req.headers.get("Content-Type"); if (!contentType?.startsWith(JSON_MIME_TYPE)) { @@ -55,18 +80,58 @@ export class AcpServer { return textResponse("Missing Acp-Connection-Id", 400); } - if (!this.registry.get(connectionId)) { + const connection = this.registry.get(connectionId); + + if (!connection) { return textResponse("Unknown Acp-Connection-Id", 404); } - return textResponse( - "Connected POST handling is not implemented in Phase 1", - 400, - ); + await this.forwardConnectedMessage(connection, body.value); + return emptyResponse(202); } - async close(): Promise { - this.registry.closeAll(); + private handleGet(req: Request): Response { + if (req.headers.get("Upgrade")?.toLowerCase() === "websocket") { + return textResponse("WebSocket upgrade is not implemented", 426); + } + + const accept = req.headers.get("Accept")?.toLowerCase(); + + if (!accept?.includes(EVENT_STREAM_MIME_TYPE)) { + return textResponse("Not Acceptable", 406); + } + + const connectionId = req.headers.get(HEADER_CONNECTION_ID); + + if (!connectionId) { + return textResponse("Missing Acp-Connection-Id", 400); + } + + const connection = this.registry.get(connectionId); + + if (!connection) { + return textResponse("Unknown Acp-Connection-Id", 404); + } + + if (req.headers.get(HEADER_SESSION_ID)) { + return textResponse("Unknown Acp-Session-Id", 404); + } + + return sseResponse(connection.connectionStream.subscribe()); + } + + private handleDelete(req: Request): Response { + const connectionId = req.headers.get(HEADER_CONNECTION_ID); + + if (!connectionId) { + return textResponse("Missing Acp-Connection-Id", 400); + } + + if (!this.registry.remove(connectionId)) { + return textResponse("Unknown Acp-Connection-Id", 404); + } + + return emptyResponse(202); } private async handleInitialize(message: AnyMessage): Promise { @@ -80,15 +145,10 @@ export class AcpServer { try { connection = this.registry.createConnection(this.createAgent); - const writer = connection.inboundTx.getWriter(); - - try { - await writer.write(message); - } finally { - writer.releaseLock(); - } + await writeInbound(connection, message); const initialResponse = await connection.recvInitial(message.id); + connection.startRouter(); return jsonResponse(initialResponse, 200, { [HEADER_CONNECTION_ID]: connection.connectionId, @@ -112,6 +172,21 @@ export class AcpServer { ); } } + + private async forwardConnectedMessage( + connection: ConnectionState, + message: AnyMessage, + ): Promise { + if (isRequestMessage(message)) { + const key = messageIdKey(message.id); + + if (key) { + connection.pendingRoutes.set(key, determineRoute()); + } + } + + await writeInbound(connection, message); + } } type JsonResult = @@ -136,6 +211,23 @@ async function readJson(req: Request): Promise { } } +async function writeInbound( + connection: ConnectionState, + message: AnyMessage, +): Promise { + const writer = connection.inboundTx.getWriter(); + + try { + await writer.write(message); + } finally { + writer.releaseLock(); + } +} + +function determineRoute(): ResponseRoute { + return "connection"; +} + function isJsonRpcMessage(value: unknown): value is AnyMessage { return ( isRecord(value) && @@ -144,10 +236,101 @@ function isJsonRpcMessage(value: unknown): value is AnyMessage { ); } +function isRequestMessage( + message: AnyMessage, +): message is AnyMessage & { readonly id: string | number | null } { + return "method" in message && "id" in message; +} + function isRecord(value: unknown): value is Record { return typeof value === "object" && value !== null; } +function sseResponse(subscription: OutboundSubscription): Response { + return new Response(createSseBody(subscription), { + status: 200, + headers: { + "Content-Type": EVENT_STREAM_MIME_TYPE, + "Cache-Control": "no-cache", + Connection: "keep-alive", + }, + }); +} + +function createSseBody( + subscription: OutboundSubscription, +): ReadableStream { + const encoder = new TextEncoder(); + let keepAliveTimer: ReturnType | undefined; + let reader: ReadableStreamDefaultReader | undefined; + + const clearKeepAlive = (): void => { + if (keepAliveTimer) { + clearInterval(keepAliveTimer); + keepAliveTimer = undefined; + } + }; + + const enqueueText = ( + controller: ReadableStreamDefaultController, + text: string, + ): boolean => { + try { + controller.enqueue(encoder.encode(text)); + return true; + } catch { + return false; + } + }; + + return new ReadableStream({ + async start(controller) { + for (const message of subscription.replay) { + if (!enqueueText(controller, serializeSseEvent(message))) { + return; + } + } + + reader = subscription.stream.getReader(); + + keepAliveTimer = setInterval(() => { + if (!enqueueText(controller, serializeSseKeepAlive())) { + clearKeepAlive(); + } + }, 15_000); + + try { + while (true) { + const result = await reader.read(); + + if (result.done) { + return; + } + + if (!enqueueText(controller, serializeSseEvent(result.value))) { + return; + } + } + } catch (error) { + controller.error(error); + } finally { + clearKeepAlive(); + reader.releaseLock(); + + try { + controller.close(); + } catch { + // Stream may already be cancelled by the consumer. + } + } + }, + cancel() { + clearKeepAlive(); + void reader?.cancel(); + }, + }); +} + function jsonResponse( value: unknown, status: number, @@ -170,3 +353,7 @@ function textResponse(body: string, status: number): Response { }, }); } + +function emptyResponse(status: number): Response { + return new Response(null, { status }); +} From a274c49b9d4569c58d6d3c0d644e633660ac7abf Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Mon, 18 May 2026 18:57:02 +1000 Subject: [PATCH 04/19] Add session SSE and prompt streaming --- src/connection.test.ts | 97 +++++++ src/connection.ts | 55 +++- src/server-session-sse.test.ts | 455 +++++++++++++++++++++++++++++++++ src/server.test.ts | 30 ++- src/server.ts | 107 +++++++- 5 files changed, 723 insertions(+), 21 deletions(-) create mode 100644 src/server-session-sse.test.ts diff --git a/src/connection.test.ts b/src/connection.test.ts index e0f52ec..60ea0e5 100644 --- a/src/connection.test.ts +++ b/src/connection.test.ts @@ -30,6 +30,18 @@ const sessionNewRequest = { }, } as const; +function createPromptRequest(id: number, sessionId: string) { + return { + jsonrpc: "2.0", + id, + method: "session/prompt", + params: { + sessionId, + prompt: [{ type: "text", text: "Hello" }], + }, + } as const; +} + const messageOne = { jsonrpc: "2.0", id: 1, result: "one" } as const; const messageTwo = { jsonrpc: "2.0", id: 2, result: "two" } as const; const messageThree = { jsonrpc: "2.0", id: 3, result: "three" } as const; @@ -144,6 +156,70 @@ describe("ConnectionRegistry", () => { registry.closeAll(); }); + + it("returns the same session stream for repeated ensureSession calls", () => { + const registry = new ConnectionRegistry(); + const connection = registry.createConnection( + (conn: AgentSideConnection) => new TestAgent(conn), + ); + const sessionId = globalThis.crypto.randomUUID(); + + expect(connection.ensureSession(sessionId)).toBe( + connection.ensureSession(sessionId), + ); + expect(connection.sessionStreams.get(sessionId)).toBe( + connection.ensureSession(sessionId), + ); + + registry.closeAll(); + }); + + it("routes session responses and notifications to the session stream", async () => { + const registry = new ConnectionRegistry(); + const connection = registry.createConnection( + (conn: AgentSideConnection) => new TestAgent(conn, { chunkCount: 1 }), + ); + const sessionId = globalThis.crypto.randomUUID(); + const promptRequest = createPromptRequest(3, sessionId); + + await initializeConnection(connection); + + const sessionSubscription = connection.ensureSession(sessionId).subscribe(); + const connectionSubscription = connection.connectionStream.subscribe(); + const key = messageIdKey(promptRequest.id); + + expect(key).toBe("number:3"); + connection.pendingRoutes.set(key ?? "", { session: sessionId }); + + await writeInbound(connection.inboundTx, promptRequest); + + expect(await readNext(sessionSubscription.stream)).toMatchObject({ + jsonrpc: "2.0", + method: "session/update", + params: { + sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { + text: "chunk-1", + }, + }, + }, + }); + expect(await readNext(sessionSubscription.stream)).toMatchObject({ + jsonrpc: "2.0", + id: promptRequest.id, + result: { + stopReason: "end_turn", + }, + }); + expect(connection.pendingRoutes.has(key ?? "")).toBe(false); + expect( + await readNextOrUndefined(connectionSubscription.stream), + ).toBeUndefined(); + + registry.closeAll(); + }); }); describe("OutboundStream", () => { @@ -251,5 +327,26 @@ async function readNext( } } +async function readNextOrUndefined( + stream: ReadableStream, +): Promise { + const reader = stream.getReader(); + + try { + return await Promise.race([ + reader.read().then((result) => (result.done ? undefined : result.value)), + delay(50).then(() => undefined), + ]); + } finally { + reader.releaseLock(); + } +} + +function delay(ms: number): Promise { + return new Promise((resolve) => { + setTimeout(resolve, ms); + }); +} + const routeShapeCheck = "connection" satisfies ResponseRoute; void routeShapeCheck; diff --git a/src/connection.ts b/src/connection.ts index c02994a..098faaa 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -1,5 +1,5 @@ import { AgentSideConnection } from "./acp.js"; -import { messageIdKey } from "./protocol.js"; +import { messageIdKey, sessionIdFromParams } from "./protocol.js"; import type { Agent } from "./acp.js"; import type { AnyMessage, AnyResponse } from "./jsonrpc.js"; @@ -84,6 +84,7 @@ export class ConnectionState { readonly agentConnection: AgentSideConnection; readonly connectionStream = new OutboundStream(); readonly allOutbound = new OutboundStream(); + readonly sessionStreams = new Map(); readonly pendingRoutes = new Map(); private hasStartedRouter = false; @@ -135,9 +136,27 @@ export class ConnectionState { void this.runRouter(); } + ensureSession(sessionId: string): OutboundStream { + const existing = this.sessionStreams.get(sessionId); + if (existing) { + return existing; + } + + const stream = new OutboundStream(); + this.sessionStreams.set(sessionId, stream); + + return stream; + } + async shutdown(): Promise { this.connectionStream.close(); this.allOutbound.close(); + + for (const stream of this.sessionStreams.values()) { + stream.close(); + } + + this.sessionStreams.clear(); this.pendingRoutes.clear(); await Promise.allSettled([ @@ -170,6 +189,10 @@ export class ConnectionState { reader.releaseLock(); this.connectionStream.close(); this.allOutbound.close(); + + for (const stream of this.sessionStreams.values()) { + stream.close(); + } } } @@ -179,6 +202,13 @@ export class ConnectionState { if (isResponse(message)) { const key = messageIdKey(message.id); const route = key ? this.pendingRoutes.get(key) : undefined; + const sessionId = sessionIdFromResult( + "result" in message ? message.result : undefined, + ); + + if (sessionId) { + this.ensureSession(sessionId); + } if (key) { this.pendingRoutes.delete(key); @@ -188,6 +218,14 @@ export class ConnectionState { return; } + if ("method" in message) { + const sessionId = sessionIdFromParams(message.params); + if (sessionId) { + this.ensureSession(sessionId).push(message); + return; + } + } + this.connectionStream.push(message); } @@ -197,7 +235,7 @@ export class ConnectionState { return; } - this.connectionStream.push(message); + this.ensureSession(route.session).push(message); } } @@ -333,3 +371,16 @@ function isMatchingResponse( function isResponse(msg: AnyMessage): msg is AnyResponse { return "id" in msg && !("method" in msg); } + +function sessionIdFromResult(result: unknown): string | undefined { + if (!isRecord(result)) { + return undefined; + } + + const sessionId = result["sessionId"]; + return typeof sessionId === "string" ? sessionId : undefined; +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} diff --git a/src/server-session-sse.test.ts b/src/server-session-sse.test.ts new file mode 100644 index 0000000..c30c775 --- /dev/null +++ b/src/server-session-sse.test.ts @@ -0,0 +1,455 @@ +import { describe, expect, it } from "vitest"; +import { + EVENT_STREAM_MIME_TYPE, + HEADER_CONNECTION_ID, + HEADER_SESSION_ID, + JSON_MIME_TYPE, +} from "./protocol.js"; +import { parseSseStream } from "./sse.js"; +import { TestAgent } from "./test-support/test-agent.js"; +import { startTestServer } from "./test-support/test-http-server.js"; + +import type { AgentSideConnection } from "./acp.js"; +import type { AnyMessage } from "./jsonrpc.js"; + +const initializeRequest = { + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: 1, + clientCapabilities: {}, + }, +}; + +const sessionNewRequest = { + jsonrpc: "2.0", + id: 2, + method: "session/new", + params: { + cwd: "/tmp", + mcpServers: [], + }, +}; + +function createPromptRequest(id: number, sessionId?: string) { + return { + jsonrpc: "2.0", + id, + method: "session/prompt", + params: { + ...(sessionId === undefined ? {} : { sessionId }), + prompt: [{ type: "text", text: "Hello" }], + }, + }; +} + +describe("AcpServer session SSE", () => { + it("streams prompt updates and responses on the session SSE stream", async () => { + const server = await startTestServer( + (conn: AgentSideConnection) => new TestAgent(conn, { chunkCount: 2 }), + ); + + try { + const connectionId = await initialize(server.url); + const sessionId = await createSession(server.url, connectionId); + const sessionSse = await openSessionSse( + server.url, + connectionId, + sessionId, + ); + + expect(sessionSse.status).toBe(200); + + const accepted = await postJson( + server.url, + createPromptRequest(3, sessionId), + { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + ); + + expect(accepted.status).toBe(202); + expect(await readSseMessages(sessionSse, 3)).toMatchObject([ + { + jsonrpc: "2.0", + method: "session/update", + params: { + sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { + text: "chunk-1", + }, + }, + }, + }, + { + jsonrpc: "2.0", + method: "session/update", + params: { + sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { + text: "chunk-2", + }, + }, + }, + }, + { + jsonrpc: "2.0", + id: 3, + result: { + stopReason: "end_turn", + }, + }, + ]); + expect( + await readNextConnectionSseMessage(server.url, connectionId), + ).toBeUndefined(); + } finally { + await server.close(); + } + }); + + it("routes session prompts using params.sessionId when the session header is absent", async () => { + const server = await startTestServer(); + + try { + const connectionId = await initialize(server.url); + const sessionId = await createSession(server.url, connectionId); + const sessionSse = await openSessionSse( + server.url, + connectionId, + sessionId, + ); + const accepted = await postJson( + server.url, + createPromptRequest(3, sessionId), + { + [HEADER_CONNECTION_ID]: connectionId, + }, + ); + + expect(accepted.status).toBe(202); + expect(await readSseMessages(sessionSse, 2)).toMatchObject([ + { + jsonrpc: "2.0", + method: "session/update", + params: { sessionId }, + }, + { + jsonrpc: "2.0", + id: 3, + result: { stopReason: "end_turn" }, + }, + ]); + } finally { + await server.close(); + } + }); + + it("rejects session-scoped requests without a session identifier", async () => { + const server = await startTestServer(); + + try { + const connectionId = await initialize(server.url); + const response = await postJson(server.url, createPromptRequest(3), { + [HEADER_CONNECTION_ID]: connectionId, + }); + + expect(response.status).toBe(400); + } finally { + await server.close(); + } + }); + + it("replays buffered session messages when session SSE attaches after prompt", async () => { + const server = await startTestServer(); + + try { + const connectionId = await initialize(server.url); + const sessionId = await createSession(server.url, connectionId); + const accepted = await postJson( + server.url, + createPromptRequest(3, sessionId), + { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + ); + const sessionSse = await openSessionSse( + server.url, + connectionId, + sessionId, + ); + + expect(accepted.status).toBe(202); + expect(await readSseMessages(sessionSse, 2)).toMatchObject([ + { + jsonrpc: "2.0", + method: "session/update", + params: { sessionId }, + }, + { + jsonrpc: "2.0", + id: 3, + result: { stopReason: "end_turn" }, + }, + ]); + } finally { + await server.close(); + } + }); + + it("isolates prompt events for multiple sessions on the same connection", async () => { + const server = await startTestServer(); + + try { + const connectionId = await initialize(server.url); + const connectionSse = await openConnectionSse(server.url, connectionId); + const connectionEvents = createSseMessageIterator(connectionSse); + const firstSessionId = await createSessionFromConnectionEvents( + server.url, + connectionId, + connectionEvents, + ); + const secondSessionId = await createSessionFromConnectionEvents( + server.url, + connectionId, + connectionEvents, + ); + const firstSse = await openSessionSse( + server.url, + connectionId, + firstSessionId, + ); + const secondSse = await openSessionSse( + server.url, + connectionId, + secondSessionId, + ); + + expect( + await postJson(server.url, createPromptRequest(3, firstSessionId), { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: firstSessionId, + }), + ).toMatchObject({ status: 202 }); + expect( + await postJson(server.url, createPromptRequest(4, secondSessionId), { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: secondSessionId, + }), + ).toMatchObject({ status: 202 }); + + expect(await readSseMessages(firstSse, 2)).toMatchObject([ + { method: "session/update", params: { sessionId: firstSessionId } }, + { id: 3, result: { stopReason: "end_turn" } }, + ]); + expect(await readSseMessages(secondSse, 2)).toMatchObject([ + { method: "session/update", params: { sessionId: secondSessionId } }, + { id: 4, result: { stopReason: "end_turn" } }, + ]); + } finally { + await server.close(); + } + }); +}); + +async function initialize(url: string): Promise { + const response = await postJson(url, initializeRequest); + const connectionId = response.headers.get(HEADER_CONNECTION_ID); + + expect(response.status).toBe(200); + expect(connectionId).toMatch(/^[0-9a-f-]{36}$/); + + return connectionId ?? ""; +} + +async function createSession( + url: string, + connectionId: string, +): Promise { + return createSessionFromConnectionSse( + url, + connectionId, + await openConnectionSse(url, connectionId), + ); +} + +async function createSessionFromConnectionSse( + url: string, + connectionId: string, + response: Response, +): Promise { + return createSessionFromConnectionEvents( + url, + connectionId, + createSseMessageIterator(response), + ); +} + +async function createSessionFromConnectionEvents( + url: string, + connectionId: string, + events: AsyncIterator, +): Promise { + const accepted = await postJson(url, sessionNewRequest, { + [HEADER_CONNECTION_ID]: connectionId, + }); + + expect(accepted.status).toBe(202); + + return readSessionId(await readNextSseMessage(events)); +} + +function openConnectionSse( + url: string, + connectionId: string, + signal?: AbortSignal, +): Promise { + return fetch(url, { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + }, + signal, + }); +} + +function openSessionSse( + url: string, + connectionId: string, + sessionId: string, +): Promise { + return fetch(url, { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + }); +} + +function createSseMessageIterator( + response: Response, +): AsyncIterator { + if (!response.body) { + throw new Error("Expected SSE response body"); + } + + return parseSseStream(response.body)[Symbol.asyncIterator](); +} + +async function readNextSseMessage( + iterator: AsyncIterator, +): Promise { + const result = await iterator.next(); + + if (result.done) { + throw new Error("Expected SSE message"); + } + + return result.value; +} + +async function readSseMessages( + response: Response, + count: number, +): Promise { + if (!response.body) { + throw new Error("Expected SSE response body"); + } + + const iterator = parseSseStream(response.body)[Symbol.asyncIterator](); + + try { + const messages: AnyMessage[] = []; + + for (const __unused of Array.from({ length: count })) { + void __unused; + const result = await iterator.next(); + + if (result.done) { + throw new Error("Expected SSE message"); + } + + messages.push(result.value); + } + + return messages; + } finally { + await iterator.return?.(); + await response.body.cancel(); + } +} + +async function readNextConnectionSseMessage( + url: string, + connectionId: string, +): Promise { + const abort = new AbortController(); + const response = await openConnectionSse(url, connectionId, abort.signal); + + if (!response.body) { + throw new Error("Expected SSE response body"); + } + + const iterator = parseSseStream(response.body)[Symbol.asyncIterator](); + + try { + const result = await Promise.race([ + iterator.next(), + delay(50).then(() => ({ done: true, value: undefined })), + ]); + + return result.done ? undefined : result.value; + } finally { + abort.abort(); + await iterator.return?.(); + } +} + +function readSessionId(message: AnyMessage): string { + if (!("result" in message) || !isRecord(message.result)) { + throw new Error("Expected session/new response result"); + } + + const sessionId = message.result["sessionId"]; + + if (typeof sessionId !== "string") { + throw new Error("Expected session ID"); + } + + return sessionId; +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + +function delay(ms: number): Promise { + return new Promise((resolve) => { + setTimeout(resolve, ms); + }); +} + +function postJson( + url: string, + body: unknown, + headers: Record = {}, +): Promise { + return fetch(url, { + method: "POST", + headers: { + "Content-Type": JSON_MIME_TYPE, + ...headers, + }, + body: JSON.stringify(body), + }); +} diff --git a/src/server.test.ts b/src/server.test.ts index 50727ee..e2b37bc 100644 --- a/src/server.test.ts +++ b/src/server.test.ts @@ -176,19 +176,16 @@ describe("AcpServer", () => { } }); - it("rejects session-scoped GETs until session SSE is implemented", async () => { + it("rejects session-scoped GETs for unknown sessions", async () => { const server = await startTestServer(); try { const connectionId = await initialize(server.url); - const response = await fetch(server.url, { - method: "GET", - headers: { - Accept: EVENT_STREAM_MIME_TYPE, - [HEADER_CONNECTION_ID]: connectionId, - [HEADER_SESSION_ID]: globalThis.crypto.randomUUID(), - }, - }); + const response = await openSessionSse( + server.url, + connectionId, + globalThis.crypto.randomUUID(), + ); expect(response.status).toBe(404); } finally { @@ -451,6 +448,21 @@ function openConnectionSse( }); } +function openSessionSse( + url: string, + connectionId: string, + sessionId: string, +): Promise { + return fetch(url, { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + }); +} + async function readFirstSseMessage(response: Response): Promise { if (!response.body) { throw new Error("Expected SSE response body"); diff --git a/src/server.ts b/src/server.ts index f758736..5a74208 100644 --- a/src/server.ts +++ b/src/server.ts @@ -6,7 +6,10 @@ import { JSON_MIME_TYPE, isInitializeRequest, messageIdKey, + methodRequiresSessionHeader, + sessionIdFromParams, } from "./protocol.js"; + import { serializeSseEvent, serializeSseKeepAlive } from "./sse.js"; import type { @@ -86,7 +89,15 @@ export class AcpServer { return textResponse("Unknown Acp-Connection-Id", 404); } - await this.forwardConnectedMessage(connection, body.value); + const forwarded = await this.forwardConnectedMessage( + connection, + body.value, + req.headers, + ); + if (!forwarded.ok) { + return textResponse(forwarded.message, forwarded.status); + } + return emptyResponse(202); } @@ -113,8 +124,14 @@ export class AcpServer { return textResponse("Unknown Acp-Connection-Id", 404); } - if (req.headers.get(HEADER_SESSION_ID)) { - return textResponse("Unknown Acp-Session-Id", 404); + const sessionId = req.headers.get(HEADER_SESSION_ID); + if (sessionId) { + const sessionStream = connection.sessionStreams.get(sessionId); + if (!sessionStream) { + return textResponse("Unknown Acp-Session-Id", 404); + } + + return sseResponse(sessionStream.subscribe()); } return sseResponse(connection.connectionStream.subscribe()); @@ -176,19 +193,41 @@ export class AcpServer { private async forwardConnectedMessage( connection: ConnectionState, message: AnyMessage, - ): Promise { + headers: Headers, + ): Promise { if (isRequestMessage(message)) { + const route = determineRoute(message, headers); + + if (!route.ok) { + return route; + } + + if (route.value !== "connection") { + connection.ensureSession(route.value.session); + } + const key = messageIdKey(message.id); if (key) { - connection.pendingRoutes.set(key, determineRoute()); + connection.pendingRoutes.set(key, route.value); } } await writeInbound(connection, message); + return { ok: true }; } } +type ForwardResult = + | { + ok: true; + } + | { + ok: false; + status: number; + message: string; + }; + type JsonResult = | { ok: true; @@ -198,6 +237,17 @@ type JsonResult = ok: false; }; +type RouteResult = + | { + ok: true; + value: ResponseRoute; + } + | { + ok: false; + status: number; + message: string; + }; + async function readJson(req: Request): Promise { try { return { @@ -224,8 +274,43 @@ async function writeInbound( } } -function determineRoute(): ResponseRoute { - return "connection"; +function determineRoute( + message: AnyMessage & { + readonly method: string; + readonly params?: unknown; + }, + headers: Headers, +): RouteResult { + const headerSessionId = headers.get(HEADER_SESSION_ID); + + if (headerSessionId) { + return { + ok: true, + value: { session: headerSessionId }, + }; + } + + const paramsSessionId = sessionIdFromParams(message.params); + + if (paramsSessionId) { + return { + ok: true, + value: { session: paramsSessionId }, + }; + } + + if (methodRequiresSessionHeader(message.method)) { + return { + ok: false, + status: 400, + message: "Missing Acp-Session-Id", + }; + } + + return { + ok: true, + value: "connection", + }; } function isJsonRpcMessage(value: unknown): value is AnyMessage { @@ -236,9 +321,11 @@ function isJsonRpcMessage(value: unknown): value is AnyMessage { ); } -function isRequestMessage( - message: AnyMessage, -): message is AnyMessage & { readonly id: string | number | null } { +function isRequestMessage(message: AnyMessage): message is AnyMessage & { + readonly id: string | number | null; + readonly method: string; + readonly params?: unknown; +} { return "method" in message && "id" in message; } From 7e770a30cdb3bd972b20dc3da15aa6e692516290 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Mon, 18 May 2026 20:40:47 +1000 Subject: [PATCH 05/19] Add tool permission request support --- src/connection.ts | 34 ++-- src/server-permission.test.ts | 305 +++++++++++++++++++++++++++++++++ src/server.ts | 87 +++++++--- src/test-support/test-agent.ts | 39 +++++ 4 files changed, 426 insertions(+), 39 deletions(-) create mode 100644 src/server-permission.test.ts diff --git a/src/connection.ts b/src/connection.ts index 098faaa..f67df33 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -200,24 +200,32 @@ export class ConnectionState { this.allOutbound.push(message); if (isResponse(message)) { - const key = messageIdKey(message.id); - const route = key ? this.pendingRoutes.get(key) : undefined; - const sessionId = sessionIdFromResult( - "result" in message ? message.result : undefined, - ); + this.routeOutboundResponse(message); + return; + } - if (sessionId) { - this.ensureSession(sessionId); - } + this.routeOutboundRequestOrNotification(message); + } - if (key) { - this.pendingRoutes.delete(key); - } + private routeOutboundResponse(message: AnyResponse): void { + const key = messageIdKey(message.id); + const route = key ? this.pendingRoutes.get(key) : undefined; + const sessionId = sessionIdFromResult( + "result" in message ? message.result : undefined, + ); - this.pushToRoute(route ?? "connection", message); - return; + if (sessionId) { + this.ensureSession(sessionId); + } + + if (key) { + this.pendingRoutes.delete(key); } + this.pushToRoute(route ?? "connection", message); + } + + private routeOutboundRequestOrNotification(message: AnyMessage): void { if ("method" in message) { const sessionId = sessionIdFromParams(message.params); if (sessionId) { diff --git a/src/server-permission.test.ts b/src/server-permission.test.ts new file mode 100644 index 0000000..0ec3a9f --- /dev/null +++ b/src/server-permission.test.ts @@ -0,0 +1,305 @@ +import { describe, expect, it } from "vitest"; +import { + EVENT_STREAM_MIME_TYPE, + HEADER_CONNECTION_ID, + HEADER_SESSION_ID, + JSON_MIME_TYPE, +} from "./protocol.js"; +import { parseSseStream } from "./sse.js"; +import { TestAgent } from "./test-support/test-agent.js"; +import { startTestServer } from "./test-support/test-http-server.js"; + +import type { AgentSideConnection } from "./acp.js"; +import type { AnyMessage } from "./jsonrpc.js"; + +const initializeRequest = { + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: 1, + clientCapabilities: {}, + }, +}; + +const sessionNewRequest = { + jsonrpc: "2.0", + id: 2, + method: "session/new", + params: { + cwd: "/tmp", + mcpServers: [], + }, +}; + +function createPromptRequest(id: number, sessionId: string) { + return { + jsonrpc: "2.0", + id, + method: "session/prompt", + params: { + sessionId, + prompt: [{ type: "text", text: "Hello" }], + }, + }; +} + +describe("AcpServer permission requests over HTTP", () => { + it("routes permission requests over session SSE and accepts client responses", async () => { + const server = await startTestServer( + (conn: AgentSideConnection) => + new TestAgent(conn, { enablePermission: true }), + ); + + try { + const connectionId = await initialize(server.url); + const sessionId = await createSession(server.url, connectionId); + const connectionAbort = new AbortController(); + const connectionSse = await openConnectionSse( + server.url, + connectionId, + connectionAbort.signal, + ); + const sessionSse = await openSessionSse( + server.url, + connectionId, + sessionId, + ); + const sessionEvents = createSseMessageIterator(sessionSse); + + expect( + await postJson(server.url, createPromptRequest(3, sessionId), { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }), + ).toMatchObject({ status: 202 }); + + expect(await readNextSseMessage(sessionEvents)).toMatchObject({ + jsonrpc: "2.0", + method: "session/update", + params: { + sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { text: "chunk-1" }, + }, + }, + }); + + const permissionRequest = await readNextSseMessage(sessionEvents); + expect(permissionRequest).toMatchObject({ + jsonrpc: "2.0", + id: expect.any(Number), + method: "session/request_permission", + params: { + sessionId, + toolCall: { + toolCallId: "permission-tool", + title: "Permission tool", + }, + options: expect.arrayContaining([ + expect.objectContaining({ + kind: "allow_once", + optionId: "allow", + }), + ]), + }, + }); + expect( + await readNextMessageOrUndefined(connectionSse, connectionAbort), + ).toBeUndefined(); + + expect( + await postJson( + server.url, + { + jsonrpc: "2.0", + id: readMessageId(permissionRequest), + result: { + outcome: { + outcome: "selected", + optionId: "allow", + }, + }, + }, + { [HEADER_CONNECTION_ID]: connectionId }, + ), + ).toMatchObject({ status: 202 }); + + expect(await readNextSseMessage(sessionEvents)).toMatchObject({ + jsonrpc: "2.0", + method: "session/update", + params: { + sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { text: "permission-selected-allow" }, + }, + }, + }); + expect(await readNextSseMessage(sessionEvents)).toMatchObject({ + jsonrpc: "2.0", + id: 3, + result: { stopReason: "end_turn" }, + }); + + await sessionEvents.return?.(); + await sessionSse.body?.cancel(); + } finally { + await server.close(); + } + }, 10_000); +}); + +async function initialize(url: string): Promise { + const response = await postJson(url, initializeRequest); + const connectionId = response.headers.get(HEADER_CONNECTION_ID); + + expect(response.status).toBe(200); + expect(connectionId).toMatch(/^[0-9a-f-]{36}$/); + + return connectionId ?? ""; +} + +async function createSession( + url: string, + connectionId: string, +): Promise { + const response = await openConnectionSse(url, connectionId); + const events = createSseMessageIterator(response); + + try { + expect( + await postJson(url, sessionNewRequest, { + [HEADER_CONNECTION_ID]: connectionId, + }), + ).toMatchObject({ status: 202 }); + + return readSessionId(await readNextSseMessage(events)); + } finally { + await events.return?.(); + await response.body?.cancel(); + } +} + +function openConnectionSse( + url: string, + connectionId: string, + signal?: AbortSignal, +): Promise { + return fetch(url, { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + }, + signal, + }); +} + +function openSessionSse( + url: string, + connectionId: string, + sessionId: string, +): Promise { + return fetch(url, { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + }); +} + +function createSseMessageIterator( + response: Response, +): AsyncIterator { + if (!response.body) { + throw new Error("Expected SSE response body"); + } + + return parseSseStream(response.body)[Symbol.asyncIterator](); +} + +async function readNextSseMessage( + iterator: AsyncIterator, +): Promise { + const result = await iterator.next(); + + if (result.done) { + throw new Error("Expected SSE message"); + } + + return result.value; +} + +async function readNextMessageOrUndefined( + response: Response, + abort: AbortController, +): Promise { + if (!response.body) { + throw new Error("Expected SSE response body"); + } + + const iterator = parseSseStream(response.body)[Symbol.asyncIterator](); + + try { + const result = await Promise.race([ + iterator.next(), + delay(50).then(() => ({ done: true, value: undefined })), + ]); + + return result.done ? undefined : result.value; + } finally { + abort.abort(); + await iterator.return?.(); + } +} + +function readMessageId(message: AnyMessage): string | number | null { + if (!("id" in message)) { + throw new Error("Expected message ID"); + } + + return message.id; +} + +function readSessionId(message: AnyMessage): string { + if (!("result" in message) || !isRecord(message.result)) { + throw new Error("Expected session/new response result"); + } + + const sessionId = message.result["sessionId"]; + + if (typeof sessionId !== "string") { + throw new Error("Expected session ID"); + } + + return sessionId; +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + +function delay(ms: number): Promise { + return new Promise((resolve) => { + setTimeout(resolve, ms); + }); +} + +function postJson( + url: string, + body: unknown, + headers: Record = {}, +): Promise { + return fetch(url, { + method: "POST", + headers: { + "Content-Type": JSON_MIME_TYPE, + ...headers, + }, + body: JSON.stringify(body), + }); +} diff --git a/src/server.ts b/src/server.ts index 5a74208..d6ae5d8 100644 --- a/src/server.ts +++ b/src/server.ts @@ -18,7 +18,7 @@ import type { ResponseRoute, } from "./connection.js"; import type { Agent, AgentSideConnection } from "./acp.js"; -import type { AnyMessage } from "./jsonrpc.js"; +import type { AnyMessage, AnyResponse } from "./jsonrpc.js"; export interface AcpServerOptions { createAgent: (conn: AgentSideConnection) => Agent; @@ -196,25 +196,14 @@ export class AcpServer { headers: Headers, ): Promise { if (isRequestMessage(message)) { - const route = determineRoute(message, headers); - - if (!route.ok) { - return route; - } - - if (route.value !== "connection") { - connection.ensureSession(route.value.session); - } - - const key = messageIdKey(message.id); + return await forwardClientRequest(connection, message, headers); + } - if (key) { - connection.pendingRoutes.set(key, route.value); - } + if (isResponseMessage(message)) { + return await forwardClientResponse(connection, message); } - await writeInbound(connection, message); - return { ok: true }; + return await forwardClientNotification(connection, message); } } @@ -248,6 +237,12 @@ type RouteResult = message: string; }; +type ClientRequestMessage = AnyMessage & { + readonly id: string | number | null; + readonly method: string; + readonly params?: unknown; +}; + async function readJson(req: Request): Promise { try { return { @@ -274,11 +269,49 @@ async function writeInbound( } } +async function forwardClientRequest( + connection: ConnectionState, + message: ClientRequestMessage, + headers: Headers, +): Promise { + const route = determineRoute(message, headers); + + if (!route.ok) { + return route; + } + + if (route.value !== "connection") { + connection.ensureSession(route.value.session); + } + + const key = messageIdKey(message.id); + + if (key) { + connection.pendingRoutes.set(key, route.value); + } + + await writeInbound(connection, message); + return { ok: true }; +} + +async function forwardClientResponse( + connection: ConnectionState, + message: AnyResponse, +): Promise { + await writeInbound(connection, message); + return { ok: true }; +} + +async function forwardClientNotification( + connection: ConnectionState, + message: AnyMessage, +): Promise { + await writeInbound(connection, message); + return { ok: true }; +} + function determineRoute( - message: AnyMessage & { - readonly method: string; - readonly params?: unknown; - }, + message: ClientRequestMessage, headers: Headers, ): RouteResult { const headerSessionId = headers.get(HEADER_SESSION_ID); @@ -321,14 +354,16 @@ function isJsonRpcMessage(value: unknown): value is AnyMessage { ); } -function isRequestMessage(message: AnyMessage): message is AnyMessage & { - readonly id: string | number | null; - readonly method: string; - readonly params?: unknown; -} { +function isRequestMessage( + message: AnyMessage, +): message is ClientRequestMessage { return "method" in message && "id" in message; } +function isResponseMessage(message: AnyMessage): message is AnyResponse { + return "id" in message && !("method" in message); +} + function isRecord(value: unknown): value is Record { return typeof value === "object" && value !== null; } diff --git a/src/test-support/test-agent.ts b/src/test-support/test-agent.ts index a1fcc87..265e783 100644 --- a/src/test-support/test-agent.ts +++ b/src/test-support/test-agent.ts @@ -17,17 +17,20 @@ import type { export interface TestAgentOptions { readonly chunkCount?: number; readonly chunkDelayMs?: number; + readonly enablePermission?: boolean; } export class TestAgent implements Agent { private readonly connection: AgentSideConnection; private readonly chunkCount: number; private readonly chunkDelayMs: number; + private readonly enablePermission: boolean; constructor(connection: AgentSideConnection, options: TestAgentOptions = {}) { this.connection = connection; this.chunkCount = options.chunkCount ?? 1; this.chunkDelayMs = options.chunkDelayMs ?? 0; + this.enablePermission = options.enablePermission ?? false; } initialize(_params: InitializeRequest): Promise { @@ -70,6 +73,42 @@ export class TestAgent implements Agent { }); } + if (this.enablePermission) { + const permission = await this.connection.requestPermission({ + sessionId: params.sessionId, + toolCall: { + toolCallId: "permission-tool", + title: "Permission tool", + }, + options: [ + { + kind: "allow_once", + name: "Allow once", + optionId: "allow", + }, + { + kind: "reject_once", + name: "Reject once", + optionId: "reject", + }, + ], + }); + + await this.connection.sessionUpdate({ + sessionId: params.sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { + type: "text", + text: + permission.outcome.outcome === "selected" + ? `permission-selected-${permission.outcome.optionId}` + : "permission-cancelled", + }, + }, + }); + } + return { stopReason: "end_turn" }; } From 03c26e40b1cbde366b9aae98d66f8a7a1794aa4e Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Mon, 18 May 2026 21:40:22 +1000 Subject: [PATCH 06/19] Add ACP HTTP client transport --- package.json | 4 + src/connection.ts | 40 +--- src/http-stream.test.ts | 487 ++++++++++++++++++++++++++++++++++++++++ src/http-stream.ts | 303 +++++++++++++++++++++++++ src/jsonrpc.ts | 30 +++ src/protocol.ts | 27 ++- src/server.ts | 35 +-- src/sse.ts | 19 +- 8 files changed, 867 insertions(+), 78 deletions(-) create mode 100644 src/http-stream.test.ts create mode 100644 src/http-stream.ts diff --git a/package.json b/package.json index b7b2b7c..c93aff2 100644 --- a/package.json +++ b/package.json @@ -29,6 +29,10 @@ "types": "./dist/acp.d.ts", "default": "./dist/acp.js" }, + "./http-client": { + "types": "./dist/http-stream.d.ts", + "default": "./dist/http-stream.js" + }, "./server": { "types": "./dist/server.d.ts", "default": "./dist/server.js" diff --git a/src/connection.ts b/src/connection.ts index f67df33..cee4b77 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -1,5 +1,10 @@ import { AgentSideConnection } from "./acp.js"; -import { messageIdKey, sessionIdFromParams } from "./protocol.js"; +import { isResponseMessage } from "./jsonrpc.js"; +import { + messageIdKey, + sessionIdFromMessageParams, + sessionIdFromResponseResult, +} from "./protocol.js"; import type { Agent } from "./acp.js"; import type { AnyMessage, AnyResponse } from "./jsonrpc.js"; @@ -199,7 +204,7 @@ export class ConnectionState { private routeOutbound(message: AnyMessage): void { this.allOutbound.push(message); - if (isResponse(message)) { + if (isResponseMessage(message)) { this.routeOutboundResponse(message); return; } @@ -210,9 +215,7 @@ export class ConnectionState { private routeOutboundResponse(message: AnyResponse): void { const key = messageIdKey(message.id); const route = key ? this.pendingRoutes.get(key) : undefined; - const sessionId = sessionIdFromResult( - "result" in message ? message.result : undefined, - ); + const sessionId = sessionIdFromResponseResult(message); if (sessionId) { this.ensureSession(sessionId); @@ -226,12 +229,10 @@ export class ConnectionState { } private routeOutboundRequestOrNotification(message: AnyMessage): void { - if ("method" in message) { - const sessionId = sessionIdFromParams(message.params); - if (sessionId) { - this.ensureSession(sessionId).push(message); - return; - } + const sessionId = sessionIdFromMessageParams(message); + if (sessionId) { + this.ensureSession(sessionId).push(message); + return; } this.connectionStream.push(message); @@ -375,20 +376,3 @@ function isMatchingResponse( ): msg is AnyResponse { return "id" in msg && !("method" in msg) && msg.id === id; } - -function isResponse(msg: AnyMessage): msg is AnyResponse { - return "id" in msg && !("method" in msg); -} - -function sessionIdFromResult(result: unknown): string | undefined { - if (!isRecord(result)) { - return undefined; - } - - const sessionId = result["sessionId"]; - return typeof sessionId === "string" ? sessionId : undefined; -} - -function isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null; -} diff --git a/src/http-stream.test.ts b/src/http-stream.test.ts new file mode 100644 index 0000000..d74c5bb --- /dev/null +++ b/src/http-stream.test.ts @@ -0,0 +1,487 @@ +import { describe, expect, it } from "vitest"; +import { ClientSideConnection, PROTOCOL_VERSION } from "./acp.js"; +import { createHttpStream } from "./http-stream.js"; +import { + EVENT_STREAM_MIME_TYPE, + HEADER_CONNECTION_ID, + HEADER_SESSION_ID, + JSON_MIME_TYPE, +} from "./protocol.js"; +import { serializeSseEvent } from "./sse.js"; +import { TestAgent } from "./test-support/test-agent.js"; +import { startTestServer } from "./test-support/test-http-server.js"; + +import type { + AgentSideConnection, + Client, + RequestPermissionRequest, + RequestPermissionResponse, + SessionNotification, +} from "./acp.js"; +import type { AnyMessage } from "./jsonrpc.js"; + +const initializeRequest = { + jsonrpc: "2.0", + id: 0, + method: "initialize", + params: { + protocolVersion: PROTOCOL_VERSION, + clientCapabilities: {}, + }, +} satisfies AnyMessage; + +const initializeResponse = { + jsonrpc: "2.0", + id: 0, + result: { + protocolVersion: PROTOCOL_VERSION, + agentCapabilities: { + loadSession: false, + }, + }, +} satisfies AnyMessage; + +const sessionNewResponse = { + jsonrpc: "2.0", + id: 1, + result: { + sessionId: "session-1", + }, +} satisfies AnyMessage; + +const promptRequest = { + jsonrpc: "2.0", + id: 2, + method: "session/prompt", + params: { + sessionId: "session-1", + prompt: [{ type: "text", text: "Hello" }], + }, +} satisfies AnyMessage; + +describe("createHttpStream", () => { + it("posts initialize with custom headers, opens connection SSE, and emits the initialize response", async () => { + const controlledFetch = createControlledFetch(); + const stream = createHttpStream("https://agent.example/acp", { + fetch: controlledFetch.fetch, + headers: { + Authorization: "Bearer token", + "X-Test-Header": "phase-5", + }, + }); + const writer = stream.writable.getWriter(); + const reader = stream.readable.getReader(); + + try { + await writer.write(initializeRequest); + + expect(await readMessage(reader)).toEqual(initializeResponse); + expect(controlledFetch.requests).toHaveLength(2); + + const initializePost = requestAt(controlledFetch.requests, 0); + expect(initializePost.url).toBe("https://agent.example/acp"); + expect(initializePost.method).toBe("POST"); + expect(initializePost.headers.get("Authorization")).toBe("Bearer token"); + expect(initializePost.headers.get("X-Test-Header")).toBe("phase-5"); + expect(initializePost.headers.get("Content-Type")).toBe(JSON_MIME_TYPE); + expect(initializePost.headers.get(HEADER_CONNECTION_ID)).toBeNull(); + expect(JSON.parse(initializePost.body)).toEqual(initializeRequest); + + const connectionGet = requestAt(controlledFetch.requests, 1); + expect(connectionGet.method).toBe("GET"); + expect(connectionGet.headers.get("Authorization")).toBe("Bearer token"); + expect(connectionGet.headers.get("Accept")).toBe(EVENT_STREAM_MIME_TYPE); + expect(connectionGet.headers.get(HEADER_CONNECTION_ID)).toBe( + "connection-1", + ); + expect(connectionGet.headers.get(HEADER_SESSION_ID)).toBeNull(); + } finally { + reader.releaseLock(); + writer.releaseLock(); + await stream.writable.close(); + } + }); + + it("opens session SSE after session creation and includes the session header on session-scoped POSTs", async () => { + const controlledFetch = createControlledFetch(); + const stream = createHttpStream("https://agent.example/acp", { + fetch: controlledFetch.fetch, + }); + const writer = stream.writable.getWriter(); + const reader = stream.readable.getReader(); + + try { + await writer.write(initializeRequest); + await readMessage(reader); + await controlledFetch.sendSse(0, sessionNewResponse); + + expect(await readMessage(reader)).toEqual(sessionNewResponse); + expect(controlledFetch.requests).toHaveLength(3); + + const sessionGet = requestAt(controlledFetch.requests, 2); + expect(sessionGet.method).toBe("GET"); + expect(sessionGet.headers.get(HEADER_CONNECTION_ID)).toBe("connection-1"); + expect(sessionGet.headers.get(HEADER_SESSION_ID)).toBe("session-1"); + + await writer.write(promptRequest); + + const promptPost = requestAt(controlledFetch.requests, 3); + expect(promptPost.method).toBe("POST"); + expect(promptPost.headers.get(HEADER_CONNECTION_ID)).toBe("connection-1"); + expect(promptPost.headers.get(HEADER_SESSION_ID)).toBe("session-1"); + expect(JSON.parse(promptPost.body)).toEqual(promptRequest); + } finally { + reader.releaseLock(); + writer.releaseLock(); + await stream.writable.close(); + } + }); + + it("sends DELETE and aborts SSE requests when closed", async () => { + const controlledFetch = createControlledFetch(); + const stream = createHttpStream("https://agent.example/acp", { + fetch: controlledFetch.fetch, + }); + const writer = stream.writable.getWriter(); + const reader = stream.readable.getReader(); + + try { + await writer.write(initializeRequest); + await readMessage(reader); + await writer.close(); + + const deleteRequest = requestAt(controlledFetch.requests, 2); + expect(deleteRequest.method).toBe("DELETE"); + expect(deleteRequest.headers.get(HEADER_CONNECTION_ID)).toBe( + "connection-1", + ); + expect(sseAt(controlledFetch.sseRequests, 0).signal.aborted).toBe(true); + } finally { + reader.releaseLock(); + writer.releaseLock(); + } + }); + + it("runs initialize, newSession, and prompt through ClientSideConnection", async () => { + const updates: SessionNotification[] = []; + const server = await startTestServer( + (conn: AgentSideConnection) => new TestAgent(conn, { chunkCount: 2 }), + ); + const stream = createHttpStream(server.url); + const conn = new ClientSideConnection( + () => createTestClient({ updates }), + stream, + ); + + try { + expect( + await conn.initialize({ + protocolVersion: PROTOCOL_VERSION, + clientCapabilities: {}, + }), + ).toMatchObject({ + protocolVersion: PROTOCOL_VERSION, + agentCapabilities: { loadSession: false }, + }); + + const session = await conn.newSession({ cwd: "/tmp", mcpServers: [] }); + expect(session.sessionId).toMatch(/^[0-9a-f-]{36}$/); + + await expect( + conn.prompt({ + sessionId: session.sessionId, + prompt: [{ type: "text", text: "Hello" }], + }), + ).resolves.toEqual({ stopReason: "end_turn" }); + expect(updates).toHaveLength(2); + expect(updates).toMatchObject([ + { + sessionId: session.sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { text: "chunk-1" }, + }, + }, + { + sessionId: session.sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { text: "chunk-2" }, + }, + }, + ]); + } finally { + await closeStream(stream); + await server.close(); + } + }); + + it("round-trips permission requests through ClientSideConnection", async () => { + const updates: SessionNotification[] = []; + const permissionRequests: RequestPermissionRequest[] = []; + const server = await startTestServer( + (conn: AgentSideConnection) => + new TestAgent(conn, { enablePermission: true }), + ); + const stream = createHttpStream(server.url); + const conn = new ClientSideConnection( + () => createTestClient({ updates, permissionRequests }), + stream, + ); + + try { + await conn.initialize({ + protocolVersion: PROTOCOL_VERSION, + clientCapabilities: {}, + }); + const session = await conn.newSession({ cwd: "/tmp", mcpServers: [] }); + + await expect( + conn.prompt({ + sessionId: session.sessionId, + prompt: [{ type: "text", text: "Hello" }], + }), + ).resolves.toEqual({ stopReason: "end_turn" }); + + expect(permissionRequests).toHaveLength(1); + expect(permissionRequests[0]).toMatchObject({ + sessionId: session.sessionId, + toolCall: { + toolCallId: "permission-tool", + title: "Permission tool", + }, + }); + expect(updates).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + sessionId: session.sessionId, + update: expect.objectContaining({ + sessionUpdate: "agent_message_chunk", + content: expect.objectContaining({ + text: "permission-selected-allow", + }), + }), + }), + ]), + ); + } finally { + await closeStream(stream); + await server.close(); + } + }); + + it("keeps multiple sessions isolated through the SDK client abstraction", async () => { + const updates: SessionNotification[] = []; + const server = await startTestServer(); + const stream = createHttpStream(server.url); + const conn = new ClientSideConnection( + () => createTestClient({ updates }), + stream, + ); + + try { + await conn.initialize({ + protocolVersion: PROTOCOL_VERSION, + clientCapabilities: {}, + }); + const firstSession = await conn.newSession({ + cwd: "/tmp", + mcpServers: [], + }); + const secondSession = await conn.newSession({ + cwd: "/tmp/other", + mcpServers: [], + }); + + await Promise.all([ + conn.prompt({ + sessionId: firstSession.sessionId, + prompt: [{ type: "text", text: "First" }], + }), + conn.prompt({ + sessionId: secondSession.sessionId, + prompt: [{ type: "text", text: "Second" }], + }), + ]); + + expect(updates).toEqual( + expect.arrayContaining([ + expect.objectContaining({ sessionId: firstSession.sessionId }), + expect.objectContaining({ sessionId: secondSession.sessionId }), + ]), + ); + expect( + updates.filter((update) => update.sessionId === firstSession.sessionId), + ).toHaveLength(1); + expect( + updates.filter( + (update) => update.sessionId === secondSession.sessionId, + ), + ).toHaveLength(1); + } finally { + await closeStream(stream); + await server.close(); + } + }); +}); + +interface RecordedRequest { + readonly url: string; + readonly method: string; + readonly headers: Headers; + readonly body: string; +} + +interface RecordedSseRequest { + readonly signal: AbortSignal; + readonly writer: WritableStreamDefaultWriter; +} + +interface ControlledFetch { + readonly fetch: typeof globalThis.fetch; + readonly requests: RecordedRequest[]; + readonly sseRequests: RecordedSseRequest[]; + readonly sendSse: (index: number, message: AnyMessage) => Promise; +} + +interface TestClientState { + readonly updates: SessionNotification[]; + readonly permissionRequests?: RequestPermissionRequest[]; +} + +function createControlledFetch(): ControlledFetch { + const requests: RecordedRequest[] = []; + const sseRequests: RecordedSseRequest[] = []; + const encoder = new TextEncoder(); + + return { + requests, + sseRequests, + fetch: async (input, init) => { + const method = init?.method ?? "GET"; + const headers = new Headers(init?.headers); + requests.push({ + url: String(input), + method, + headers, + body: bodyToString(init?.body), + }); + + if (method === "POST" && !headers.has(HEADER_CONNECTION_ID)) { + return jsonResponse(initializeResponse, 200, { + [HEADER_CONNECTION_ID]: "connection-1", + }); + } + + if (method === "POST" || method === "DELETE") { + return new Response(null, { status: 202 }); + } + + if (method === "GET") { + const stream = new TransformStream(); + const writer = stream.writable.getWriter(); + const signal = init?.signal; + + if (signal) { + signal.addEventListener("abort", () => { + void writer.close(); + }); + } + + sseRequests.push({ + signal: signal ?? new AbortController().signal, + writer, + }); + + return new Response(stream.readable, { + status: 200, + headers: { "Content-Type": EVENT_STREAM_MIME_TYPE }, + }); + } + + return new Response("Unexpected method", { status: 405 }); + }, + sendSse: async (index, message) => { + await sseAt(sseRequests, index).writer.write( + encoder.encode(serializeSseEvent(message)), + ); + }, + }; +} + +function createTestClient(state: TestClientState): Client { + return { + requestPermission: (params): Promise => { + state.permissionRequests?.push(params); + return Promise.resolve({ + outcome: { + outcome: "selected", + optionId: "allow", + }, + }); + }, + sessionUpdate: (params): Promise => { + state.updates.push(params); + return Promise.resolve(); + }, + }; +} + +async function closeStream(stream: { + writable: WritableStream; +}): Promise { + await stream.writable.close().catch(() => undefined); +} + +async function readMessage( + reader: ReadableStreamDefaultReader, +): Promise { + const result = await reader.read(); + if (result.done) { + throw new Error("Expected a message"); + } + + return result.value; +} + +function requestAt( + requests: readonly RecordedRequest[], + index: number, +): RecordedRequest { + const request = requests[index]; + if (!request) { + throw new Error(`Expected request at index ${index}`); + } + + return request; +} + +function sseAt( + requests: readonly RecordedSseRequest[], + index: number, +): RecordedSseRequest { + const request = requests[index]; + if (!request) { + throw new Error(`Expected SSE request at index ${index}`); + } + + return request; +} + +function bodyToString(body: BodyInit | null | undefined): string { + return typeof body === "string" ? body : ""; +} + +function jsonResponse( + body: AnyMessage, + status: number, + headers: Record, +): Response { + return new Response(JSON.stringify(body), { + status, + headers: { + "Content-Type": JSON_MIME_TYPE, + ...headers, + }, + }); +} diff --git a/src/http-stream.ts b/src/http-stream.ts new file mode 100644 index 0000000..f473735 --- /dev/null +++ b/src/http-stream.ts @@ -0,0 +1,303 @@ +import { isJsonRpcMessage } from "./jsonrpc.js"; +import { + EVENT_STREAM_MIME_TYPE, + HEADER_CONNECTION_ID, + HEADER_SESSION_ID, + JSON_MIME_TYPE, + isInitializeRequest, + sessionIdFromMessageParams, + sessionIdFromResponseResult, +} from "./protocol.js"; +import { parseSseStream } from "./sse.js"; + +import type { AnyMessage } from "./jsonrpc.js"; +import type { Stream } from "./stream.js"; + +export interface HttpStreamOptions { + readonly fetch?: typeof globalThis.fetch; + readonly headers?: Record; +} + +/** + * Creates an ACP Stream that speaks the Streamable HTTP transport. + * + * The transport uses HTTP POST for client-to-agent messages and SSE GET streams for agent-to-client messages. + * Cookie management is intentionally not built in; pass a cookie-aware fetch implementation when needed. + */ +export function createHttpStream( + serverUrl: string, + options: HttpStreamOptions = {}, +): Stream { + return new HttpStreamTransport(serverUrl, options).stream; +} + +class HttpStreamTransport { + readonly stream: Stream; + + private readonly fetchImpl: typeof globalThis.fetch; + private readonly headers: Record; + private readonly abortController = new AbortController(); + private readonly knownSessions = new Set(); + + private readableController: + | ReadableStreamDefaultController + | undefined; + private connectionId: string | undefined; + private isClosed = false; + private writeChain: Promise = Promise.resolve(); + + constructor( + private readonly serverUrl: string, + options: HttpStreamOptions, + ) { + this.fetchImpl = resolveFetch(options.fetch); + this.headers = options.headers ?? {}; + + this.stream = { + readable: new ReadableStream({ + start: (controller) => { + this.readableController = controller; + }, + cancel: () => { + void this.close(); + }, + }), + writable: new WritableStream({ + write: (message) => { + this.writeChain = this.writeChain.then(() => + this.writeMessage(message), + ); + return this.writeChain; + }, + close: () => this.close(), + abort: () => this.close(), + }), + }; + } + + private async writeMessage(message: AnyMessage): Promise { + if (this.isClosed) { + throw new Error("ACP HTTP stream is closed"); + } + + if (!this.connectionId) { + await this.postInitialize(message); + return; + } + + await this.postConnectedMessage(message); + } + + private async postInitialize(message: AnyMessage): Promise { + if (!isInitializeRequest(message)) { + throw new Error("ACP HTTP stream first message must be initialize"); + } + + const response = await this.fetchImpl(this.serverUrl, { + method: "POST", + headers: { + ...this.headers, + "Content-Type": JSON_MIME_TYPE, + }, + body: JSON.stringify(message), + }); + + if (!response.ok) { + throw await httpError("ACP initialize failed", response); + } + + const connectionId = response.headers.get(HEADER_CONNECTION_ID); + if (!connectionId) { + throw new Error("ACP initialize response missing Acp-Connection-Id"); + } + + const body: unknown = await response.json(); + if (!isJsonRpcMessage(body)) { + throw new Error("ACP initialize response was not a JSON-RPC message"); + } + + this.connectionId = connectionId; + this.openConnectionSse(); + this.enqueue(body); + } + + private async postConnectedMessage(message: AnyMessage): Promise { + const connectionId = this.connectionId; + if (!connectionId) { + throw new Error("ACP HTTP stream is not initialized"); + } + + const sessionId = sessionIdFromMessageParams(message); + const response = await this.fetchImpl(this.serverUrl, { + method: "POST", + headers: { + ...this.headers, + "Content-Type": JSON_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + ...(sessionId ? { [HEADER_SESSION_ID]: sessionId } : {}), + }, + body: JSON.stringify(message), + }); + + if (!response.ok) { + throw await httpError("ACP POST failed", response); + } + } + + private openConnectionSse(): void { + const connectionId = this.connectionId; + if (!connectionId) { + return; + } + + void this.openSse({ + [HEADER_CONNECTION_ID]: connectionId, + }); + } + + private openSessionSse(sessionId: string): void { + if (this.knownSessions.has(sessionId)) { + return; + } + + const connectionId = this.connectionId; + if (!connectionId) { + return; + } + + this.knownSessions.add(sessionId); + + void this.openSse({ + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }); + } + + private async openSse(headers: Record): Promise { + try { + const response = await this.fetchImpl(this.serverUrl, { + method: "GET", + headers: { + ...this.headers, + Accept: EVENT_STREAM_MIME_TYPE, + ...headers, + }, + signal: this.abortController.signal, + }); + + if (!response.ok) { + throw await httpError("ACP SSE connection failed", response); + } + + if (!response.body) { + throw new Error("ACP SSE response missing body"); + } + + for await (const message of parseSseStream(response.body)) { + if (this.isClosed) { + return; + } + + const sessionId = sessionIdFromResponseResult(message); + if (sessionId) { + this.openSessionSse(sessionId); + } + + this.enqueue(message); + } + } catch (error) { + if (this.isClosed || this.abortController.signal.aborted) { + return; + } + + this.errorReadable(error); + } + } + + private async close(): Promise { + if (this.isClosed) { + return; + } + + this.isClosed = true; + + const connectionId = this.connectionId; + if (connectionId) { + const response = await this.fetchImpl(this.serverUrl, { + method: "DELETE", + headers: { + ...this.headers, + [HEADER_CONNECTION_ID]: connectionId, + }, + }); + + if (!response.ok) { + this.abortController.abort(); + this.closeReadable(); + throw await httpError("ACP DELETE failed", response); + } + } + + this.abortController.abort(); + this.closeReadable(); + } + + private enqueue(message: AnyMessage): void { + try { + this.readableController?.enqueue(message); + } catch (error) { + this.errorReadable(error); + } + } + + private errorReadable(error: unknown): void { + if (this.isClosed) { + return; + } + + this.isClosed = true; + this.abortController.abort(); + + try { + this.readableController?.error(error); + } catch { + // The readable side may already be closed or cancelled. + } + } + + private closeReadable(): void { + try { + this.readableController?.close(); + } catch { + // The readable side may already be closed, cancelled, or errored. + } + } +} + +function resolveFetch( + fetchImpl: typeof globalThis.fetch | undefined, +): typeof globalThis.fetch { + if (fetchImpl) { + return fetchImpl; + } + + if (typeof globalThis.fetch === "function") { + return (input, init) => globalThis.fetch(input, init); + } + + throw new Error( + "createHttpStream requires globalThis.fetch or options.fetch", + ); +} + +async function httpError(prefix: string, response: Response): Promise { + const text = await response.text().catch(() => ""); + + if (text) { + return new Error( + `${prefix}: ${response.status} ${response.statusText}: ${text}`, + ); + } + + return new Error(`${prefix}: ${response.status} ${response.statusText}`); +} diff --git a/src/jsonrpc.ts b/src/jsonrpc.ts index 6f556d6..31ee0ee 100644 --- a/src/jsonrpc.ts +++ b/src/jsonrpc.ts @@ -44,3 +44,33 @@ export type NotificationHandler = ( method: string, params: unknown, ) => Promise; + +export function isJsonRpcMessage(value: unknown): value is AnyMessage { + if (!isRecord(value) || value["jsonrpc"] !== "2.0") { + return false; + } + + if ("method" in value) { + return typeof value["method"] === "string"; + } + + return "id" in value; +} + +export function isRequestMessage(message: AnyMessage): message is AnyRequest { + return "id" in message && "method" in message; +} + +export function isResponseMessage(message: AnyMessage): message is AnyResponse { + return "id" in message && !("method" in message); +} + +export function isNotificationMessage( + message: AnyMessage, +): message is AnyNotification { + return "method" in message && !("id" in message); +} + +export function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} diff --git a/src/protocol.ts b/src/protocol.ts index a88cb95..61c99e7 100644 --- a/src/protocol.ts +++ b/src/protocol.ts @@ -1,5 +1,5 @@ import { AGENT_METHODS } from "./schema/index.js"; - +import { isRecord, isResponseMessage } from "./jsonrpc.js"; import type { AnyMessage } from "./jsonrpc.js"; export const HEADER_CONNECTION_ID = "Acp-Connection-Id"; @@ -31,6 +31,27 @@ export function sessionIdFromParams(params: unknown): string | undefined { return typeof sessionId === "string" ? sessionId : undefined; } +export function sessionIdFromMessageParams( + message: AnyMessage, +): string | undefined { + return "method" in message ? sessionIdFromParams(message.params) : undefined; +} + +export function sessionIdFromResponseResult( + message: AnyMessage, +): string | undefined { + if (!isResponseMessage(message) || !("result" in message)) { + return undefined; + } + + if (!isRecord(message.result)) { + return undefined; + } + + const sessionId = message.result["sessionId"]; + return typeof sessionId === "string" ? sessionId : undefined; +} + export function isInitializeRequest(msg: AnyMessage): boolean { return ( msg.jsonrpc === "2.0" && @@ -53,7 +74,3 @@ export function messageIdKey( return undefined; } - -function isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null; -} diff --git a/src/server.ts b/src/server.ts index d6ae5d8..a8748c7 100644 --- a/src/server.ts +++ b/src/server.ts @@ -9,6 +9,11 @@ import { methodRequiresSessionHeader, sessionIdFromParams, } from "./protocol.js"; +import { + isJsonRpcMessage, + isRequestMessage, + isResponseMessage, +} from "./jsonrpc.js"; import { serializeSseEvent, serializeSseKeepAlive } from "./sse.js"; @@ -18,7 +23,7 @@ import type { ResponseRoute, } from "./connection.js"; import type { Agent, AgentSideConnection } from "./acp.js"; -import type { AnyMessage, AnyResponse } from "./jsonrpc.js"; +import type { AnyMessage, AnyRequest, AnyResponse } from "./jsonrpc.js"; export interface AcpServerOptions { createAgent: (conn: AgentSideConnection) => Agent; @@ -237,11 +242,7 @@ type RouteResult = message: string; }; -type ClientRequestMessage = AnyMessage & { - readonly id: string | number | null; - readonly method: string; - readonly params?: unknown; -}; +type ClientRequestMessage = AnyRequest; async function readJson(req: Request): Promise { try { @@ -346,28 +347,6 @@ function determineRoute( }; } -function isJsonRpcMessage(value: unknown): value is AnyMessage { - return ( - isRecord(value) && - value.jsonrpc === "2.0" && - ("method" in value || "id" in value) - ); -} - -function isRequestMessage( - message: AnyMessage, -): message is ClientRequestMessage { - return "method" in message && "id" in message; -} - -function isResponseMessage(message: AnyMessage): message is AnyResponse { - return "id" in message && !("method" in message); -} - -function isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null; -} - function sseResponse(subscription: OutboundSubscription): Response { return new Response(createSseBody(subscription), { status: 200, diff --git a/src/sse.ts b/src/sse.ts index 9b7b8cb..7f52388 100644 --- a/src/sse.ts +++ b/src/sse.ts @@ -1,4 +1,5 @@ import type { AnyMessage } from "./jsonrpc.js"; +import { isJsonRpcMessage } from "./jsonrpc.js"; export function serializeSseEvent(msg: AnyMessage): string { return `data: ${JSON.stringify(msg)}\n\n`; @@ -76,7 +77,7 @@ function parseSseEvent(eventPart: string): AnyMessage | undefined { try { const parsed: unknown = JSON.parse(data); - if (isAnyMessage(parsed)) { + if (isJsonRpcMessage(parsed)) { return parsed; } @@ -87,19 +88,3 @@ function parseSseEvent(eventPart: string): AnyMessage | undefined { return undefined; } } - -function isAnyMessage(value: unknown): value is AnyMessage { - if (!isRecord(value) || value["jsonrpc"] !== "2.0") { - return false; - } - - if ("method" in value) { - return typeof value["method"] === "string"; - } - - return "id" in value; -} - -function isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null; -} From 8e7e0c3b6c2f16e05424345226fe6020195f3e0d Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 19 May 2026 09:50:32 +1000 Subject: [PATCH 07/19] Add WebSocket server impl --- src/server.ts | 10 +- src/ws-server.ts | 425 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 434 insertions(+), 1 deletion(-) create mode 100644 src/ws-server.ts diff --git a/src/server.ts b/src/server.ts index a8748c7..85353fe 100644 --- a/src/server.ts +++ b/src/server.ts @@ -14,8 +14,9 @@ import { isRequestMessage, isResponseMessage, } from "./jsonrpc.js"; - import { serializeSseEvent, serializeSseKeepAlive } from "./sse.js"; +import { handleWebSocketConnection } from "./ws-server.js"; +import type { WebSocketServerSocket } from "./ws-server.js"; import type { ConnectionState, @@ -53,6 +54,13 @@ export class AcpServer { return textResponse("Method Not Allowed", 405); } + handleWebSocket(socket: WebSocketServerSocket): void { + handleWebSocketConnection(socket, { + registry: this.registry, + createAgent: this.createAgent, + }); + } + async close(): Promise { this.registry.closeAll(); } diff --git a/src/ws-server.ts b/src/ws-server.ts new file mode 100644 index 0000000..2dc0757 --- /dev/null +++ b/src/ws-server.ts @@ -0,0 +1,425 @@ +import { + isJsonRpcMessage, + isRequestMessage, + isResponseMessage, +} from "./jsonrpc.js"; +import { + isInitializeRequest, + messageIdKey, + sessionIdFromParams, +} from "./protocol.js"; + +import type { Agent, AgentSideConnection } from "./acp.js"; +import type { + ConnectionRegistry, + ConnectionState, + ResponseRoute, +} from "./connection.js"; +import type { AnyMessage, AnyRequest } from "./jsonrpc.js"; + +type ForwardResult = + | { + ok: true; + } + | { + ok: false; + message: string; + }; + +export interface WebSocketServerSocket { + readonly readyState?: number; + send(data: string): void; + close(code?: number, reason?: string): void; + addEventListener?(type: string, listener: (event: unknown) => void): void; + removeEventListener?(type: string, listener: (event: unknown) => void): void; + on?(type: string, listener: (...args: unknown[]) => void): unknown; + off?(type: string, listener: (...args: unknown[]) => void): unknown; + removeListener?( + type: string, + listener: (...args: unknown[]) => void, + ): unknown; +} + +export interface WebSocketConnectionOptions { + readonly registry: ConnectionRegistry; + readonly createAgent: (conn: AgentSideConnection) => Agent; +} + +export function handleWebSocketConnection( + socket: WebSocketServerSocket, + options: WebSocketConnectionOptions, +): void { + const session = new WebSocketServerSession(socket, options); + session.start(); +} + +class WebSocketServerSession { + private connection: ConnectionState | undefined; + private outboundReader: ReadableStreamDefaultReader | undefined; + private isClosed = false; + private readonly detachListeners: Array<() => void> = []; + + constructor( + private readonly socket: WebSocketServerSocket, + private readonly options: WebSocketConnectionOptions, + ) {} + + start(): void { + this.detachListeners.push( + onSocket(this.socket, "message", (...args) => { + void this.handleSocketMessage(args); + }), + ); + + this.detachListeners.push( + onSocket(this.socket, "close", () => { + void this.closeSession(); + }), + ); + + this.detachListeners.push( + onSocket(this.socket, "error", () => { + void this.shutdown(1011, "WebSocket error"); + }), + ); + } + + private async handleSocketMessage(args: unknown[]): Promise { + if (this.isClosed) { + return; + } + + const text = socketMessageToString(args); + if (text === undefined) { + console.warn("Ignoring non-text ACP WebSocket frame"); + return; + } + + let value: unknown; + try { + value = JSON.parse(text); + } catch (error) { + console.warn("Ignoring malformed ACP WebSocket JSON message:", error); + await this.shutdownIfUninitialized(1007, "Malformed JSON"); + + return; + } + + if (Array.isArray(value)) { + console.warn("Ignoring ACP WebSocket JSON-RPC batch message"); + await this.shutdownIfUninitialized( + 1002, + "JSON-RPC batch messages are not supported", + ); + return; + } + + if (!isJsonRpcMessage(value)) { + console.warn("Ignoring non-JSON-RPC ACP WebSocket message:", value); + await this.shutdownIfUninitialized(1002, "Invalid JSON-RPC message"); + return; + } + + if (!this.connection) { + await this.handleInitialize(value); + return; + } + + const forwarded = await this.forwardMessage(value); + if (!forwarded.ok) { + console.warn("Ignoring ACP WebSocket message:", forwarded.message); + } + } + + private async handleInitialize(message: AnyMessage): Promise { + if (!isInitializeRequest(message)) { + console.warn("First ACP WebSocket message must be initialize"); + await this.shutdown(1002, "First message must be initialize"); + return; + } + + if (!("id" in message) || message.id === null) { + console.warn("ACP WebSocket initialize request must include an ID"); + await this.shutdown(1002, "Initialize request must include an ID"); + return; + } + + let connection: ConnectionState | undefined; + + try { + connection = this.options.registry.createConnection( + this.options.createAgent, + ); + + await writeInbound(connection, message); + + const initialResponse = await connection.recvInitial(message.id); + + this.connection = connection; + connection.startRouter(); + + this.send(initialResponse); + this.startOutboundPump(connection); + } catch (error) { + if (connection) { + this.options.registry.remove(connection.connectionId); + } + + this.send({ + jsonrpc: "2.0", + id: message.id, + error: { + code: -32603, + message: "Initialize failed", + data: error instanceof Error ? error.message : undefined, + }, + }); + + await this.shutdown(1011, "Initialize failed"); + } + } + + private async forwardMessage(message: AnyMessage): Promise { + const connection = this.connection; + + if (!connection) { + return { + ok: false, + message: "ACP WebSocket connection is not initialized", + }; + } + + if (isRequestMessage(message)) { + const route = determineWebSocketRoute(message); + + if (route !== "connection") { + connection.ensureSession(route.session); + } + + const key = messageIdKey(message.id); + + if (key) { + connection.pendingRoutes.set(key, route); + } + + await writeInbound(connection, message); + return { ok: true }; + } + + if (isResponseMessage(message)) { + await writeInbound(connection, message); + return { ok: true }; + } + + await writeInbound(connection, message); + return { ok: true }; + } + + private startOutboundPump(connection: ConnectionState): void { + const subscription = connection.allOutbound.subscribe(); + const reader = subscription.stream.getReader(); + this.outboundReader = reader; + + void (async () => { + try { + for (const message of subscription.replay) { + if (!this.send(message)) { + return; + } + } + + while (!this.isClosed) { + const result = await reader.read(); + + if (result.done) { + return; + } + + if (!this.send(result.value)) { + return; + } + } + } catch (error) { + if (!this.isClosed) { + console.error("ACP WebSocket outbound pump failed:", error); + } + } finally { + if (this.outboundReader === reader) { + this.outboundReader = undefined; + } + + reader.releaseLock(); + + if (!this.isClosed) { + void this.shutdown(); + } + } + })(); + } + + private send(message: AnyMessage): boolean { + if (this.isClosed) { + return false; + } + + try { + this.socket.send(JSON.stringify(message)); + return true; + } catch (error) { + console.warn("Failed to send ACP WebSocket message:", error); + void this.shutdown(1011, "Failed to send message"); + return false; + } + } + + private async shutdownIfUninitialized( + code?: number, + reason?: string, + ): Promise { + if (this.connection) { + return; + } + + await this.shutdown(code, reason); + } + + private async shutdown(code?: number, reason?: string): Promise { + this.closeSocket(code, reason); + await this.closeSession(); + } + + private closeSocket(code?: number, reason?: string): void { + try { + this.socket.close(code, reason); + } catch (error) { + console.warn("Failed to close ACP WebSocket:", error); + } + } + + private async closeSession(): Promise { + if (this.isClosed) { + return; + } + + this.isClosed = true; + + for (const detach of this.detachListeners.splice(0)) { + detach(); + } + + const outboundReader = this.outboundReader; + this.outboundReader = undefined; + + if (outboundReader) { + await outboundReader.cancel(); + } + + if (this.connection) { + this.options.registry.remove(this.connection.connectionId); + this.connection = undefined; + } + } +} + +async function writeInbound( + connection: ConnectionState, + message: AnyMessage, +): Promise { + const writer = connection.inboundTx.getWriter(); + + try { + await writer.write(message); + } finally { + writer.releaseLock(); + } +} + +function determineWebSocketRoute(message: AnyRequest): ResponseRoute { + const sessionId = sessionIdFromParams(message.params); + + if (sessionId) { + return { + session: sessionId, + }; + } + + return "connection"; +} + +function onSocket( + socket: WebSocketServerSocket, + type: string, + listener: (...args: unknown[]) => void, +): () => void { + if (socket.addEventListener) { + const eventListener = (event: unknown) => listener(event); + socket.addEventListener(type, eventListener); + + return () => { + socket.removeEventListener?.(type, eventListener); + }; + } + + if (socket.on) { + socket.on(type, listener); + + return () => { + if (socket.off) { + socket.off(type, listener); + return; + } + + socket.removeListener?.(type, listener); + }; + } + + throw new Error("WebSocket object does not support event listeners"); +} + +function socketMessageToString(args: unknown[]): string | undefined { + const data = extractMessageData(args); + + if (typeof data === "string") { + return data; + } + + if (data instanceof ArrayBuffer || ArrayBuffer.isView(data)) { + return new TextDecoder().decode(data); + } + + if (Array.isArray(data) && data.every(ArrayBuffer.isView)) { + return decodeArrayBufferViews(data); + } + + return undefined; +} + +function extractMessageData(args: unknown[]): unknown { + const [first] = args; + + if (isMessageEventLike(first)) { + return first.data; + } + + return first; +} + +function isMessageEventLike(value: unknown): value is { data: unknown } { + return typeof value === "object" && value !== null && "data" in value; +} + +function decodeArrayBufferViews(views: ArrayBufferView[]): string { + const totalLength = views.reduce((sum, view) => sum + view.byteLength, 0); + const combined = new Uint8Array(totalLength); + let offset = 0; + + for (const view of views) { + combined.set( + new Uint8Array(view.buffer, view.byteOffset, view.byteLength), + offset, + ); + offset += view.byteLength; + } + + return new TextDecoder().decode(combined); +} From 2822ae77be66e9d2720c352249f155acfa921587 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 19 May 2026 10:27:08 +1000 Subject: [PATCH 08/19] Add WebSocket client SDK and split out shared methods into ws-utils, add tests --- package-lock.json | 34 +++ package.json | 6 + src/test-support/test-http-server.ts | 36 ++- src/ws-server.ts | 130 ++------- src/ws-stream.test.ts | 411 +++++++++++++++++++++++++++ src/ws-stream.ts | 239 ++++++++++++++++ src/ws-utils.ts | 108 +++++++ 7 files changed, 861 insertions(+), 103 deletions(-) create mode 100644 src/ws-stream.test.ts create mode 100644 src/ws-stream.ts create mode 100644 src/ws-utils.ts diff --git a/package-lock.json b/package-lock.json index 8dee173..e72ef9d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -12,6 +12,7 @@ "@eslint/js": "^10.0.1", "@hey-api/openapi-ts": "^0.97.0", "@types/node": "^25.5.0", + "@types/ws": "^8.5.13", "@typescript-eslint/eslint-plugin": "^8.57.1", "@typescript-eslint/parser": "^8.57.1", "concurrently": "^9.2.1", @@ -25,6 +26,7 @@ "typedoc-github-theme": "^0.4.0", "typescript": "^6.0.2", "vitest": "^4.1.0", + "ws": "^8.18.0", "zod": "^3.25.0 || ^4.0.0" }, "peerDependencies": { @@ -1263,6 +1265,16 @@ "dev": true, "license": "MIT" }, + "node_modules/@types/ws": { + "version": "8.18.1", + "resolved": "https://packages.atlassian.com/api/npm/npm-remote/@types/ws/-/ws-8.18.1.tgz", + "integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@typescript-eslint/eslint-plugin": { "version": "8.59.3", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.59.3.tgz", @@ -4513,6 +4525,28 @@ "url": "https://github.com/chalk/wrap-ansi?sponsor=1" } }, + "node_modules/ws": { + "version": "8.20.1", + "resolved": "https://packages.atlassian.com/api/npm/npm-remote/ws/-/ws-8.20.1.tgz", + "integrity": "sha512-It4dO0K5v//JtTXuPkfEOaI3uUN87iYPnqo/ZzqCoG3g8uhA66QUMs/SrM0YK7/NAu+r4LMh/9dq2A7k+rHs+w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, "node_modules/wsl-utils": { "version": "0.3.1", "resolved": "https://registry.npmjs.org/wsl-utils/-/wsl-utils-0.3.1.tgz", diff --git a/package.json b/package.json index c93aff2..e56ebcd 100644 --- a/package.json +++ b/package.json @@ -33,6 +33,10 @@ "types": "./dist/http-stream.d.ts", "default": "./dist/http-stream.js" }, + "./ws-client": { + "types": "./dist/ws-stream.d.ts", + "default": "./dist/ws-stream.js" + }, "./server": { "types": "./dist/server.d.ts", "default": "./dist/server.js" @@ -69,6 +73,7 @@ "@eslint/js": "^10.0.1", "@hey-api/openapi-ts": "^0.97.0", "@types/node": "^25.5.0", + "@types/ws": "^8.5.13", "@typescript-eslint/eslint-plugin": "^8.57.1", "@typescript-eslint/parser": "^8.57.1", "concurrently": "^9.2.1", @@ -82,6 +87,7 @@ "typedoc-github-theme": "^0.4.0", "typescript": "^6.0.2", "vitest": "^4.1.0", + "ws": "^8.18.0", "zod": "^3.25.0 || ^4.0.0" } } diff --git a/src/test-support/test-http-server.ts b/src/test-support/test-http-server.ts index de29c5e..96b746b 100644 --- a/src/test-support/test-http-server.ts +++ b/src/test-support/test-http-server.ts @@ -1,4 +1,5 @@ import http from "node:http"; +import { WebSocketServer } from "ws"; import { AcpServer } from "../server.js"; import { createNodeHttpHandler } from "../node-adapter.js"; @@ -9,6 +10,7 @@ import type { Agent, AgentSideConnection } from "../acp.js"; export interface TestHttpServer { readonly url: string; + readonly wsUrl: string; readonly close: () => Promise; } @@ -19,6 +21,13 @@ export async function startTestServer( ): Promise { const acpServer = new AcpServer({ createAgent: agentFactory }); const httpServer = http.createServer(createNodeHttpHandler(acpServer)); + const webSocketServer = new WebSocketServer({ noServer: true }); + + httpServer.on("upgrade", (req, socket, head) => { + webSocketServer.handleUpgrade(req, socket, head, (webSocket) => { + acpServer.handleWebSocket(webSocket); + }); + }); await listen(httpServer, options.port ?? 0); @@ -30,8 +39,14 @@ export async function startTestServer( return { url: `http://127.0.0.1:${address.port}`, + wsUrl: `ws://127.0.0.1:${address.port}`, close: async () => { - await Promise.all([acpServer.close(), closeHttpServer(httpServer)]); + terminateWebSockets(webSocketServer); + await Promise.all([ + acpServer.close(), + closeWebSocketServer(webSocketServer), + closeHttpServer(httpServer), + ]); }, }; } @@ -54,6 +69,25 @@ function listen(server: http.Server, port: number): Promise { }); } +function terminateWebSockets(server: WebSocketServer): void { + for (const client of server.clients) { + client.terminate(); + } +} + +function closeWebSocketServer(server: WebSocketServer): Promise { + return new Promise((resolve, reject) => { + server.close((error) => { + if (error) { + reject(error); + return; + } + + resolve(); + }); + }); +} + function closeHttpServer(server: http.Server): Promise { return new Promise((resolve, reject) => { server.close((error) => { diff --git a/src/ws-server.ts b/src/ws-server.ts index 2dc0757..fe7cc92 100644 --- a/src/ws-server.ts +++ b/src/ws-server.ts @@ -8,7 +8,7 @@ import { messageIdKey, sessionIdFromParams, } from "./protocol.js"; - +import { onWebSocket, webSocketMessageToString } from "./ws-utils.js"; import type { Agent, AgentSideConnection } from "./acp.js"; import type { ConnectionRegistry, @@ -16,6 +16,9 @@ import type { ResponseRoute, } from "./connection.js"; import type { AnyMessage, AnyRequest } from "./jsonrpc.js"; +import type { WebSocketLike } from "./ws-utils.js"; + +export type WebSocketServerSocket = WebSocketLike; type ForwardResult = | { @@ -26,27 +29,13 @@ type ForwardResult = message: string; }; -export interface WebSocketServerSocket { - readonly readyState?: number; - send(data: string): void; - close(code?: number, reason?: string): void; - addEventListener?(type: string, listener: (event: unknown) => void): void; - removeEventListener?(type: string, listener: (event: unknown) => void): void; - on?(type: string, listener: (...args: unknown[]) => void): unknown; - off?(type: string, listener: (...args: unknown[]) => void): unknown; - removeListener?( - type: string, - listener: (...args: unknown[]) => void, - ): unknown; -} - export interface WebSocketConnectionOptions { readonly registry: ConnectionRegistry; readonly createAgent: (conn: AgentSideConnection) => Agent; } export function handleWebSocketConnection( - socket: WebSocketServerSocket, + socket: WebSocketLike, options: WebSocketConnectionOptions, ): void { const session = new WebSocketServerSession(socket, options); @@ -56,29 +45,30 @@ export function handleWebSocketConnection( class WebSocketServerSession { private connection: ConnectionState | undefined; private outboundReader: ReadableStreamDefaultReader | undefined; + private inboundWriteChain: Promise = Promise.resolve(); private isClosed = false; private readonly detachListeners: Array<() => void> = []; constructor( - private readonly socket: WebSocketServerSocket, + private readonly socket: WebSocketLike, private readonly options: WebSocketConnectionOptions, ) {} start(): void { this.detachListeners.push( - onSocket(this.socket, "message", (...args) => { + onWebSocket(this.socket, "message", (...args) => { void this.handleSocketMessage(args); }), ); this.detachListeners.push( - onSocket(this.socket, "close", () => { + onWebSocket(this.socket, "close", () => { void this.closeSession(); }), ); this.detachListeners.push( - onSocket(this.socket, "error", () => { + onWebSocket(this.socket, "error", () => { void this.shutdown(1011, "WebSocket error"); }), ); @@ -89,7 +79,7 @@ class WebSocketServerSession { return; } - const text = socketMessageToString(args); + const text = webSocketMessageToString(args); if (text === undefined) { console.warn("Ignoring non-text ACP WebSocket frame"); return; @@ -202,19 +192,33 @@ class WebSocketServerSession { connection.pendingRoutes.set(key, route); } - await writeInbound(connection, message); + await this.writeInbound(message); return { ok: true }; } if (isResponseMessage(message)) { - await writeInbound(connection, message); + await this.writeInbound(message); return { ok: true }; } - await writeInbound(connection, message); + await this.writeInbound(message); return { ok: true }; } + private async writeInbound(message: AnyMessage): Promise { + const connection = this.connection; + + if (!connection) { + throw new Error("ACP WebSocket connection is not initialized"); + } + + const write = this.inboundWriteChain.then(() => + writeInbound(connection, message), + ); + this.inboundWriteChain = write.catch(() => undefined); + await write; + } + private startOutboundPump(connection: ConnectionState): void { const subscription = connection.allOutbound.subscribe(); const reader = subscription.stream.getReader(); @@ -345,81 +349,3 @@ function determineWebSocketRoute(message: AnyRequest): ResponseRoute { return "connection"; } - -function onSocket( - socket: WebSocketServerSocket, - type: string, - listener: (...args: unknown[]) => void, -): () => void { - if (socket.addEventListener) { - const eventListener = (event: unknown) => listener(event); - socket.addEventListener(type, eventListener); - - return () => { - socket.removeEventListener?.(type, eventListener); - }; - } - - if (socket.on) { - socket.on(type, listener); - - return () => { - if (socket.off) { - socket.off(type, listener); - return; - } - - socket.removeListener?.(type, listener); - }; - } - - throw new Error("WebSocket object does not support event listeners"); -} - -function socketMessageToString(args: unknown[]): string | undefined { - const data = extractMessageData(args); - - if (typeof data === "string") { - return data; - } - - if (data instanceof ArrayBuffer || ArrayBuffer.isView(data)) { - return new TextDecoder().decode(data); - } - - if (Array.isArray(data) && data.every(ArrayBuffer.isView)) { - return decodeArrayBufferViews(data); - } - - return undefined; -} - -function extractMessageData(args: unknown[]): unknown { - const [first] = args; - - if (isMessageEventLike(first)) { - return first.data; - } - - return first; -} - -function isMessageEventLike(value: unknown): value is { data: unknown } { - return typeof value === "object" && value !== null && "data" in value; -} - -function decodeArrayBufferViews(views: ArrayBufferView[]): string { - const totalLength = views.reduce((sum, view) => sum + view.byteLength, 0); - const combined = new Uint8Array(totalLength); - let offset = 0; - - for (const view of views) { - combined.set( - new Uint8Array(view.buffer, view.byteOffset, view.byteLength), - offset, - ); - offset += view.byteLength; - } - - return new TextDecoder().decode(combined); -} diff --git a/src/ws-stream.test.ts b/src/ws-stream.test.ts new file mode 100644 index 0000000..44a6b7d --- /dev/null +++ b/src/ws-stream.test.ts @@ -0,0 +1,411 @@ +import { describe, expect, it } from "vitest"; +import { WebSocket } from "ws"; + +import { ClientSideConnection, PROTOCOL_VERSION } from "./acp.js"; +import { createWebSocketStream } from "./ws-stream.js"; +import { TestAgent } from "./test-support/test-agent.js"; +import { startTestServer } from "./test-support/test-http-server.js"; + +import type { + AgentSideConnection, + Client, + RequestPermissionRequest, + RequestPermissionResponse, + SessionNotification, +} from "./acp.js"; +import type { AnyMessage } from "./jsonrpc.js"; +import type { Stream } from "./stream.js"; +import type { WebSocketConstructor } from "./ws-stream.js"; + +const nodeWebSocket = WebSocket as unknown as WebSocketConstructor; + +const initializeRequest = { + jsonrpc: "2.0", + id: 0, + method: "initialize", + params: { + protocolVersion: PROTOCOL_VERSION, + clientCapabilities: {}, + }, +} satisfies AnyMessage; + +const initializeResponse = { + jsonrpc: "2.0", + id: 0, + result: { + protocolVersion: PROTOCOL_VERSION, + agentCapabilities: { + loadSession: false, + }, + }, +} satisfies AnyMessage; + +describe("createWebSocketStream", () => { + it("uses the custom WebSocket constructor and queues writes until the socket opens", async () => { + const instances: FakeWebSocket[] = []; + const stream = createWebSocketStream("ws://agent.example/acp", { + WebSocket: createFakeWebSocketConstructor(instances), + protocols: ["acp"], + headers: { Authorization: "Bearer token" }, + }); + const writer = stream.writable.getWriter(); + const reader = stream.readable.getReader(); + + try { + const socket = fakeSocketAt(instances, 0); + expect(socket.url).toBe("ws://agent.example/acp"); + expect(socket.protocols).toEqual(["acp"]); + expect(socket.options).toEqual({ + headers: { Authorization: "Bearer token" }, + }); + + const write = writer.write(initializeRequest); + await Promise.resolve(); + expect(socket.sent).toEqual([]); + + socket.open(); + await write; + expect(socket.sent).toEqual([JSON.stringify(initializeRequest)]); + + socket.receive(JSON.stringify(initializeResponse)); + expect(await readMessage(reader)).toEqual(initializeResponse); + } finally { + reader.releaseLock(); + await writer.close().catch(() => undefined); + writer.releaseLock(); + } + }); + + it("ignores binary, malformed JSON, and non-JSON-RPC messages", async () => { + const instances: FakeWebSocket[] = []; + const stream = createWebSocketStream("ws://agent.example/acp", { + WebSocket: createFakeWebSocketConstructor(instances), + }); + const reader = stream.readable.getReader(); + + try { + const socket = fakeSocketAt(instances, 0); + socket.open(); + socket.receive(new Uint8Array([1, 2, 3]), true); + socket.receive("not json"); + socket.receive(JSON.stringify({ hello: "world" })); + socket.receive(JSON.stringify(initializeResponse)); + + expect(await readMessage(reader)).toEqual(initializeResponse); + } finally { + reader.releaseLock(); + await closeStream(stream); + } + }); + + it("closes the readable stream when the socket closes", async () => { + const instances: FakeWebSocket[] = []; + const stream = createWebSocketStream("ws://agent.example/acp", { + WebSocket: createFakeWebSocketConstructor(instances), + }); + const reader = stream.readable.getReader(); + + try { + const socket = fakeSocketAt(instances, 0); + socket.open(); + socket.close(); + + expect(await reader.read()).toEqual({ done: true, value: undefined }); + } finally { + reader.releaseLock(); + } + }); + + it("runs initialize, newSession, and prompt through ClientSideConnection", async () => { + const updates: SessionNotification[] = []; + const server = await startTestServer( + (conn: AgentSideConnection) => new TestAgent(conn, { chunkCount: 2 }), + ); + const stream = createWebSocketStream(server.wsUrl, { + WebSocket: nodeWebSocket, + }); + const conn = new ClientSideConnection( + () => createTestClient({ updates }), + stream, + ); + + try { + expect( + await conn.initialize({ + protocolVersion: PROTOCOL_VERSION, + clientCapabilities: {}, + }), + ).toMatchObject({ + protocolVersion: PROTOCOL_VERSION, + agentCapabilities: { loadSession: false }, + }); + + const session = await conn.newSession({ cwd: "/tmp", mcpServers: [] }); + expect(session.sessionId).toMatch(/^[0-9a-f-]{36}$/); + + await expect( + conn.prompt({ + sessionId: session.sessionId, + prompt: [{ type: "text", text: "Hello" }], + }), + ).resolves.toEqual({ stopReason: "end_turn" }); + expect(updates).toHaveLength(2); + expect(updates).toMatchObject([ + { + sessionId: session.sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { text: "chunk-1" }, + }, + }, + { + sessionId: session.sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { text: "chunk-2" }, + }, + }, + ]); + } finally { + await closeStream(stream); + await server.close(); + } + }); + + it("round-trips permission requests through ClientSideConnection", async () => { + const updates: SessionNotification[] = []; + const permissionRequests: RequestPermissionRequest[] = []; + const server = await startTestServer( + (conn: AgentSideConnection) => + new TestAgent(conn, { enablePermission: true }), + ); + const stream = createWebSocketStream(server.wsUrl, { + WebSocket: nodeWebSocket, + }); + const conn = new ClientSideConnection( + () => createTestClient({ updates, permissionRequests }), + stream, + ); + + try { + await conn.initialize({ + protocolVersion: PROTOCOL_VERSION, + clientCapabilities: {}, + }); + const session = await conn.newSession({ cwd: "/tmp", mcpServers: [] }); + + await expect( + conn.prompt({ + sessionId: session.sessionId, + prompt: [{ type: "text", text: "Hello" }], + }), + ).resolves.toEqual({ stopReason: "end_turn" }); + + expect(permissionRequests).toHaveLength(1); + expect(permissionRequests[0]).toMatchObject({ + sessionId: session.sessionId, + toolCall: { + toolCallId: "permission-tool", + title: "Permission tool", + }, + }); + expect(updates).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + sessionId: session.sessionId, + update: expect.objectContaining({ + sessionUpdate: "agent_message_chunk", + content: expect.objectContaining({ + text: "permission-selected-allow", + }), + }), + }), + ]), + ); + } finally { + await closeStream(stream); + await server.close(); + } + }); + + it("keeps multiple sessions isolated through the SDK client abstraction", async () => { + const updates: SessionNotification[] = []; + const server = await startTestServer(); + const stream = createWebSocketStream(server.wsUrl, { + WebSocket: nodeWebSocket, + }); + const conn = new ClientSideConnection( + () => createTestClient({ updates }), + stream, + ); + + try { + await conn.initialize({ + protocolVersion: PROTOCOL_VERSION, + clientCapabilities: {}, + }); + const firstSession = await conn.newSession({ + cwd: "/tmp", + mcpServers: [], + }); + const secondSession = await conn.newSession({ + cwd: "/tmp/other", + mcpServers: [], + }); + + await Promise.all([ + conn.prompt({ + sessionId: firstSession.sessionId, + prompt: [{ type: "text", text: "First" }], + }), + conn.prompt({ + sessionId: secondSession.sessionId, + prompt: [{ type: "text", text: "Second" }], + }), + ]); + + expect(updates).toEqual( + expect.arrayContaining([ + expect.objectContaining({ sessionId: firstSession.sessionId }), + expect.objectContaining({ sessionId: secondSession.sessionId }), + ]), + ); + expect( + updates.filter((update) => update.sessionId === firstSession.sessionId), + ).toHaveLength(1); + expect( + updates.filter( + (update) => update.sessionId === secondSession.sessionId, + ), + ).toHaveLength(1); + } finally { + await closeStream(stream); + await server.close(); + } + }); +}); + +interface TestClientState { + readonly updates: SessionNotification[]; + readonly permissionRequests?: RequestPermissionRequest[]; +} + +function createTestClient(state: TestClientState): Client { + return { + requestPermission: (params): Promise => { + state.permissionRequests?.push(params); + return Promise.resolve({ + outcome: { + outcome: "selected", + optionId: "allow", + }, + }); + }, + sessionUpdate: (params): Promise => { + state.updates.push(params); + return Promise.resolve(); + }, + }; +} + +async function closeStream(stream: Stream): Promise { + await stream.writable.close().catch(() => undefined); +} + +async function readMessage( + reader: ReadableStreamDefaultReader, +): Promise { + const result = await reader.read(); + if (result.done) { + throw new Error("Expected a message"); + } + + return result.value; +} + +function createFakeWebSocketConstructor( + instances: FakeWebSocket[], +): WebSocketConstructor { + return class extends FakeWebSocket { + constructor( + url: string, + protocols?: string | string[], + options?: { headers?: Record }, + ) { + super(url, protocols, options); + instances.push(this); + } + }; +} + +function fakeSocketAt( + instances: readonly FakeWebSocket[], + index: number, +): FakeWebSocket { + const socket = instances[index]; + + if (!socket) { + throw new Error(`Expected fake WebSocket at index ${index}`); + } + + return socket; +} + +class FakeWebSocket { + readonly sent: string[] = []; + readonly listeners = new Map void>>(); + readyState = 0; + + constructor( + readonly url: string, + readonly protocols?: string | string[], + readonly options?: { headers?: Record }, + ) {} + + send(data: string): void { + if (this.readyState !== 1) { + throw new Error("Fake WebSocket is not open"); + } + + this.sent.push(data); + } + + close(): void { + if (this.readyState === 3) { + return; + } + + this.readyState = 3; + this.emit("close", {}); + } + + addEventListener(type: string, listener: (event: unknown) => void): void { + let listeners = this.listeners.get(type); + + if (!listeners) { + listeners = new Set(); + this.listeners.set(type, listeners); + } + + listeners.add(listener); + } + + removeEventListener(type: string, listener: (event: unknown) => void): void { + this.listeners.get(type)?.delete(listener); + } + + open(): void { + this.readyState = 1; + this.emit("open", {}); + } + + receive(data: unknown, isBinary = false): void { + this.emit("message", { data, isBinary }); + } + + private emit(type: string, event: unknown): void { + for (const listener of this.listeners.get(type) ?? []) { + listener(event); + } + } +} diff --git a/src/ws-stream.ts b/src/ws-stream.ts new file mode 100644 index 0000000..9d65012 --- /dev/null +++ b/src/ws-stream.ts @@ -0,0 +1,239 @@ +import { isJsonRpcMessage } from "./jsonrpc.js"; +import { onWebSocket, webSocketMessageToString } from "./ws-utils.js"; +import type { WebSocketLike } from "./ws-utils.js"; +import type { AnyMessage } from "./jsonrpc.js"; +import type { Stream } from "./stream.js"; + +export interface WebSocketStreamOptions { + /** WebSocket subprotocols. */ + readonly protocols?: string[]; + /** + * Custom headers for runtimes/constructors that support them, such as Node.js + * `ws`. Browsers ignore custom headers because the browser WebSocket API does + * not expose a headers option. + */ + readonly headers?: Record; + /** Custom WebSocket constructor, for example `ws.WebSocket` in Node.js. */ + readonly WebSocket?: WebSocketConstructor; +} + +export interface WebSocketConstructor { + new ( + url: string, + protocols?: string | string[], + options?: { headers?: Record }, + ): WebSocketLike; +} + +const SOCKET_OPEN = 1; + +/** + * Creates an ACP Stream that speaks JSON-RPC over WebSocket text frames. + * + * Browser WebSocket constructors do not support custom headers. The `headers` + * option is passed as a best-effort third constructor argument for runtimes such + * as Node.js `ws` that accept it. + */ +export function createWebSocketStream( + serverUrl: string, + options: WebSocketStreamOptions = {}, +): Stream { + return new WebSocketStreamTransport(serverUrl, options).stream; +} + +class WebSocketStreamTransport { + readonly stream: Stream; + + private readonly socket: WebSocketLike; + private readableController: + | ReadableStreamDefaultController + | undefined; + private isClosed = false; + private openPromise: Promise | undefined; + private resolveOpen: (() => void) | undefined; + private rejectOpen: ((error: unknown) => void) | undefined; + private readonly detachListeners: Array<() => void> = []; + + constructor(serverUrl: string, options: WebSocketStreamOptions) { + const WebSocketCtor = resolveWebSocket(options.WebSocket); + this.socket = new WebSocketCtor(serverUrl, options.protocols, { + headers: options.headers, + }); + + this.openPromise = new Promise((resolve, reject) => { + this.resolveOpen = resolve; + this.rejectOpen = reject; + }); + + this.detachListeners.push( + onWebSocket(this.socket, "open", () => { + this.resolveOpen?.(); + this.resolveOpen = undefined; + this.rejectOpen = undefined; + this.openPromise = undefined; + }), + ); + + this.detachListeners.push( + onWebSocket(this.socket, "message", (...args) => { + this.handleSocketMessage(args); + }), + ); + + this.detachListeners.push( + onWebSocket(this.socket, "close", () => { + this.closeReadable(); + }), + ); + + this.detachListeners.push( + onWebSocket(this.socket, "error", (error) => { + this.errorReadable(error); + }), + ); + + this.stream = { + readable: new ReadableStream({ + start: (controller) => { + this.readableController = controller; + }, + cancel: () => { + this.close(); + }, + }), + writable: new WritableStream({ + write: async (message) => { + await this.sendMessage(message); + }, + close: () => { + this.close(); + }, + abort: () => { + this.close(); + }, + }), + }; + } + + private async sendMessage(message: AnyMessage): Promise { + if (this.isClosed) { + throw new Error("ACP WebSocket stream is closed"); + } + + await this.waitForOpen(); + + if (this.isClosed) { + throw new Error("ACP WebSocket stream is closed"); + } + + this.socket.send(JSON.stringify(message)); + } + + private async waitForOpen(): Promise { + if ( + this.socket.readyState === undefined || + this.socket.readyState === SOCKET_OPEN + ) { + return; + } + + await this.openPromise; + } + + private handleSocketMessage(args: unknown[]): void { + if (this.isClosed) { + return; + } + + const text = webSocketMessageToString(args); + if (text === undefined) { + return; + } + + let value: unknown; + try { + value = JSON.parse(text); + } catch (error) { + console.warn("Ignoring malformed ACP WebSocket JSON message:", error); + return; + } + + if (!isJsonRpcMessage(value)) { + console.warn("Ignoring non-JSON-RPC ACP WebSocket message:", value); + return; + } + + this.readableController?.enqueue(value); + } + + private close(): void { + this.closeSocket(); + this.closeReadable(); + } + + private closeSocket(): void { + try { + this.socket.close(); + } catch (error) { + console.warn("Failed to close ACP WebSocket:", error); + } + } + + private closeReadable(): void { + if (this.isClosed) { + return; + } + + this.isClosed = true; + + for (const detach of this.detachListeners.splice(0)) { + detach(); + } + + this.rejectOpen?.(new Error("ACP WebSocket stream closed before open")); + this.rejectOpen = undefined; + this.resolveOpen = undefined; + this.openPromise = undefined; + + try { + this.readableController?.close(); + } catch { + // Stream may already be closed/cancelled. + } + } + + private errorReadable(error: unknown): void { + if (this.isClosed) { + return; + } + + this.isClosed = true; + + for (const detach of this.detachListeners.splice(0)) { + detach(); + } + + this.rejectOpen?.(error); + this.rejectOpen = undefined; + this.resolveOpen = undefined; + this.openPromise = undefined; + + this.readableController?.error(error); + } +} + +function resolveWebSocket( + WebSocketCtor: WebSocketConstructor | undefined, +): WebSocketConstructor { + if (WebSocketCtor) { + return WebSocketCtor; + } + + if (typeof globalThis.WebSocket === "function") { + return globalThis.WebSocket as unknown as WebSocketConstructor; + } + + throw new Error( + "createWebSocketStream requires globalThis.WebSocket or options.WebSocket", + ); +} diff --git a/src/ws-utils.ts b/src/ws-utils.ts new file mode 100644 index 0000000..31256a6 --- /dev/null +++ b/src/ws-utils.ts @@ -0,0 +1,108 @@ +export interface WebSocketLike { + readonly readyState?: number; + send(data: string): void; + close(code?: number, reason?: string): void; + addEventListener?(type: string, listener: (event: unknown) => void): void; + removeEventListener?(type: string, listener: (event: unknown) => void): void; + on?(type: string, listener: (...args: unknown[]) => void): unknown; + off?(type: string, listener: (...args: unknown[]) => void): unknown; + removeListener?( + type: string, + listener: (...args: unknown[]) => void, + ): unknown; +} + +export function onWebSocket( + socket: WebSocketLike, + type: string, + listener: (...args: unknown[]) => void, +): () => void { + if (socket.addEventListener) { + const eventListener = (event: unknown): void => listener(event); + socket.addEventListener(type, eventListener); + + return () => { + socket.removeEventListener?.(type, eventListener); + }; + } + + if (socket.on) { + socket.on(type, listener); + + return () => { + if (socket.off) { + socket.off(type, listener); + return; + } + + socket.removeListener?.(type, listener); + }; + } + + throw new Error("WebSocket object does not support event listeners"); +} + +export function webSocketMessageToString(args: unknown[]): string | undefined { + if (args[1] === true || isBinaryMessageEvent(args[0])) { + return undefined; + } + + const data = extractMessageData(args); + + if (typeof data === "string") { + return data; + } + + if (data instanceof ArrayBuffer) { + return new TextDecoder().decode(data); + } + + if (ArrayBuffer.isView(data)) { + return new TextDecoder().decode(data); + } + + if (Array.isArray(data) && data.every(ArrayBuffer.isView)) { + return decodeArrayBufferViews(data); + } + + return undefined; +} + +function extractMessageData(args: unknown[]): unknown { + const [first] = args; + + if (isMessageEventLike(first)) { + return first.data; + } + + return first; +} + +function isMessageEventLike(value: unknown): value is { data: unknown } { + return typeof value === "object" && value !== null && "data" in value; +} + +function isBinaryMessageEvent(value: unknown): boolean { + return ( + typeof value === "object" && + value !== null && + "isBinary" in value && + value.isBinary === true + ); +} + +function decodeArrayBufferViews(views: ArrayBufferView[]): string { + const totalLength = views.reduce((sum, view) => sum + view.byteLength, 0); + const combined = new Uint8Array(totalLength); + let offset = 0; + + for (const view of views) { + combined.set( + new Uint8Array(view.buffer, view.byteOffset, view.byteLength), + offset, + ); + offset += view.byteLength; + } + + return new TextDecoder().decode(combined); +} From a311d3490f781b377fe680affe6278c7dd278a83 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 19 May 2026 10:32:57 +1000 Subject: [PATCH 09/19] Add docstrings --- src/http-stream.ts | 8 +++++--- src/server.ts | 12 ++++++++++++ src/ws-server.ts | 1 + src/ws-stream.ts | 19 ++++++++++--------- src/ws-utils.ts | 1 + 5 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/http-stream.ts b/src/http-stream.ts index f473735..ab432c6 100644 --- a/src/http-stream.ts +++ b/src/http-stream.ts @@ -14,15 +14,17 @@ import type { AnyMessage } from "./jsonrpc.js"; import type { Stream } from "./stream.js"; export interface HttpStreamOptions { + /** Fetch implementation to use. Defaults to `globalThis.fetch`. */ readonly fetch?: typeof globalThis.fetch; + /** Headers to include on every HTTP/SSE request. */ readonly headers?: Record; } /** - * Creates an ACP Stream that speaks the Streamable HTTP transport. + * Creates an ACP Stream over Streamable HTTP. * - * The transport uses HTTP POST for client-to-agent messages and SSE GET streams for agent-to-client messages. - * Cookie management is intentionally not built in; pass a cookie-aware fetch implementation when needed. + * Uses POST for client messages and SSE GET streams for server messages. + * Pass a custom `fetch` for cookies, auth, proxies, or non-browser runtimes. */ export function createHttpStream( serverUrl: string, diff --git a/src/server.ts b/src/server.ts index 85353fe..daf3602 100644 --- a/src/server.ts +++ b/src/server.ts @@ -26,10 +26,19 @@ import type { import type { Agent, AgentSideConnection } from "./acp.js"; import type { AnyMessage, AnyRequest, AnyResponse } from "./jsonrpc.js"; +/** Options for creating an ACP server transport. */ export interface AcpServerOptions { + /** Creates the agent implementation for each accepted ACP connection. */ createAgent: (conn: AgentSideConnection) => Agent; } +/** + * ACP server transport for Streamable HTTP and WebSocket connections. + * + * Route HTTP requests to {@link handleRequest}. For WebSocket upgrades, let your + * framework perform the upgrade and pass the accepted socket to + * {@link handleWebSocket}. + */ export class AcpServer { private readonly createAgent: (conn: AgentSideConnection) => Agent; private readonly registry = new ConnectionRegistry(); @@ -38,6 +47,7 @@ export class AcpServer { this.createAgent = options.createAgent; } + /** Handles one Streamable HTTP ACP request. */ async handleRequest(req: Request): Promise { if (req.method === "POST") { return await this.handlePost(req); @@ -54,6 +64,7 @@ export class AcpServer { return textResponse("Method Not Allowed", 405); } + /** Handles one accepted ACP WebSocket connection. */ handleWebSocket(socket: WebSocketServerSocket): void { handleWebSocketConnection(socket, { registry: this.registry, @@ -61,6 +72,7 @@ export class AcpServer { }); } + /** Closes all active ACP connections owned by this server. */ async close(): Promise { this.registry.closeAll(); } diff --git a/src/ws-server.ts b/src/ws-server.ts index fe7cc92..1490f28 100644 --- a/src/ws-server.ts +++ b/src/ws-server.ts @@ -18,6 +18,7 @@ import type { import type { AnyMessage, AnyRequest } from "./jsonrpc.js"; import type { WebSocketLike } from "./ws-utils.js"; +/** WebSocket shape accepted by `AcpServer.handleWebSocket`. */ export type WebSocketServerSocket = WebSocketLike; type ForwardResult = diff --git a/src/ws-stream.ts b/src/ws-stream.ts index 9d65012..0ebea76 100644 --- a/src/ws-stream.ts +++ b/src/ws-stream.ts @@ -5,18 +5,18 @@ import type { AnyMessage } from "./jsonrpc.js"; import type { Stream } from "./stream.js"; export interface WebSocketStreamOptions { - /** WebSocket subprotocols. */ + /** WebSocket subprotocols to request. */ readonly protocols?: string[]; /** - * Custom headers for runtimes/constructors that support them, such as Node.js - * `ws`. Browsers ignore custom headers because the browser WebSocket API does - * not expose a headers option. + * Headers for WebSocket constructors that support them, such as Node `ws`. + * Browser WebSocket constructors ignore custom headers. */ readonly headers?: Record; - /** Custom WebSocket constructor, for example `ws.WebSocket` in Node.js. */ + /** WebSocket constructor to use. Defaults to `globalThis.WebSocket`. */ readonly WebSocket?: WebSocketConstructor; } +/** Constructor shape used by `createWebSocketStream`. */ export interface WebSocketConstructor { new ( url: string, @@ -25,14 +25,15 @@ export interface WebSocketConstructor { ): WebSocketLike; } +export type { WebSocketLike }; + const SOCKET_OPEN = 1; /** - * Creates an ACP Stream that speaks JSON-RPC over WebSocket text frames. + * Creates an ACP Stream over WebSocket. * - * Browser WebSocket constructors do not support custom headers. The `headers` - * option is passed as a best-effort third constructor argument for runtimes such - * as Node.js `ws` that accept it. + * Sends and receives ACP JSON-RPC messages as WebSocket text frames. In Node, + * pass a WebSocket constructor such as `ws.WebSocket` via `options.WebSocket`. */ export function createWebSocketStream( serverUrl: string, diff --git a/src/ws-utils.ts b/src/ws-utils.ts index 31256a6..947586e 100644 --- a/src/ws-utils.ts +++ b/src/ws-utils.ts @@ -1,3 +1,4 @@ +/** Minimal browser/Node-compatible WebSocket shape used by ACP transports. */ export interface WebSocketLike { readonly readyState?: number; send(data: string): void; From 5203ef1c0272da6398d8b78d71d4e0a7067c5ce6 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 19 May 2026 19:20:21 +1000 Subject: [PATCH 10/19] Add examples and update imports --- package.json | 11 ++++ src/examples/README.md | 25 ++++++++ src/examples/http-client.ts | 69 ++++++++++++++++++++++ src/examples/http-server.ts | 115 ++++++++++++++++++++++++++++++++++++ src/examples/ws-client.ts | 72 ++++++++++++++++++++++ 5 files changed, 292 insertions(+) create mode 100644 src/examples/http-client.ts create mode 100644 src/examples/http-server.ts create mode 100644 src/examples/ws-client.ts diff --git a/package.json b/package.json index e56ebcd..474d83a 100644 --- a/package.json +++ b/package.json @@ -27,22 +27,27 @@ "exports": { ".": { "types": "./dist/acp.d.ts", + "import": "./dist/acp.js", "default": "./dist/acp.js" }, "./http-client": { "types": "./dist/http-stream.d.ts", + "import": "./dist/http-stream.js", "default": "./dist/http-stream.js" }, "./ws-client": { "types": "./dist/ws-stream.d.ts", + "import": "./dist/ws-stream.js", "default": "./dist/ws-stream.js" }, "./server": { "types": "./dist/server.d.ts", + "import": "./dist/server.js", "default": "./dist/server.js" }, "./node": { "types": "./dist/node-adapter.d.ts", + "import": "./dist/node-adapter.js", "default": "./dist/node-adapter.js" }, "./schema/schema.json": "./schema/schema.json" @@ -67,8 +72,14 @@ "docs:ts:verify": "cd src && typedoc --emit none && echo 'TypeDoc verification passed'" }, "peerDependencies": { + "ws": ">=8.0.0", "zod": "^3.25.0 || ^4.0.0" }, + "peerDependenciesMeta": { + "ws": { + "optional": true + } + }, "devDependencies": { "@eslint/js": "^10.0.1", "@hey-api/openapi-ts": "^0.97.0", diff --git a/src/examples/README.md b/src/examples/README.md index 2fe4f8a..d1c19a3 100644 --- a/src/examples/README.md +++ b/src/examples/README.md @@ -4,6 +4,9 @@ This directory contains examples using the [ACP](https://agentclientprotocol.com - [`agent.ts`](./agent.ts) - A minimal agent implementation that simulates LLM interaction - [`client.ts`](./client.ts) - A minimal client implementation that spawns the [`agent.ts`](./agent.ts) as a subprocess +- [`http-server.ts`](./http-server.ts) - A minimal ACP Streamable HTTP server with WebSocket upgrade support +- [`http-client.ts`](./http-client.ts) - A minimal client using `createHttpStream` +- [`ws-client.ts`](./ws-client.ts) - A minimal client using `createWebSocketStream` ## Running the Agent @@ -75,3 +78,25 @@ npx tsx src/examples/client.ts ``` This client will spawn the example agent as a subprocess, send a message, and print the content it receives from it. + +## Running the HTTP and WebSocket Examples + +Start the Streamable HTTP server with WebSocket upgrade support: + +```bash +npx tsx src/examples/http-server.ts +``` + +In another terminal, run the HTTP client: + +```bash +npx tsx src/examples/http-client.ts +``` + +Or run the WebSocket client: + +```bash +npx tsx src/examples/ws-client.ts +``` + +The HTTP example sends a bearer token through custom request headers. The WebSocket example passes the Node `ws` constructor so custom headers can be sent during the WebSocket handshake. Browser WebSocket clients can use `createWebSocketStream` too, but browsers do not allow custom WebSocket headers. diff --git a/src/examples/http-client.ts b/src/examples/http-client.ts new file mode 100644 index 0000000..514a276 --- /dev/null +++ b/src/examples/http-client.ts @@ -0,0 +1,69 @@ +#!/usr/bin/env node + +import * as acp from "@agentclientprotocol/sdk"; +import { createHttpStream } from "@agentclientprotocol/sdk/http-client"; + +class HttpExampleClient implements acp.Client { + async requestPermission( + params: acp.RequestPermissionRequest, + ): Promise { + return { + outcome: { + outcome: "selected", + optionId: params.options[0]?.optionId ?? "allow", + }, + }; + } + + async sessionUpdate(params: acp.SessionNotification): Promise { + const update = params.update; + + if (update.sessionUpdate === "agent_message_chunk") { + process.stdout.write( + update.content.type === "text" ? update.content.text : "", + ); + return; + } + + console.log(`[${update.sessionUpdate}]`); + } +} + +const serverUrl = process.env.ACP_HTTP_URL ?? "http://127.0.0.1:7331/acp"; +const stream = createHttpStream(serverUrl, { + headers: { + Authorization: "Bearer example-token", + }, + // To use cookies, pass a cookie-aware fetch implementation here instead of relying on a built-in cookie jar. + // fetch: cookieAwareFetch, +}); +const connection = new acp.ClientSideConnection( + (_agent) => new HttpExampleClient(), + stream, +); + +try { + await connection.initialize({ + protocolVersion: acp.PROTOCOL_VERSION, + clientCapabilities: {}, + }); + + const session = await connection.newSession({ + cwd: process.cwd(), + mcpServers: [], + }); + + const result = await connection.prompt({ + sessionId: session.sessionId, + prompt: [ + { + type: "text", + text: "Hello over Streamable HTTP", + }, + ], + }); + + console.log(`\nDone: ${result.stopReason}`); +} finally { + await stream.writable.close(); +} diff --git a/src/examples/http-server.ts b/src/examples/http-server.ts new file mode 100644 index 0000000..ad72c92 --- /dev/null +++ b/src/examples/http-server.ts @@ -0,0 +1,115 @@ +#!/usr/bin/env node + +import { createServer } from "node:http"; + +import { WebSocketServer } from "ws"; + +import * as acp from "@agentclientprotocol/sdk"; +import { createNodeHttpHandler } from "@agentclientprotocol/sdk/node"; +import { AcpServer } from "@agentclientprotocol/sdk/server"; + +class HttpExampleAgent implements acp.Agent { + private readonly connection: acp.AgentSideConnection; + private readonly sessions = new Set(); + + constructor(connection: acp.AgentSideConnection) { + this.connection = connection; + } + + async initialize( + _params: acp.InitializeRequest, + ): Promise { + return { + protocolVersion: acp.PROTOCOL_VERSION, + agentCapabilities: { + loadSession: false, + }, + }; + } + + async newSession( + _params: acp.NewSessionRequest, + ): Promise { + const sessionId = crypto.randomUUID(); + this.sessions.add(sessionId); + return { sessionId }; + } + + async authenticate( + _params: acp.AuthenticateRequest, + ): Promise { + return {}; + } + + async prompt(params: acp.PromptRequest): Promise { + if (!this.sessions.has(params.sessionId)) { + throw new Error(`Session ${params.sessionId} not found`); + } + + await this.connection.sessionUpdate({ + sessionId: params.sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { + type: "text", + text: "Hello from the ACP HTTP/WebSocket example server.", + }, + }, + }); + + return { stopReason: "end_turn" }; + } + + async cancel(_params: acp.CancelNotification): Promise {} +} + +const acpServer = new AcpServer({ + createAgent: (connection) => new HttpExampleAgent(connection), +}); +const acpHttpHandler = createNodeHttpHandler(acpServer); +const webSocketServer = new WebSocketServer({ noServer: true }); +const port = Number.parseInt(process.env.PORT ?? "7331", 10); + +const httpServer = createServer((req, res) => { + if (!isAcpPath(req.url)) { + res.writeHead(404, { "Content-Type": "text/plain" }); + res.end("Not Found"); + return; + } + + // Put authentication or tenant-selection middleware here before routing to AcpServer. + // For example, validate `req.headers.authorization` and reject unauthorized requests. + if (!isAuthorized(req.headers.authorization)) { + res.writeHead(401, { "Content-Type": "text/plain" }); + res.end("Unauthorized"); + return; + } + + acpHttpHandler(req, res); +}); + +httpServer.on("upgrade", (req, socket, head) => { + if (!isAcpPath(req.url) || !isAuthorized(req.headers.authorization)) { + socket.destroy(); + return; + } + + webSocketServer.handleUpgrade(req, socket, head, (ws) => { + acpServer.handleWebSocket(ws); + }); +}); + +httpServer.listen(port, () => { + console.log(`ACP HTTP endpoint listening at http://127.0.0.1:${port}/acp`); + console.log(`ACP WebSocket endpoint listening at ws://127.0.0.1:${port}/acp`); +}); + +function isAcpPath(url: string | undefined): boolean { + return new URL(url ?? "/", "http://127.0.0.1").pathname === "/acp"; +} + +function isAuthorized(authorization: string | undefined): boolean { + return ( + authorization === undefined || authorization === "Bearer example-token" + ); +} diff --git a/src/examples/ws-client.ts b/src/examples/ws-client.ts new file mode 100644 index 0000000..ddddd0e --- /dev/null +++ b/src/examples/ws-client.ts @@ -0,0 +1,72 @@ +#!/usr/bin/env node + +import { WebSocket } from "ws"; + +import * as acp from "@agentclientprotocol/sdk"; +import { createWebSocketStream } from "@agentclientprotocol/sdk/ws-client"; +import type { WebSocketConstructor } from "@agentclientprotocol/sdk/ws-client"; + +class WebSocketExampleClient implements acp.Client { + async requestPermission( + params: acp.RequestPermissionRequest, + ): Promise { + return { + outcome: { + outcome: "selected", + optionId: params.options[0]?.optionId ?? "allow", + }, + }; + } + + async sessionUpdate(params: acp.SessionNotification): Promise { + const update = params.update; + + if (update.sessionUpdate === "agent_message_chunk") { + process.stdout.write( + update.content.type === "text" ? update.content.text : "", + ); + return; + } + + console.log(`[${update.sessionUpdate}]`); + } +} + +const serverUrl = process.env.ACP_WS_URL ?? "ws://127.0.0.1:7331/acp"; +const stream = createWebSocketStream(serverUrl, { + WebSocket: WebSocket satisfies WebSocketConstructor, + // Custom headers work with Node's `ws` constructor. Browser WebSocket does not support custom headers. + headers: { + Authorization: "Bearer example-token", + }, +}); +const connection = new acp.ClientSideConnection( + (_agent) => new WebSocketExampleClient(), + stream, +); + +try { + await connection.initialize({ + protocolVersion: acp.PROTOCOL_VERSION, + clientCapabilities: {}, + }); + + const session = await connection.newSession({ + cwd: process.cwd(), + mcpServers: [], + }); + + const result = await connection.prompt({ + sessionId: session.sessionId, + prompt: [ + { + type: "text", + text: "Hello over WebSocket", + }, + ], + }); + + console.log(`\nDone: ${result.stopReason}`); +} finally { + await stream.writable.close(); +} From 1fae58f36f2604b61f314eb15e9fae0eea3b11b5 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 19 May 2026 20:04:46 +1000 Subject: [PATCH 11/19] Add RFD-compliant websocket upgrade handling --- src/examples/http-server.ts | 13 +++-- src/node-adapter.ts | 52 ++++++++++++++++++++ src/server.ts | 46 ++++++++++++++---- src/test-support/test-http-server.ts | 14 +++--- src/ws-server.ts | 33 +++++++++---- src/ws-stream.test.ts | 71 ++++++++++++++++++++++++++++ 6 files changed, 200 insertions(+), 29 deletions(-) diff --git a/src/examples/http-server.ts b/src/examples/http-server.ts index ad72c92..6fb85b0 100644 --- a/src/examples/http-server.ts +++ b/src/examples/http-server.ts @@ -5,7 +5,10 @@ import { createServer } from "node:http"; import { WebSocketServer } from "ws"; import * as acp from "@agentclientprotocol/sdk"; -import { createNodeHttpHandler } from "@agentclientprotocol/sdk/node"; +import { + createNodeHttpHandler, + createNodeWebSocketUpgradeHandler, +} from "@agentclientprotocol/sdk/node"; import { AcpServer } from "@agentclientprotocol/sdk/server"; class HttpExampleAgent implements acp.Agent { @@ -68,6 +71,10 @@ const acpServer = new AcpServer({ }); const acpHttpHandler = createNodeHttpHandler(acpServer); const webSocketServer = new WebSocketServer({ noServer: true }); +const acpWebSocketUpgradeHandler = createNodeWebSocketUpgradeHandler( + acpServer, + webSocketServer, +); const port = Number.parseInt(process.env.PORT ?? "7331", 10); const httpServer = createServer((req, res) => { @@ -94,9 +101,7 @@ httpServer.on("upgrade", (req, socket, head) => { return; } - webSocketServer.handleUpgrade(req, socket, head, (ws) => { - acpServer.handleWebSocket(ws); - }); + acpWebSocketUpgradeHandler(req, socket, head); }); httpServer.listen(port, () => { diff --git a/src/node-adapter.ts b/src/node-adapter.ts index a92d7c0..f4a6801 100644 --- a/src/node-adapter.ts +++ b/src/node-adapter.ts @@ -1,5 +1,8 @@ +import { HEADER_CONNECTION_ID } from "./protocol.js"; import type { IncomingMessage, ServerResponse } from "node:http"; +import type { Duplex } from "node:stream"; import type { AcpServer } from "./server.js"; +import type { WebSocketServer } from "ws"; export function createNodeHttpHandler( server: AcpServer, @@ -9,6 +12,55 @@ export function createNodeHttpHandler( }; } +export function createNodeWebSocketUpgradeHandler( + server: AcpServer, + webSocketServer: WebSocketServer, +): (req: IncomingMessage, socket: Duplex, head: Buffer) => void { + return (req, socket, head) => { + const upgrade = server.prepareWebSocketUpgrade(); + let hasAccepted = false; + + const cleanup = (): void => { + webSocketServer.off("headers", onHeaders); + socket.off("close", onUpgradeFailed); + socket.off("error", onUpgradeFailed); + }; + + const onHeaders = (headers: string[], request: IncomingMessage): void => { + if (request !== req) { + return; + } + + headers.push(`${HEADER_CONNECTION_ID}: ${upgrade.connectionId}`); + }; + + const onUpgradeFailed = (): void => { + if (hasAccepted) { + return; + } + + cleanup(); + upgrade.reject(); + }; + + webSocketServer.on("headers", onHeaders); + socket.once("close", onUpgradeFailed); + socket.once("error", onUpgradeFailed); + + try { + webSocketServer.handleUpgrade(req, socket, head, (webSocket) => { + hasAccepted = true; + cleanup(); + upgrade.accept(webSocket); + }); + } catch (error) { + cleanup(); + upgrade.reject(); + throw error; + } + }; +} + async function handleNodeRequest( server: AcpServer, req: IncomingMessage, diff --git a/src/server.ts b/src/server.ts index daf3602..cd327a5 100644 --- a/src/server.ts +++ b/src/server.ts @@ -32,12 +32,18 @@ export interface AcpServerOptions { createAgent: (conn: AgentSideConnection) => Agent; } +export interface PreparedWebSocketUpgrade { + readonly connectionId: string; + accept(socket: WebSocketServerSocket): void; + reject(): void; +} + /** * ACP server transport for Streamable HTTP and WebSocket connections. * - * Route HTTP requests to {@link handleRequest}. For WebSocket upgrades, let your - * framework perform the upgrade and pass the accepted socket to - * {@link handleWebSocket}. + * Route HTTP requests to {@link handleRequest}. For WebSocket upgrades, use + * {@link prepareWebSocketUpgrade} so adapters can attach `Acp-Connection-Id` to + * the `101 Switching Protocols` response. */ export class AcpServer { private readonly createAgent: (conn: AgentSideConnection) => Agent; @@ -64,12 +70,34 @@ export class AcpServer { return textResponse("Method Not Allowed", 405); } - /** Handles one accepted ACP WebSocket connection. */ - handleWebSocket(socket: WebSocketServerSocket): void { - handleWebSocketConnection(socket, { - registry: this.registry, - createAgent: this.createAgent, - }); + /** Creates a WebSocket connection before accepting the HTTP upgrade. */ + prepareWebSocketUpgrade(): PreparedWebSocketUpgrade { + const connection = this.registry.createConnection(this.createAgent); + let isSettled = false; + + return { + connectionId: connection.connectionId, + accept: (socket) => { + if (isSettled) { + throw new Error("ACP WebSocket upgrade has already been settled"); + } + + isSettled = true; + handleWebSocketConnection(socket, { + registry: this.registry, + createAgent: this.createAgent, + connection, + }); + }, + reject: () => { + if (isSettled) { + return; + } + + isSettled = true; + this.registry.remove(connection.connectionId); + }, + }; } /** Closes all active ACP connections owned by this server. */ diff --git a/src/test-support/test-http-server.ts b/src/test-support/test-http-server.ts index 96b746b..ccb3e42 100644 --- a/src/test-support/test-http-server.ts +++ b/src/test-support/test-http-server.ts @@ -2,7 +2,10 @@ import http from "node:http"; import { WebSocketServer } from "ws"; import { AcpServer } from "../server.js"; -import { createNodeHttpHandler } from "../node-adapter.js"; +import { + createNodeHttpHandler, + createNodeWebSocketUpgradeHandler, +} from "../node-adapter.js"; import { TestAgent } from "./test-agent.js"; import type { AddressInfo } from "node:net"; @@ -23,11 +26,10 @@ export async function startTestServer( const httpServer = http.createServer(createNodeHttpHandler(acpServer)); const webSocketServer = new WebSocketServer({ noServer: true }); - httpServer.on("upgrade", (req, socket, head) => { - webSocketServer.handleUpgrade(req, socket, head, (webSocket) => { - acpServer.handleWebSocket(webSocket); - }); - }); + httpServer.on( + "upgrade", + createNodeWebSocketUpgradeHandler(acpServer, webSocketServer), + ); await listen(httpServer, options.port ?? 0); diff --git a/src/ws-server.ts b/src/ws-server.ts index 1490f28..1cc0a30 100644 --- a/src/ws-server.ts +++ b/src/ws-server.ts @@ -18,7 +18,7 @@ import type { import type { AnyMessage, AnyRequest } from "./jsonrpc.js"; import type { WebSocketLike } from "./ws-utils.js"; -/** WebSocket shape accepted by `AcpServer.handleWebSocket`. */ +/** WebSocket shape accepted by prepared ACP WebSocket upgrades. */ export type WebSocketServerSocket = WebSocketLike; type ForwardResult = @@ -33,6 +33,7 @@ type ForwardResult = export interface WebSocketConnectionOptions { readonly registry: ConnectionRegistry; readonly createAgent: (conn: AgentSideConnection) => Agent; + readonly connection?: ConnectionState; } export function handleWebSocketConnection( @@ -45,6 +46,7 @@ export function handleWebSocketConnection( class WebSocketServerSession { private connection: ConnectionState | undefined; + private preparedConnection: ConnectionState | undefined; private outboundReader: ReadableStreamDefaultReader | undefined; private inboundWriteChain: Promise = Promise.resolve(); private isClosed = false; @@ -53,7 +55,9 @@ class WebSocketServerSession { constructor( private readonly socket: WebSocketLike, private readonly options: WebSocketConnectionOptions, - ) {} + ) { + this.preparedConnection = options.connection; + } start(): void { this.detachListeners.push( @@ -135,26 +139,30 @@ class WebSocketServerSession { return; } - let connection: ConnectionState | undefined; + const connection = + this.preparedConnection ?? + this.options.registry.createConnection(this.options.createAgent); + this.preparedConnection = connection; try { - connection = this.options.registry.createConnection( - this.options.createAgent, - ); - await writeInbound(connection, message); const initialResponse = await connection.recvInitial(message.id); + if (this.isClosed) { + this.options.registry.remove(connection.connectionId); + return; + } + + this.preparedConnection = undefined; this.connection = connection; connection.startRouter(); this.send(initialResponse); this.startOutboundPump(connection); } catch (error) { - if (connection) { - this.options.registry.remove(connection.connectionId); - } + this.preparedConnection = undefined; + this.options.registry.remove(connection.connectionId); this.send({ jsonrpc: "2.0", @@ -323,6 +331,11 @@ class WebSocketServerSession { this.options.registry.remove(this.connection.connectionId); this.connection = undefined; } + + if (this.preparedConnection) { + this.options.registry.remove(this.preparedConnection.connectionId); + this.preparedConnection = undefined; + } } } diff --git a/src/ws-stream.test.ts b/src/ws-stream.test.ts index 44a6b7d..01acbd6 100644 --- a/src/ws-stream.test.ts +++ b/src/ws-stream.test.ts @@ -2,10 +2,12 @@ import { describe, expect, it } from "vitest"; import { WebSocket } from "ws"; import { ClientSideConnection, PROTOCOL_VERSION } from "./acp.js"; +import { HEADER_CONNECTION_ID } from "./protocol.js"; import { createWebSocketStream } from "./ws-stream.js"; import { TestAgent } from "./test-support/test-agent.js"; import { startTestServer } from "./test-support/test-http-server.js"; +import type { IncomingMessage } from "node:http"; import type { AgentSideConnection, Client, @@ -40,7 +42,76 @@ const initializeResponse = { }, } satisfies AnyMessage; +const sessionNewRequest = { + jsonrpc: "2.0", + id: 1, + method: "session/new", + params: { + cwd: "/tmp", + mcpServers: [], + }, +} satisfies AnyMessage; + describe("createWebSocketStream", () => { + it("exposes the ACP connection ID during the WebSocket handshake", async () => { + const server = await startTestServer(); + const socket = new WebSocket(server.wsUrl); + const upgrade = new Promise((resolve, reject) => { + socket.once("upgrade", resolve); + socket.once("error", reject); + }); + + try { + const request = await upgrade; + expect(request.headers[HEADER_CONNECTION_ID.toLowerCase()]).toMatch( + /^[0-9a-f-]{36}$/, + ); + } finally { + socket.close(); + await server.close(); + } + }); + + it("closes pre-created WebSocket connections when the first frame is not initialize", async () => { + const server = await startTestServer(); + const socket = new WebSocket(server.wsUrl); + const upgrade = new Promise((resolve, reject) => { + socket.once("upgrade", resolve); + socket.once("error", reject); + }); + const close = new Promise<{ code: number; reason: string }>((resolve) => { + socket.once("close", (code: number, reason: Buffer) => { + resolve({ code, reason: reason.toString("utf8") }); + }); + }); + + try { + const request = await upgrade; + const connectionId = request.headers[HEADER_CONNECTION_ID.toLowerCase()]; + expect(connectionId).toMatch(/^[0-9a-f-]{36}$/); + + socket.send(JSON.stringify(sessionNewRequest)); + + await expect(close).resolves.toEqual({ + code: 1002, + reason: "First message must be initialize", + }); + + const response = await fetch(server.url, { + method: "GET", + headers: { + Accept: "text/event-stream", + [HEADER_CONNECTION_ID]: String(connectionId), + }, + }); + + expect(response.status).toBe(404); + } finally { + socket.close(); + await server.close(); + } + }); + it("uses the custom WebSocket constructor and queues writes until the socket opens", async () => { const instances: FakeWebSocket[] = []; const stream = createWebSocketStream("ws://agent.example/acp", { From a75be7671c01faf00749fe531ed574bf32adc65e Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 19 May 2026 20:47:51 +1000 Subject: [PATCH 12/19] Only accept text frames or Node-style buffers with isBinary --- src/ws-stream.test.ts | 49 +++++++++++++++++++++++++++++++ src/ws-utils.test.ts | 67 +++++++++++++++++++++++++++++++++++++++++++ src/ws-utils.ts | 61 +++++++++++++++++++++++++-------------- 3 files changed, 155 insertions(+), 22 deletions(-) create mode 100644 src/ws-utils.test.ts diff --git a/src/ws-stream.test.ts b/src/ws-stream.test.ts index 01acbd6..cf71e26 100644 --- a/src/ws-stream.test.ts +++ b/src/ws-stream.test.ts @@ -112,6 +112,49 @@ describe("createWebSocketStream", () => { } }); + it("ignores binary WebSocket initialize frames", async () => { + const server = await startTestServer(); + const socket = new WebSocket(server.wsUrl); + const upgrade = new Promise((resolve, reject) => { + socket.once("upgrade", resolve); + socket.once("error", reject); + }); + const close = new Promise<{ code: number; reason: string }>((resolve) => { + socket.once("close", (code: number, reason: Buffer) => { + resolve({ code, reason: reason.toString("utf8") }); + }); + }); + + try { + const request = await upgrade; + const connectionId = request.headers[HEADER_CONNECTION_ID.toLowerCase()]; + expect(connectionId).toMatch(/^[0-9a-f-]{36}$/); + + socket.send(Buffer.from(JSON.stringify(initializeRequest)), { + binary: true, + }); + socket.send(JSON.stringify(sessionNewRequest)); + + await expect(close).resolves.toEqual({ + code: 1002, + reason: "First message must be initialize", + }); + + const response = await fetch(server.url, { + method: "GET", + headers: { + Accept: "text/event-stream", + [HEADER_CONNECTION_ID]: String(connectionId), + }, + }); + + expect(response.status).toBe(404); + } finally { + socket.close(); + await server.close(); + } + }); + it("uses the custom WebSocket constructor and queues writes until the socket opens", async () => { const instances: FakeWebSocket[] = []; const stream = createWebSocketStream("ws://agent.example/acp", { @@ -158,6 +201,12 @@ describe("createWebSocketStream", () => { const socket = fakeSocketAt(instances, 0); socket.open(); socket.receive(new Uint8Array([1, 2, 3]), true); + socket.receive( + new TextEncoder().encode(JSON.stringify(initializeResponse)), + ); + socket.receive( + new TextEncoder().encode(JSON.stringify(initializeResponse)).buffer, + ); socket.receive("not json"); socket.receive(JSON.stringify({ hello: "world" })); socket.receive(JSON.stringify(initializeResponse)); diff --git a/src/ws-utils.test.ts b/src/ws-utils.test.ts new file mode 100644 index 0000000..af40e7f --- /dev/null +++ b/src/ws-utils.test.ts @@ -0,0 +1,67 @@ +import { describe, expect, it } from "vitest"; + +import { onWebSocket, webSocketMessageToString } from "./ws-utils.js"; + +describe("webSocketMessageToString", () => { + it("accepts only WebSocket text message payloads", () => { + expect(webSocketMessageToString(["text"])).toBe("text"); + expect(webSocketMessageToString([{ data: "event text" }])).toBe( + "event text", + ); + expect( + webSocketMessageToString([new TextEncoder().encode("binary text")]), + ).toBe(undefined); + expect( + webSocketMessageToString([ + new TextEncoder().encode("binary text").buffer, + ]), + ).toBe(undefined); + expect( + webSocketMessageToString([[new TextEncoder().encode("binary text")]]), + ).toBe(undefined); + }); +}); + +describe("onWebSocket", () => { + it("normalizes Node ws text frames before shared message parsing", () => { + const socket = new EventEmitterWebSocket(); + const messages: Array = []; + + onWebSocket(socket, "message", (...args) => { + messages.push(webSocketMessageToString(args)); + }); + + socket.emit("message", new TextEncoder().encode("text frame"), false); + socket.emit("message", new TextEncoder().encode("binary frame"), true); + + expect(messages).toEqual(["text frame", undefined]); + }); +}); + +class EventEmitterWebSocket { + private readonly listeners = new Map< + string, + Set<(...args: unknown[]) => void> + >(); + + send(): void {} + + close(): void {} + + on(type: string, listener: (...args: unknown[]) => void): void { + this.listeners.set( + type, + (this.listeners.get(type) ?? new Set()).add(listener), + ); + } + + off(type: string, listener: (...args: unknown[]) => void): void { + this.listeners.get(type)?.delete(listener); + } + + emit(type: string, ...args: unknown[]): void { + this.listeners.get(type)?.forEach((listener) => { + listener(...args); + }); + } +} diff --git a/src/ws-utils.ts b/src/ws-utils.ts index 947586e..a77082b 100644 --- a/src/ws-utils.ts +++ b/src/ws-utils.ts @@ -18,6 +18,22 @@ export function onWebSocket( type: string, listener: (...args: unknown[]) => void, ): () => void { + if (socket.on) { + const eventListener = (...args: unknown[]): void => { + listener(...normalizeEventEmitterMessageArgs(type, args)); + }; + socket.on(type, eventListener); + + return () => { + if (socket.off) { + socket.off(type, eventListener); + return; + } + + socket.removeListener?.(type, eventListener); + }; + } + if (socket.addEventListener) { const eventListener = (event: unknown): void => listener(event); socket.addEventListener(type, eventListener); @@ -27,29 +43,35 @@ export function onWebSocket( }; } - if (socket.on) { - socket.on(type, listener); + throw new Error("WebSocket object does not support event listeners"); +} - return () => { - if (socket.off) { - socket.off(type, listener); - return; - } +export function webSocketMessageToString(args: unknown[]): string | undefined { + const data = extractMessageData(args); - socket.removeListener?.(type, listener); - }; + if (typeof data === "string") { + return data; } - throw new Error("WebSocket object does not support event listeners"); + return undefined; } -export function webSocketMessageToString(args: unknown[]): string | undefined { - if (args[1] === true || isBinaryMessageEvent(args[0])) { - return undefined; +function normalizeEventEmitterMessageArgs( + type: string, + args: unknown[], +): unknown[] { + if (type !== "message" || typeof args[1] !== "boolean") { + return args; } - const data = extractMessageData(args); + if (args[1]) { + return [undefined]; + } + + return [decodeWebSocketTextData(args[0])]; +} +function decodeWebSocketTextData(data: unknown): string | undefined { if (typeof data === "string") { return data; } @@ -62,7 +84,7 @@ export function webSocketMessageToString(args: unknown[]): string | undefined { return new TextDecoder().decode(data); } - if (Array.isArray(data) && data.every(ArrayBuffer.isView)) { + if (isArrayBufferViewArray(data)) { return decodeArrayBufferViews(data); } @@ -83,13 +105,8 @@ function isMessageEventLike(value: unknown): value is { data: unknown } { return typeof value === "object" && value !== null && "data" in value; } -function isBinaryMessageEvent(value: unknown): boolean { - return ( - typeof value === "object" && - value !== null && - "isBinary" in value && - value.isBinary === true - ); +function isArrayBufferViewArray(value: unknown): value is ArrayBufferView[] { + return Array.isArray(value) && value.every(ArrayBuffer.isView); } function decodeArrayBufferViews(views: ArrayBufferView[]): string { From af3f7cc2ac7d2d7a57eda43fe86a8abe5744e935 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 19 May 2026 20:56:37 +1000 Subject: [PATCH 13/19] Tighten up POST request validation and JSON content-type validation --- src/server-session-sse.test.ts | 75 +++++++++++++++++++++++++--------- src/server.test.ts | 36 +++++++++++++++- src/server.ts | 37 ++++++++++++----- 3 files changed, 116 insertions(+), 32 deletions(-) diff --git a/src/server-session-sse.test.ts b/src/server-session-sse.test.ts index c30c775..37141d0 100644 --- a/src/server-session-sse.test.ts +++ b/src/server-session-sse.test.ts @@ -44,6 +44,19 @@ function createPromptRequest(id: number, sessionId?: string) { }; } +function createForkRequest(id: number, sessionId: string) { + return { + jsonrpc: "2.0", + id, + method: "session/fork", + params: { + cwd: "/tmp", + mcpServers: [], + sessionId, + }, + }; +} + describe("AcpServer session SSE", () => { it("streams prompt updates and responses on the session SSE stream", async () => { const server = await startTestServer( @@ -114,18 +127,13 @@ describe("AcpServer session SSE", () => { } }); - it("routes session prompts using params.sessionId when the session header is absent", async () => { + it("rejects session-scoped requests without a session header", async () => { const server = await startTestServer(); try { const connectionId = await initialize(server.url); const sessionId = await createSession(server.url, connectionId); - const sessionSse = await openSessionSse( - server.url, - connectionId, - sessionId, - ); - const accepted = await postJson( + const response = await postJson( server.url, createPromptRequest(3, sessionId), { @@ -133,25 +141,34 @@ describe("AcpServer session SSE", () => { }, ); - expect(accepted.status).toBe(202); - expect(await readSseMessages(sessionSse, 2)).toMatchObject([ - { - jsonrpc: "2.0", - method: "session/update", - params: { sessionId }, - }, + expect(response.status).toBe(400); + } finally { + await server.close(); + } + }); + + it("rejects session-scoped requests with mismatched session header and params", async () => { + const server = await startTestServer(); + + try { + const connectionId = await initialize(server.url); + const sessionId = await createSession(server.url, connectionId); + const response = await postJson( + server.url, + createPromptRequest(3, "other-session"), { - jsonrpc: "2.0", - id: 3, - result: { stopReason: "end_turn" }, + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, }, - ]); + ); + + expect(response.status).toBe(400); } finally { await server.close(); } }); - it("rejects session-scoped requests without a session identifier", async () => { + it("rejects session-scoped requests without any session identifier", async () => { const server = await startTestServer(); try { @@ -166,6 +183,26 @@ describe("AcpServer session SSE", () => { } }); + it("routes non-required session methods using params.sessionId when the session header is absent", async () => { + const server = await startTestServer(); + + try { + const connectionId = await initialize(server.url); + const sessionId = await createSession(server.url, connectionId); + const response = await postJson( + server.url, + createForkRequest(3, sessionId), + { + [HEADER_CONNECTION_ID]: connectionId, + }, + ); + + expect(response.status).toBe(202); + } finally { + await server.close(); + } + }); + it("replays buffered session messages when session SSE attaches after prompt", async () => { const server = await startTestServer(); diff --git a/src/server.test.ts b/src/server.test.ts index e2b37bc..14458f9 100644 --- a/src/server.test.ts +++ b/src/server.test.ts @@ -275,18 +275,50 @@ describe("AcpServer", () => { } }); - it("rejects POST without application/json Content-Type", async () => { + it("accepts POST with application/json Content-Type parameters", async () => { const server = await startTestServer(); try { const response = await fetch(server.url, { method: "POST", headers: { - "Content-Type": "text/plain", + "Content-Type": "application/json; charset=utf-8", }, body: JSON.stringify(initializeRequest), }); + expect(response.status).toBe(200); + } finally { + await server.close(); + } + }); + + it("rejects POST without Content-Type", async () => { + const server = await startTestServer(); + + try { + const response = await fetch(server.url, { method: "POST" }); + + expect(response.status).toBe(415); + } finally { + await server.close(); + } + }); + + it.each([ + "text/plain", + "application/jsonfoobar", + "application/json-patch+json", + ])("rejects POST with %s Content-Type", async (contentType) => { + const server = await startTestServer(); + + try { + const response = await fetch(server.url, { + method: "POST", + headers: { "Content-Type": contentType }, + body: JSON.stringify(initializeRequest), + }); + expect(response.status).toBe(415); } finally { await server.close(); diff --git a/src/server.ts b/src/server.ts index cd327a5..5277f50 100644 --- a/src/server.ts +++ b/src/server.ts @@ -108,7 +108,7 @@ export class AcpServer { private async handlePost(req: Request): Promise { const contentType = req.headers.get("Content-Type"); - if (!contentType?.startsWith(JSON_MIME_TYPE)) { + if (!isJsonContentType(contentType)) { return textResponse("Unsupported Media Type", 415); } @@ -364,28 +364,39 @@ function determineRoute( headers: Headers, ): RouteResult { const headerSessionId = headers.get(HEADER_SESSION_ID); + const paramsSessionId = sessionIdFromParams(message.params); - if (headerSessionId) { + if (methodRequiresSessionHeader(message.method) && !headerSessionId) { return { - ok: true, - value: { session: headerSessionId }, + ok: false, + status: 400, + message: "Missing Acp-Session-Id", }; } - const paramsSessionId = sessionIdFromParams(message.params); + if ( + headerSessionId !== null && + paramsSessionId !== undefined && + headerSessionId !== paramsSessionId + ) { + return { + ok: false, + status: 400, + message: "Mismatched Acp-Session-Id", + }; + } - if (paramsSessionId) { + if (headerSessionId) { return { ok: true, - value: { session: paramsSessionId }, + value: { session: headerSessionId }, }; } - if (methodRequiresSessionHeader(message.method)) { + if (paramsSessionId) { return { - ok: false, - status: 400, - message: "Missing Acp-Session-Id", + ok: true, + value: { session: paramsSessionId }, }; } @@ -395,6 +406,10 @@ function determineRoute( }; } +function isJsonContentType(contentType: string | null): boolean { + return contentType?.split(";", 1)[0]?.trim().toLowerCase() === JSON_MIME_TYPE; +} + function sseResponse(subscription: OutboundSubscription): Response { return new Response(createSseBody(subscription), { status: 200, From 82b1c0ee986bf8a6e7d5f8174d627b5b442cdc76 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 19 May 2026 21:08:36 +1000 Subject: [PATCH 14/19] fix: route session/load resume streams correctly --- src/server-session-sse.test.ts | 104 ++++++++++++++++++++++++++++++++- src/server.test.ts | 4 +- src/server.ts | 20 ++++--- 3 files changed, 118 insertions(+), 10 deletions(-) diff --git a/src/server-session-sse.test.ts b/src/server-session-sse.test.ts index 37141d0..048665d 100644 --- a/src/server-session-sse.test.ts +++ b/src/server-session-sse.test.ts @@ -5,11 +5,18 @@ import { HEADER_SESSION_ID, JSON_MIME_TYPE, } from "./protocol.js"; +import { PROTOCOL_VERSION } from "./schema/index.js"; import { parseSseStream } from "./sse.js"; import { TestAgent } from "./test-support/test-agent.js"; import { startTestServer } from "./test-support/test-http-server.js"; -import type { AgentSideConnection } from "./acp.js"; +import type { + AgentSideConnection, + InitializeRequest, + InitializeResponse, + LoadSessionRequest, + LoadSessionResponse, +} from "./acp.js"; import type { AnyMessage } from "./jsonrpc.js"; const initializeRequest = { @@ -57,6 +64,49 @@ function createForkRequest(id: number, sessionId: string) { }; } +function createLoadSessionRequest(id: number, sessionId: string) { + return { + jsonrpc: "2.0", + id, + method: "session/load", + params: { + cwd: "/tmp", + mcpServers: [], + sessionId, + }, + }; +} + +class LoadSessionAgent extends TestAgent { + constructor(private readonly agentConnection: AgentSideConnection) { + super(agentConnection); + } + + initialize(_params: InitializeRequest): Promise { + return Promise.resolve({ + protocolVersion: PROTOCOL_VERSION, + agentCapabilities: { + loadSession: true, + }, + }); + } + + async loadSession(params: LoadSessionRequest): Promise { + await this.agentConnection.sessionUpdate({ + sessionId: params.sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { + type: "text", + text: "replayed-session-history", + }, + }, + }); + + return {}; + } +} + describe("AcpServer session SSE", () => { it("streams prompt updates and responses on the session SSE stream", async () => { const server = await startTestServer( @@ -203,6 +253,58 @@ describe("AcpServer session SSE", () => { } }); + it("routes session/load replay updates to session SSE and final response to connection SSE", async () => { + const server = await startTestServer( + (conn: AgentSideConnection) => new LoadSessionAgent(conn), + ); + + try { + const connectionId = await initialize(server.url); + const sessionId = "existing-session"; + const connectionSse = await openConnectionSse(server.url, connectionId); + const sessionSse = await openSessionSse( + server.url, + connectionId, + sessionId, + ); + const accepted = await postJson( + server.url, + createLoadSessionRequest(3, sessionId), + { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + ); + + expect(sessionSse.status).toBe(200); + expect(accepted.status).toBe(202); + expect(await readSseMessages(sessionSse, 1)).toMatchObject([ + { + jsonrpc: "2.0", + method: "session/update", + params: { + sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { + text: "replayed-session-history", + }, + }, + }, + }, + ]); + expect(await readSseMessages(connectionSse, 1)).toMatchObject([ + { + jsonrpc: "2.0", + id: 3, + result: {}, + }, + ]); + } finally { + await server.close(); + } + }); + it("replays buffered session messages when session SSE attaches after prompt", async () => { const server = await startTestServer(); diff --git a/src/server.test.ts b/src/server.test.ts index 14458f9..a78757e 100644 --- a/src/server.test.ts +++ b/src/server.test.ts @@ -176,7 +176,7 @@ describe("AcpServer", () => { } }); - it("rejects session-scoped GETs for unknown sessions", async () => { + it("opens session-scoped GETs for sessions without local streams", async () => { const server = await startTestServer(); try { @@ -187,7 +187,7 @@ describe("AcpServer", () => { globalThis.crypto.randomUUID(), ); - expect(response.status).toBe(404); + expect(response.status).toBe(200); } finally { await server.close(); } diff --git a/src/server.ts b/src/server.ts index 5277f50..ca4ca72 100644 --- a/src/server.ts +++ b/src/server.ts @@ -14,6 +14,7 @@ import { isRequestMessage, isResponseMessage, } from "./jsonrpc.js"; +import { AGENT_METHODS } from "./schema/index.js"; import { serializeSseEvent, serializeSseKeepAlive } from "./sse.js"; import { handleWebSocketConnection } from "./ws-server.js"; import type { WebSocketServerSocket } from "./ws-server.js"; @@ -179,12 +180,7 @@ export class AcpServer { const sessionId = req.headers.get(HEADER_SESSION_ID); if (sessionId) { - const sessionStream = connection.sessionStreams.get(sessionId); - if (!sessionStream) { - return textResponse("Unknown Acp-Session-Id", 404); - } - - return sseResponse(sessionStream.subscribe()); + return sseResponse(connection.ensureSession(sessionId).subscribe()); } return sseResponse(connection.connectionStream.subscribe()); @@ -336,7 +332,10 @@ async function forwardClientRequest( const key = messageIdKey(message.id); if (key) { - connection.pendingRoutes.set(key, route.value); + connection.pendingRoutes.set( + key, + pendingResponseRoute(message, route.value), + ); } await writeInbound(connection, message); @@ -359,6 +358,13 @@ async function forwardClientNotification( return { ok: true }; } +function pendingResponseRoute( + message: ClientRequestMessage, + route: ResponseRoute, +): ResponseRoute { + return message.method === AGENT_METHODS.session_load ? "connection" : route; +} + function determineRoute( message: ClientRequestMessage, headers: Headers, From acd2227a61b4d92acecbf7416b50b22901ef6968 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 19 May 2026 21:21:32 +1000 Subject: [PATCH 15/19] Add connection-scoped cookie support --- src/http-stream.test.ts | 99 ++++++++++++++++++++- src/http-stream.ts | 191 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 279 insertions(+), 11 deletions(-) diff --git a/src/http-stream.test.ts b/src/http-stream.test.ts index d74c5bb..4e70b7c 100644 --- a/src/http-stream.test.ts +++ b/src/http-stream.test.ts @@ -137,6 +137,82 @@ describe("createHttpStream", () => { } }); + it("propagates cookies across initialize, SSE, session POST, and DELETE", async () => { + const controlledFetch = createControlledFetch({ + initializeCookies: ["transport=alpha; Path=/"], + getCookies: ["route=bravo; Path=/"], + }); + const stream = createHttpStream("https://agent.example/acp", { + fetch: controlledFetch.fetch, + headers: { + Cookie: "caller=custom; transport=caller", + }, + }); + const writer = stream.writable.getWriter(); + const reader = stream.readable.getReader(); + + try { + await writer.write(initializeRequest); + await readMessage(reader); + await controlledFetch.sendSse(0, sessionNewResponse); + await readMessage(reader); + await writer.write(promptRequest); + await writer.close(); + + expect(requestAt(controlledFetch.requests, 0).credentials).toBe( + "include", + ); + expect(requestAt(controlledFetch.requests, 0).headers.get("Cookie")).toBe( + "caller=custom; transport=caller", + ); + expect(requestAt(controlledFetch.requests, 1).headers.get("Cookie")).toBe( + "transport=caller; caller=custom", + ); + expect(requestAt(controlledFetch.requests, 2).headers.get("Cookie")).toBe( + "transport=caller; route=bravo; caller=custom", + ); + expect(requestAt(controlledFetch.requests, 3).headers.get("Cookie")).toBe( + "transport=caller; route=bravo; caller=custom", + ); + expect(requestAt(controlledFetch.requests, 4).headers.get("Cookie")).toBe( + "transport=caller; route=bravo; caller=custom", + ); + expect( + controlledFetch.requests.map((request) => request.credentials), + ).toEqual(["include", "include", "include", "include", "include"]); + } finally { + reader.releaseLock(); + writer.releaseLock(); + } + }); + + it("omits managed cookies when cookie handling is disabled", async () => { + const controlledFetch = createControlledFetch({ + initializeCookies: ["transport=alpha; Path=/"], + }); + const stream = createHttpStream("https://agent.example/acp", { + fetch: controlledFetch.fetch, + cookies: "omit", + }); + const writer = stream.writable.getWriter(); + const reader = stream.readable.getReader(); + + try { + await writer.write(initializeRequest); + await readMessage(reader); + + expect(requestAt(controlledFetch.requests, 0).credentials).toBe("omit"); + expect(requestAt(controlledFetch.requests, 1).credentials).toBe("omit"); + expect( + requestAt(controlledFetch.requests, 1).headers.get("Cookie"), + ).toBeNull(); + } finally { + reader.releaseLock(); + writer.releaseLock(); + await stream.writable.close(); + } + }); + it("sends DELETE and aborts SSE requests when closed", async () => { const controlledFetch = createControlledFetch(); const stream = createHttpStream("https://agent.example/acp", { @@ -330,6 +406,7 @@ interface RecordedRequest { readonly method: string; readonly headers: Headers; readonly body: string; + readonly credentials: RequestCredentials | undefined; } interface RecordedSseRequest { @@ -349,7 +426,14 @@ interface TestClientState { readonly permissionRequests?: RequestPermissionRequest[]; } -function createControlledFetch(): ControlledFetch { +interface ControlledFetchOptions { + readonly initializeCookies?: readonly string[]; + readonly getCookies?: readonly string[]; +} + +function createControlledFetch( + options: ControlledFetchOptions = {}, +): ControlledFetch { const requests: RecordedRequest[] = []; const sseRequests: RecordedSseRequest[] = []; const encoder = new TextEncoder(); @@ -365,11 +449,13 @@ function createControlledFetch(): ControlledFetch { method, headers, body: bodyToString(init?.body), + credentials: init?.credentials, }); if (method === "POST" && !headers.has(HEADER_CONNECTION_ID)) { return jsonResponse(initializeResponse, 200, { [HEADER_CONNECTION_ID]: "connection-1", + ...setCookieResponseHeaders(options.initializeCookies), }); } @@ -395,7 +481,10 @@ function createControlledFetch(): ControlledFetch { return new Response(stream.readable, { status: 200, - headers: { "Content-Type": EVENT_STREAM_MIME_TYPE }, + headers: { + "Content-Type": EVENT_STREAM_MIME_TYPE, + ...setCookieResponseHeaders(options.getCookies), + }, }); } @@ -485,3 +574,9 @@ function jsonResponse( }, }); } + +function setCookieResponseHeaders( + cookies: readonly string[] | undefined, +): Record { + return cookies ? { "Set-Cookie": cookies.join(", ") } : {}; +} diff --git a/src/http-stream.ts b/src/http-stream.ts index ab432c6..665e6a5 100644 --- a/src/http-stream.ts +++ b/src/http-stream.ts @@ -18,13 +18,15 @@ export interface HttpStreamOptions { readonly fetch?: typeof globalThis.fetch; /** Headers to include on every HTTP/SSE request. */ readonly headers?: Record; + /** Cookie handling policy for transport requests. Defaults to `include`. */ + readonly cookies?: "include" | "omit"; } /** * Creates an ACP Stream over Streamable HTTP. * * Uses POST for client messages and SSE GET streams for server messages. - * Pass a custom `fetch` for cookies, auth, proxies, or non-browser runtimes. + * Cookies are included by default for the lifetime of one stream. */ export function createHttpStream( serverUrl: string, @@ -38,6 +40,8 @@ class HttpStreamTransport { private readonly fetchImpl: typeof globalThis.fetch; private readonly headers: Record; + private readonly cookiePolicy: RequestCredentials; + private readonly cookieJar = new ConnectionCookieJar(); private readonly abortController = new AbortController(); private readonly knownSessions = new Set(); @@ -54,6 +58,7 @@ class HttpStreamTransport { ) { this.fetchImpl = resolveFetch(options.fetch); this.headers = options.headers ?? {}; + this.cookiePolicy = options.cookies ?? "include"; this.stream = { readable: new ReadableStream({ @@ -95,10 +100,9 @@ class HttpStreamTransport { throw new Error("ACP HTTP stream first message must be initialize"); } - const response = await this.fetchImpl(this.serverUrl, { + const response = await this.fetchRequest({ method: "POST", headers: { - ...this.headers, "Content-Type": JSON_MIME_TYPE, }, body: JSON.stringify(message), @@ -130,10 +134,9 @@ class HttpStreamTransport { } const sessionId = sessionIdFromMessageParams(message); - const response = await this.fetchImpl(this.serverUrl, { + const response = await this.fetchRequest({ method: "POST", headers: { - ...this.headers, "Content-Type": JSON_MIME_TYPE, [HEADER_CONNECTION_ID]: connectionId, ...(sessionId ? { [HEADER_SESSION_ID]: sessionId } : {}), @@ -177,10 +180,9 @@ class HttpStreamTransport { private async openSse(headers: Record): Promise { try { - const response = await this.fetchImpl(this.serverUrl, { + const response = await this.fetchRequest({ method: "GET", headers: { - ...this.headers, Accept: EVENT_STREAM_MIME_TYPE, ...headers, }, @@ -216,6 +218,35 @@ class HttpStreamTransport { } } + private async fetchRequest(init: RequestInit): Promise { + const response = await this.fetchImpl(this.serverUrl, { + ...init, + credentials: this.cookiePolicy, + headers: this.createRequestHeaders(init.headers), + }); + + if (this.cookiePolicy === "include") { + this.cookieJar.store(response.headers); + } + + return response; + } + + private createRequestHeaders(headers: HeadersInit | undefined): Headers { + const requestHeaders = new Headers(this.headers); + const transportHeaders = new Headers(headers); + + transportHeaders.forEach((value, key) => { + requestHeaders.set(key, value); + }); + + if (this.cookiePolicy === "include") { + this.cookieJar.apply(requestHeaders); + } + + return requestHeaders; + } + private async close(): Promise { if (this.isClosed) { return; @@ -225,22 +256,23 @@ class HttpStreamTransport { const connectionId = this.connectionId; if (connectionId) { - const response = await this.fetchImpl(this.serverUrl, { + const response = await this.fetchRequest({ method: "DELETE", headers: { - ...this.headers, [HEADER_CONNECTION_ID]: connectionId, }, }); if (!response.ok) { this.abortController.abort(); + this.cookieJar.clear(); this.closeReadable(); throw await httpError("ACP DELETE failed", response); } } this.abortController.abort(); + this.cookieJar.clear(); this.closeReadable(); } @@ -259,6 +291,7 @@ class HttpStreamTransport { this.isClosed = true; this.abortController.abort(); + this.cookieJar.clear(); try { this.readableController?.error(error); @@ -276,6 +309,48 @@ class HttpStreamTransport { } } +class ConnectionCookieJar { + private readonly cookies = new Map(); + + store(headers: Headers): void { + for (const value of setCookieHeaders(headers)) { + const cookie = parseSetCookie(value); + if (!cookie) { + continue; + } + + this.cookies.set(cookie.name, cookie.value); + } + } + + apply(headers: Headers): void { + const merged = mergeCookieHeaders( + this.cookieHeader(), + headers.get("Cookie"), + ); + if (merged) { + headers.set("Cookie", merged); + } + } + + clear(): void { + this.cookies.clear(); + } + + private cookieHeader(): string | undefined { + return this.cookies.size === 0 + ? undefined + : Array.from(this.cookies) + .map(([name, value]) => `${name}=${value}`) + .join("; "); + } +} + +interface CookiePair { + readonly name: string; + readonly value: string; +} + function resolveFetch( fetchImpl: typeof globalThis.fetch | undefined, ): typeof globalThis.fetch { @@ -292,6 +367,104 @@ function resolveFetch( ); } +function setCookieHeaders(headers: Headers): string[] { + const getSetCookie = headers.getSetCookie; + if (typeof getSetCookie === "function") { + return getSetCookie.call(headers); + } + + const setCookie = headers.get("Set-Cookie"); + return setCookie ? splitSetCookieHeader(setCookie) : []; +} + +function splitSetCookieHeader(header: string): string[] { + const result: string[] = []; + let start = 0; + let isInExpires = false; + + for (let index = 0; index < header.length; index += 1) { + const char = header[index]; + + if (char === "," && !isInExpires) { + result.push(header.slice(start, index).trim()); + start = index + 1; + continue; + } + + if (header.slice(index, index + 8).toLowerCase() === "expires=") { + isInExpires = true; + index += 7; + continue; + } + + if (char === ";" && isInExpires) { + isInExpires = false; + } + } + + result.push(header.slice(start).trim()); + return result.filter((value) => value.length > 0); +} + +function parseSetCookie(header: string): CookiePair | undefined { + const pair = header.split(";", 1)[0]; + const separator = pair.indexOf("="); + + if (separator <= 0) { + return undefined; + } + + return { + name: pair.slice(0, separator).trim(), + value: pair.slice(separator + 1).trim(), + }; +} + +function mergeCookieHeaders( + jarCookieHeader: string | undefined, + callerCookieHeader: string | null, +): string | undefined { + const cookies = new Map(); + + for (const cookie of parseCookieHeader(jarCookieHeader)) { + cookies.set(cookie.name, cookie.value); + } + + for (const cookie of parseCookieHeader(callerCookieHeader ?? undefined)) { + cookies.set(cookie.name, cookie.value); + } + + return cookies.size === 0 + ? undefined + : Array.from(cookies) + .map(([name, value]) => `${name}=${value}`) + .join("; "); +} + +function parseCookieHeader(header: string | undefined): CookiePair[] { + if (!header) { + return []; + } + + return header + .split(";") + .map(parseCookiePair) + .filter((cookie): cookie is CookiePair => cookie !== undefined); +} + +function parseCookiePair(value: string): CookiePair | undefined { + const separator = value.indexOf("="); + + if (separator <= 0) { + return undefined; + } + + return { + name: value.slice(0, separator).trim(), + value: value.slice(separator + 1).trim(), + }; +} + async function httpError(prefix: string, response: Response): Promise { const text = await response.text().catch(() => ""); From 50fdedf6f51c20d63c82b0792e8ab98427d4bfb5 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 19 May 2026 21:26:20 +1000 Subject: [PATCH 16/19] Update examples and README --- src/examples/README.md | 8 +++++++- src/examples/http-client.ts | 3 +-- src/examples/http-server.ts | 1 + 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/examples/README.md b/src/examples/README.md index d1c19a3..f00c3cd 100644 --- a/src/examples/README.md +++ b/src/examples/README.md @@ -99,4 +99,10 @@ Or run the WebSocket client: npx tsx src/examples/ws-client.ts ``` -The HTTP example sends a bearer token through custom request headers. The WebSocket example passes the Node `ws` constructor so custom headers can be sent during the WebSocket handshake. Browser WebSocket clients can use `createWebSocketStream` too, but browsers do not allow custom WebSocket headers. +The HTTP example sends a bearer token through custom request headers. `createHttpStream` includes cookies by default for the lifetime of one stream: it sends credentials on fetch requests, captures exposed `Set-Cookie` headers, merges them with caller-provided `Cookie` headers, and reuses them for connection SSE, session SSE, POST, and DELETE requests. Pass `cookies: "omit"` to disable this behavior for stateless transports. + +The WebSocket server example uses `createNodeWebSocketUpgradeHandler`, which creates the ACP connection before the upgrade completes and adds `Acp-Connection-Id` to the `101 Switching Protocols` response. Frameworks that only expose an already-upgraded WebSocket socket cannot add that response header, so prefer an upgrade hook when building compliant servers. + +The WebSocket client example passes the Node `ws` constructor so custom headers can be sent during the WebSocket handshake. Browser WebSocket clients can use `createWebSocketStream` too, but browsers do not allow custom WebSocket headers. Use cookies or URL-level authentication for browser WebSocket authentication instead of relying on custom handshake headers. + +The included Node HTTP server is an HTTP/1.1 compatibility adapter. HTTP/2 deployment guidance is still tracked separately in the transport hardening plan. diff --git a/src/examples/http-client.ts b/src/examples/http-client.ts index 514a276..f0c61a5 100644 --- a/src/examples/http-client.ts +++ b/src/examples/http-client.ts @@ -34,8 +34,7 @@ const stream = createHttpStream(serverUrl, { headers: { Authorization: "Bearer example-token", }, - // To use cookies, pass a cookie-aware fetch implementation here instead of relying on a built-in cookie jar. - // fetch: cookieAwareFetch, + // Cookies are included by default and scoped to this stream. Use `cookies: "omit"` for stateless requests. }); const connection = new acp.ClientSideConnection( (_agent) => new HttpExampleClient(), diff --git a/src/examples/http-server.ts b/src/examples/http-server.ts index 6fb85b0..2ba3962 100644 --- a/src/examples/http-server.ts +++ b/src/examples/http-server.ts @@ -71,6 +71,7 @@ const acpServer = new AcpServer({ }); const acpHttpHandler = createNodeHttpHandler(acpServer); const webSocketServer = new WebSocketServer({ noServer: true }); +// Use the ACP upgrade helper so the 101 response includes Acp-Connection-Id. const acpWebSocketUpgradeHandler = createNodeWebSocketUpgradeHandler( acpServer, webSocketServer, From 66da6ed47f40ada583b5ebf4f618e7b43b8768aa Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 19 May 2026 21:42:23 +1000 Subject: [PATCH 17/19] Enforce ACP transport routing validation --- src/server-session-sse.test.ts | 63 ++++++++++++++++++++++++++++++---- src/server.test.ts | 19 ++++++++-- src/server.ts | 54 ++++++++++++++--------------- 3 files changed, 100 insertions(+), 36 deletions(-) diff --git a/src/server-session-sse.test.ts b/src/server-session-sse.test.ts index 048665d..1ae524b 100644 --- a/src/server-session-sse.test.ts +++ b/src/server-session-sse.test.ts @@ -77,6 +77,16 @@ function createLoadSessionRequest(id: number, sessionId: string) { }; } +function createCancelNotification(sessionId: string) { + return { + jsonrpc: "2.0", + method: "session/cancel", + params: { + sessionId, + }, + }; +} + class LoadSessionAgent extends TestAgent { constructor(private readonly agentConnection: AgentSideConnection) { super(agentConnection); @@ -218,6 +228,47 @@ describe("AcpServer session SSE", () => { } }); + it("rejects session-scoped notifications without a session header", async () => { + const server = await startTestServer(); + + try { + const connectionId = await initialize(server.url); + const sessionId = await createSession(server.url, connectionId); + const response = await postJson( + server.url, + createCancelNotification(sessionId), + { + [HEADER_CONNECTION_ID]: connectionId, + }, + ); + + expect(response.status).toBe(400); + } finally { + await server.close(); + } + }); + + it("rejects session-scoped notifications with mismatched session header and params", async () => { + const server = await startTestServer(); + + try { + const connectionId = await initialize(server.url); + const sessionId = await createSession(server.url, connectionId); + const response = await postJson( + server.url, + createCancelNotification("other-session"), + { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, + ); + + expect(response.status).toBe(400); + } finally { + await server.close(); + } + }); + it("rejects session-scoped requests without any session identifier", async () => { const server = await startTestServer(); @@ -262,11 +313,6 @@ describe("AcpServer session SSE", () => { const connectionId = await initialize(server.url); const sessionId = "existing-session"; const connectionSse = await openConnectionSse(server.url, connectionId); - const sessionSse = await openSessionSse( - server.url, - connectionId, - sessionId, - ); const accepted = await postJson( server.url, createLoadSessionRequest(3, sessionId), @@ -275,9 +321,14 @@ describe("AcpServer session SSE", () => { [HEADER_SESSION_ID]: sessionId, }, ); + const sessionSse = await openSessionSse( + server.url, + connectionId, + sessionId, + ); - expect(sessionSse.status).toBe(200); expect(accepted.status).toBe(202); + expect(sessionSse.status).toBe(200); expect(await readSseMessages(sessionSse, 1)).toMatchObject([ { jsonrpc: "2.0", diff --git a/src/server.test.ts b/src/server.test.ts index a78757e..4977baf 100644 --- a/src/server.test.ts +++ b/src/server.test.ts @@ -176,7 +176,7 @@ describe("AcpServer", () => { } }); - it("opens session-scoped GETs for sessions without local streams", async () => { + it("rejects session-scoped GETs for unknown sessions", async () => { const server = await startTestServer(); try { @@ -187,7 +187,7 @@ describe("AcpServer", () => { globalThis.crypto.randomUUID(), ); - expect(response.status).toBe(200); + expect(response.status).toBe(404); } finally { await server.close(); } @@ -385,6 +385,21 @@ describe("AcpServer", () => { } }); + it("rejects initialize requests on existing connections", async () => { + const server = await startTestServer(); + + try { + const connectionId = await initialize(server.url); + const response = await postJson(server.url, initializeRequest, { + [HEADER_CONNECTION_ID]: connectionId, + }); + + expect(response.status).toBe(400); + } finally { + await server.close(); + } + }); + it("rejects unknown connection IDs", async () => { const server = await startTestServer(); diff --git a/src/server.ts b/src/server.ts index ca4ca72..1a09aa3 100644 --- a/src/server.ts +++ b/src/server.ts @@ -9,11 +9,7 @@ import { methodRequiresSessionHeader, sessionIdFromParams, } from "./protocol.js"; -import { - isJsonRpcMessage, - isRequestMessage, - isResponseMessage, -} from "./jsonrpc.js"; +import { isJsonRpcMessage, isResponseMessage } from "./jsonrpc.js"; import { AGENT_METHODS } from "./schema/index.js"; import { serializeSseEvent, serializeSseKeepAlive } from "./sse.js"; import { handleWebSocketConnection } from "./ws-server.js"; @@ -25,7 +21,12 @@ import type { ResponseRoute, } from "./connection.js"; import type { Agent, AgentSideConnection } from "./acp.js"; -import type { AnyMessage, AnyRequest, AnyResponse } from "./jsonrpc.js"; +import type { + AnyMessage, + AnyNotification, + AnyRequest, + AnyResponse, +} from "./jsonrpc.js"; /** Options for creating an ACP server transport. */ export interface AcpServerOptions { @@ -129,8 +130,12 @@ export class AcpServer { const connectionId = req.headers.get(HEADER_CONNECTION_ID); - if (isInitializeRequest(body.value) && !connectionId) { - return await this.handleInitialize(body.value); + if (isInitializeRequest(body.value)) { + if (!connectionId) { + return await this.handleInitialize(body.value); + } + + return textResponse("Initialize not allowed on existing connection", 400); } if (!connectionId) { @@ -180,7 +185,12 @@ export class AcpServer { const sessionId = req.headers.get(HEADER_SESSION_ID); if (sessionId) { - return sseResponse(connection.ensureSession(sessionId).subscribe()); + const sessionStream = connection.sessionStreams.get(sessionId); + if (!sessionStream) { + return textResponse("Unknown Acp-Session-Id", 404); + } + + return sseResponse(sessionStream.subscribe()); } return sseResponse(connection.connectionStream.subscribe()); @@ -244,15 +254,11 @@ export class AcpServer { message: AnyMessage, headers: Headers, ): Promise { - if (isRequestMessage(message)) { - return await forwardClientRequest(connection, message, headers); - } - if (isResponseMessage(message)) { return await forwardClientResponse(connection, message); } - return await forwardClientNotification(connection, message); + return await forwardClientMethodMessage(connection, message, headers); } } @@ -286,7 +292,7 @@ type RouteResult = message: string; }; -type ClientRequestMessage = AnyRequest; +type ClientMethodMessage = AnyRequest | AnyNotification; async function readJson(req: Request): Promise { try { @@ -314,9 +320,9 @@ async function writeInbound( } } -async function forwardClientRequest( +async function forwardClientMethodMessage( connection: ConnectionState, - message: ClientRequestMessage, + message: ClientMethodMessage, headers: Headers, ): Promise { const route = determineRoute(message, headers); @@ -329,7 +335,7 @@ async function forwardClientRequest( connection.ensureSession(route.value.session); } - const key = messageIdKey(message.id); + const key = "id" in message ? messageIdKey(message.id) : undefined; if (key) { connection.pendingRoutes.set( @@ -350,23 +356,15 @@ async function forwardClientResponse( return { ok: true }; } -async function forwardClientNotification( - connection: ConnectionState, - message: AnyMessage, -): Promise { - await writeInbound(connection, message); - return { ok: true }; -} - function pendingResponseRoute( - message: ClientRequestMessage, + message: ClientMethodMessage, route: ResponseRoute, ): ResponseRoute { return message.method === AGENT_METHODS.session_load ? "connection" : route; } function determineRoute( - message: ClientRequestMessage, + message: ClientMethodMessage, headers: Headers, ): RouteResult { const headerSessionId = headers.get(HEADER_SESSION_ID); From 2a6d4c5d085dc327b061867abaddd3d93d5b3320 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Tue, 19 May 2026 21:58:07 +1000 Subject: [PATCH 18/19] fix: Align HTTP session routing with RFD --- src/connection.ts | 18 ++++++ src/http-stream.test.ts | 104 ++++++++++++++++++++++++++++++++++ src/http-stream.ts | 37 +++++++++++- src/server-permission.test.ts | 63 +++++++++++++++++++- src/server.ts | 27 ++++++++- 5 files changed, 246 insertions(+), 3 deletions(-) diff --git a/src/connection.ts b/src/connection.ts index cee4b77..871b808 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -91,6 +91,7 @@ export class ConnectionState { readonly allOutbound = new OutboundStream(); readonly sessionStreams = new Map(); readonly pendingRoutes = new Map(); + readonly clientResponseRoutes = new Map(); private hasStartedRouter = false; private outboundReader: ReadableStreamDefaultReader | undefined; @@ -163,6 +164,7 @@ export class ConnectionState { this.sessionStreams.clear(); this.pendingRoutes.clear(); + this.clientResponseRoutes.clear(); await Promise.allSettled([ this.inboundTx.close(), @@ -231,13 +233,29 @@ export class ConnectionState { private routeOutboundRequestOrNotification(message: AnyMessage): void { const sessionId = sessionIdFromMessageParams(message); if (sessionId) { + this.trackClientResponseRoute(message, { session: sessionId }); this.ensureSession(sessionId).push(message); return; } + this.trackClientResponseRoute(message, "connection"); this.connectionStream.push(message); } + private trackClientResponseRoute( + message: AnyMessage, + route: ResponseRoute, + ): void { + if (!("id" in message) || !("method" in message)) { + return; + } + + const key = messageIdKey(message.id); + if (key) { + this.clientResponseRoutes.set(key, route); + } + } + private pushToRoute(route: ResponseRoute, message: AnyMessage): void { if (route === "connection") { this.connectionStream.push(message); diff --git a/src/http-stream.test.ts b/src/http-stream.test.ts index 4e70b7c..63be5c9 100644 --- a/src/http-stream.test.ts +++ b/src/http-stream.test.ts @@ -59,6 +59,48 @@ const promptRequest = { }, } satisfies AnyMessage; +const loadSessionRequest = { + jsonrpc: "2.0", + id: 3, + method: "session/load", + params: { + cwd: "/tmp", + mcpServers: [], + sessionId: "existing-session", + }, +} satisfies AnyMessage; + +const permissionRequest = { + jsonrpc: "2.0", + id: 99, + method: "session/request_permission", + params: { + sessionId: "session-1", + toolCall: { + toolCallId: "permission-tool", + title: "Permission tool", + }, + options: [ + { + kind: "allow_once", + name: "Allow once", + optionId: "allow", + }, + ], + }, +} satisfies AnyMessage; + +const permissionResponse = { + jsonrpc: "2.0", + id: 99, + result: { + outcome: { + outcome: "selected", + optionId: "allow", + }, + }, +} satisfies AnyMessage; + describe("createHttpStream", () => { it("posts initialize with custom headers, opens connection SSE, and emits the initialize response", async () => { const controlledFetch = createControlledFetch(); @@ -137,6 +179,68 @@ describe("createHttpStream", () => { } }); + it("opens session SSE before posting session/load for an existing session", async () => { + const controlledFetch = createControlledFetch(); + const stream = createHttpStream("https://agent.example/acp", { + fetch: controlledFetch.fetch, + }); + const writer = stream.writable.getWriter(); + const reader = stream.readable.getReader(); + + try { + await writer.write(initializeRequest); + await readMessage(reader); + await writer.write(loadSessionRequest); + + const sessionGet = requestAt(controlledFetch.requests, 2); + const loadPost = requestAt(controlledFetch.requests, 3); + + expect(sessionGet.method).toBe("GET"); + expect(sessionGet.headers.get(HEADER_CONNECTION_ID)).toBe("connection-1"); + expect(sessionGet.headers.get(HEADER_SESSION_ID)).toBe( + "existing-session", + ); + expect(loadPost.method).toBe("POST"); + expect(loadPost.headers.get(HEADER_CONNECTION_ID)).toBe("connection-1"); + expect(loadPost.headers.get(HEADER_SESSION_ID)).toBe("existing-session"); + } finally { + reader.releaseLock(); + writer.releaseLock(); + await stream.writable.close(); + } + }); + + it("includes the session header on responses to session-scoped server requests", async () => { + const controlledFetch = createControlledFetch(); + const stream = createHttpStream("https://agent.example/acp", { + fetch: controlledFetch.fetch, + }); + const writer = stream.writable.getWriter(); + const reader = stream.readable.getReader(); + + try { + await writer.write(initializeRequest); + await readMessage(reader); + await controlledFetch.sendSse(0, sessionNewResponse); + await readMessage(reader); + await controlledFetch.sendSse(1, permissionRequest); + await readMessage(reader); + await writer.write(permissionResponse); + + const responsePost = requestAt(controlledFetch.requests, 3); + expect(responsePost.method).toBe("POST"); + expect(responsePost.headers.get(HEADER_CONNECTION_ID)).toBe( + "connection-1", + ); + expect(responsePost.headers.get(HEADER_SESSION_ID)).toBe("session-1"); + expect(JSON.parse(responsePost.body)).toEqual(permissionResponse); + } finally { + reader.releaseLock(); + writer.releaseLock(); + await stream.writable.close(); + } + }); + it("propagates cookies across initialize, SSE, session POST, and DELETE", async () => { const controlledFetch = createControlledFetch({ initializeCookies: ["transport=alpha; Path=/"], diff --git a/src/http-stream.ts b/src/http-stream.ts index 665e6a5..3df6b7a 100644 --- a/src/http-stream.ts +++ b/src/http-stream.ts @@ -5,6 +5,7 @@ import { HEADER_SESSION_ID, JSON_MIME_TYPE, isInitializeRequest, + messageIdKey, sessionIdFromMessageParams, sessionIdFromResponseResult, } from "./protocol.js"; @@ -44,6 +45,7 @@ class HttpStreamTransport { private readonly cookieJar = new ConnectionCookieJar(); private readonly abortController = new AbortController(); private readonly knownSessions = new Set(); + private readonly pendingResponseSessions = new Map(); private readableController: | ReadableStreamDefaultController @@ -133,7 +135,11 @@ class HttpStreamTransport { throw new Error("ACP HTTP stream is not initialized"); } - const sessionId = sessionIdFromMessageParams(message); + const sessionId = this.sessionIdForOutboundMessage(message); + if (sessionId) { + this.openSessionSse(sessionId); + } + const response = await this.fetchRequest({ method: "POST", headers: { @@ -149,6 +155,20 @@ class HttpStreamTransport { } } + private sessionIdForOutboundMessage(message: AnyMessage): string | undefined { + const paramsSessionId = sessionIdFromMessageParams(message); + if (paramsSessionId) { + return paramsSessionId; + } + + if (!("id" in message) || "method" in message) { + return undefined; + } + + const key = messageIdKey(message.id); + return key ? this.pendingResponseSessions.get(key) : undefined; + } + private openConnectionSse(): void { const connectionId = this.connectionId; if (!connectionId) { @@ -207,6 +227,7 @@ class HttpStreamTransport { this.openSessionSse(sessionId); } + this.trackServerRequestRoute(message, headers[HEADER_SESSION_ID]); this.enqueue(message); } } catch (error) { @@ -218,6 +239,20 @@ class HttpStreamTransport { } } + private trackServerRequestRoute( + message: AnyMessage, + streamSessionId: string | undefined, + ): void { + if (!streamSessionId || !("method" in message) || !("id" in message)) { + return; + } + + const key = messageIdKey(message.id); + if (key) { + this.pendingResponseSessions.set(key, streamSessionId); + } + } + private async fetchRequest(init: RequestInit): Promise { const response = await this.fetchImpl(this.serverUrl, { ...init, diff --git a/src/server-permission.test.ts b/src/server-permission.test.ts index 0ec3a9f..7cea5b0 100644 --- a/src/server-permission.test.ts +++ b/src/server-permission.test.ts @@ -45,6 +45,64 @@ function createPromptRequest(id: number, sessionId: string) { } describe("AcpServer permission requests over HTTP", () => { + it("rejects session-scoped client responses without a session header", async () => { + const server = await startTestServer( + (conn: AgentSideConnection) => + new TestAgent(conn, { enablePermission: true }), + ); + + try { + const connectionId = await initialize(server.url); + const sessionId = await createSession(server.url, connectionId); + const sessionSse = await openSessionSse( + server.url, + connectionId, + sessionId, + ); + const sessionEvents = createSseMessageIterator(sessionSse); + + expect( + await postJson(server.url, createPromptRequest(3, sessionId), { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }), + ).toMatchObject({ status: 202 }); + + await readNextSseMessage(sessionEvents); + const permissionRequest = await readNextSseMessage(sessionEvents); + + const permissionResponse = { + jsonrpc: "2.0", + id: readMessageId(permissionRequest), + result: { + outcome: { + outcome: "selected", + optionId: "allow", + }, + }, + }; + + expect( + await postJson(server.url, permissionResponse, { + [HEADER_CONNECTION_ID]: connectionId, + }), + ).toMatchObject({ status: 400 }); + expect( + await postJson(server.url, permissionResponse, { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }), + ).toMatchObject({ status: 202 }); + + await readNextSseMessage(sessionEvents); + await readNextSseMessage(sessionEvents); + await sessionEvents.return?.(); + await sessionSse.body?.cancel(); + } finally { + await server.close(); + } + }, 10_000); + it("routes permission requests over session SSE and accepts client responses", async () => { const server = await startTestServer( (conn: AgentSideConnection) => @@ -122,7 +180,10 @@ describe("AcpServer permission requests over HTTP", () => { }, }, }, - { [HEADER_CONNECTION_ID]: connectionId }, + { + [HEADER_CONNECTION_ID]: connectionId, + [HEADER_SESSION_ID]: sessionId, + }, ), ).toMatchObject({ status: 202 }); diff --git a/src/server.ts b/src/server.ts index 1a09aa3..e50aabc 100644 --- a/src/server.ts +++ b/src/server.ts @@ -255,7 +255,7 @@ export class AcpServer { headers: Headers, ): Promise { if (isResponseMessage(message)) { - return await forwardClientResponse(connection, message); + return await forwardClientResponse(connection, message, headers); } return await forwardClientMethodMessage(connection, message, headers); @@ -351,7 +351,32 @@ async function forwardClientMethodMessage( async function forwardClientResponse( connection: ConnectionState, message: AnyResponse, + headers: Headers, ): Promise { + const key = messageIdKey(message.id); + const route = key ? connection.clientResponseRoutes.get(key) : undefined; + const headerSessionId = headers.get(HEADER_SESSION_ID); + + if (route && route !== "connection" && !headerSessionId) { + return { + ok: false, + status: 400, + message: "Missing Acp-Session-Id", + }; + } + + if (route && route !== "connection" && headerSessionId !== route.session) { + return { + ok: false, + status: 400, + message: "Mismatched Acp-Session-Id", + }; + } + + if (key) { + connection.clientResponseRoutes.delete(key); + } + await writeInbound(connection, message); return { ok: true }; } From 73f02f4096df5e9421f9c66676baa620963a5ca9 Mon Sep 17 00:00:00 2001 From: Federico Ciner Date: Wed, 20 May 2026 13:44:37 +1000 Subject: [PATCH 19/19] Support per-connection agent factories --- src/server-websocket-upgrade.test.ts | 243 +++++++++++++++++++++++++++ src/server.test.ts | 211 +++++++++++++++++++++++ src/server.ts | 46 +++-- 3 files changed, 489 insertions(+), 11 deletions(-) create mode 100644 src/server-websocket-upgrade.test.ts diff --git a/src/server-websocket-upgrade.test.ts b/src/server-websocket-upgrade.test.ts new file mode 100644 index 0000000..098d1b8 --- /dev/null +++ b/src/server-websocket-upgrade.test.ts @@ -0,0 +1,243 @@ +import { describe, expect, it } from "vitest"; + +import { PROTOCOL_VERSION } from "./acp.js"; +import { HEADER_CONNECTION_ID } from "./protocol.js"; +import { AcpServer } from "./server.js"; +import { TestAgent } from "./test-support/test-agent.js"; + +import type { Agent, AgentSideConnection } from "./acp.js"; +import type { AnyMessage } from "./jsonrpc.js"; +import type { WebSocketServerSocket } from "./ws-server.js"; + +const initializeRequest = { + jsonrpc: "2.0", + id: 0, + method: "initialize", + params: { + protocolVersion: PROTOCOL_VERSION, + clientCapabilities: {}, + }, +} satisfies AnyMessage; + +describe("AcpServer prepared WebSocket upgrades", () => { + it("uses the default factory when no per-upgrade override is provided", async () => { + const createdBy: string[] = []; + const server = new AcpServer({ + createAgent: recordingFactory(createdBy, "default"), + }); + const socket = new FakeServerSocket(); + + try { + server.prepareWebSocketUpgrade().accept(socket); + socket.receive(JSON.stringify(initializeRequest)); + + await expect(readSentMessage(socket)).resolves.toMatchObject({ + jsonrpc: "2.0", + id: initializeRequest.id, + result: { + protocolVersion: PROTOCOL_VERSION, + }, + }); + expect(createdBy).toEqual(["default"]); + } finally { + socket.close(); + await server.close(); + } + }); + + it("uses a per-upgrade factory override for that WebSocket connection", async () => { + const createdBy: string[] = []; + const server = new AcpServer({ + createAgent: recordingFactory(createdBy, "default"), + }); + const socket = new FakeServerSocket(); + + try { + server + .prepareWebSocketUpgrade({ + createAgent: recordingFactory(createdBy, "override"), + }) + .accept(socket); + socket.receive(JSON.stringify(initializeRequest)); + + await readSentMessage(socket); + expect(createdBy).toEqual(["override"]); + } finally { + socket.close(); + await server.close(); + } + }); + + it("does not leak WebSocket factory overrides to later prepared upgrades", async () => { + const createdBy: string[] = []; + const server = new AcpServer({ + createAgent: recordingFactory(createdBy, "default"), + }); + const overrideSocket = new FakeServerSocket(); + const defaultSocket = new FakeServerSocket(); + + try { + server + .prepareWebSocketUpgrade({ + createAgent: recordingFactory(createdBy, "override"), + }) + .accept(overrideSocket); + server.prepareWebSocketUpgrade().accept(defaultSocket); + + overrideSocket.receive(JSON.stringify(initializeRequest)); + defaultSocket.receive(JSON.stringify({ ...initializeRequest, id: 1 })); + + await Promise.all([ + readSentMessage(overrideSocket), + readSentMessage(defaultSocket), + ]); + expect(createdBy).toEqual(["override", "default"]); + } finally { + overrideSocket.close(); + defaultSocket.close(); + await server.close(); + } + }); + + it("keeps concurrent WebSocket factory overrides isolated", async () => { + const createdBy: string[] = []; + const server = new AcpServer({ + createAgent: recordingFactory(createdBy, "default"), + }); + const firstSocket = new FakeServerSocket(); + const secondSocket = new FakeServerSocket(); + + try { + const first = server.prepareWebSocketUpgrade({ + createAgent: recordingFactory(createdBy, "first"), + }); + const second = server.prepareWebSocketUpgrade({ + createAgent: recordingFactory(createdBy, "second"), + }); + + second.accept(secondSocket); + first.accept(firstSocket); + secondSocket.receive(JSON.stringify({ ...initializeRequest, id: 2 })); + firstSocket.receive(JSON.stringify({ ...initializeRequest, id: 1 })); + + await Promise.all([ + readSentMessage(firstSocket), + readSentMessage(secondSocket), + ]); + expect(createdBy).toEqual(expect.arrayContaining(["first", "second"])); + expect(createdBy).toHaveLength(2); + } finally { + firstSocket.close(); + secondSocket.close(); + await server.close(); + } + }); + + it("removes rejected prepared WebSocket connections", async () => { + const server = new AcpServer({ + createAgent: (conn) => new TestAgent(conn), + }); + const prepared = server.prepareWebSocketUpgrade(); + + try { + prepared.reject(); + const response = await server.handleRequest( + new Request("http://127.0.0.1/acp", { + method: "GET", + headers: { + Accept: "text/event-stream", + [HEADER_CONNECTION_ID]: prepared.connectionId, + }, + }), + ); + + expect(response.status).toBe(404); + } finally { + await server.close(); + } + }); + + it("keeps existing double-settle behavior for prepared WebSocket upgrades", async () => { + const server = new AcpServer({ + createAgent: (conn) => new TestAgent(conn), + }); + const rejected = server.prepareWebSocketUpgrade(); + const accepted = server.prepareWebSocketUpgrade(); + const socket = new FakeServerSocket(); + + try { + rejected.reject(); + expect(() => rejected.accept(new FakeServerSocket())).toThrow( + "ACP WebSocket upgrade has already been settled", + ); + + accepted.accept(socket); + expect(() => accepted.accept(new FakeServerSocket())).toThrow( + "ACP WebSocket upgrade has already been settled", + ); + expect(() => accepted.reject()).not.toThrow(); + } finally { + socket.close(); + await server.close(); + } + }); +}); + +function recordingFactory( + createdBy: string[], + label: string, +): (conn: AgentSideConnection) => Agent { + return (conn) => { + createdBy.push(label); + return new TestAgent(conn); + }; +} + +function readSentMessage(socket: FakeServerSocket): Promise { + const message = socket.sent.shift(); + + if (message) { + return Promise.resolve(JSON.parse(message)); + } + + return new Promise((resolve) => { + socket.onSend = (data) => { + resolve(JSON.parse(data)); + }; + }); +} + +class FakeServerSocket implements WebSocketServerSocket { + readonly sent: string[] = []; + readonly listeners = new Map void>>(); + onSend: ((data: string) => void) | undefined; + + send(data: string): void { + this.sent.push(data); + this.onSend?.(data); + this.onSend = undefined; + } + + close(_code?: number, _reason?: string): void { + this.emit("close", {}); + } + + addEventListener(type: string, listener: (event: unknown) => void): void { + this.listeners.set(type, this.listeners.get(type) ?? new Set()); + this.listeners.get(type)?.add(listener); + } + + removeEventListener(type: string, listener: (event: unknown) => void): void { + this.listeners.get(type)?.delete(listener); + } + + receive(data: string): void { + this.emit("message", { data }); + } + + private emit(type: string, event: unknown): void { + for (const listener of this.listeners.get(type) ?? []) { + listener(event); + } + } +} diff --git a/src/server.test.ts b/src/server.test.ts index 4977baf..e2081dd 100644 --- a/src/server.test.ts +++ b/src/server.test.ts @@ -60,6 +60,193 @@ describe("AcpServer", () => { } }); + it("uses the default factory for direct HTTP initialize requests", async () => { + const createdBy: string[] = []; + const server = new AcpServer({ + createAgent: (conn: AgentSideConnection) => { + createdBy.push("default"); + return new TestAgent(conn); + }, + }); + + try { + const response = await server.handleRequest( + jsonRequest(initializeRequest), + ); + + expect(response.status).toBe(200); + expect(response.headers.get(HEADER_CONNECTION_ID)).toMatch( + /^[0-9a-f-]{36}$/, + ); + expect(createdBy).toEqual(["default"]); + } finally { + await server.close(); + } + }); + + it("uses per-request factory overrides for direct HTTP initialize requests", async () => { + const createdBy: string[] = []; + const server = new AcpServer({ + createAgent: (conn: AgentSideConnection) => { + createdBy.push("default"); + return new TestAgent(conn); + }, + }); + + try { + const response = await server.handleRequest( + jsonRequest(initializeRequest), + { + createAgent: (conn) => { + createdBy.push("override"); + return new TestAgent(conn); + }, + }, + ); + + expect(response.status).toBe(200); + expect(response.headers.get(HEADER_CONNECTION_ID)).toMatch( + /^[0-9a-f-]{36}$/, + ); + expect(createdBy).toEqual(["override"]); + } finally { + await server.close(); + } + }); + + it("does not leak HTTP factory overrides to later initialize requests", async () => { + const createdBy: string[] = []; + const server = new AcpServer({ + createAgent: (conn: AgentSideConnection) => { + createdBy.push("default"); + return new TestAgent(conn); + }, + }); + + try { + await server.handleRequest(jsonRequest(initializeRequest), { + createAgent: (conn) => { + createdBy.push("override"); + return new TestAgent(conn); + }, + }); + await server.handleRequest(jsonRequest({ ...initializeRequest, id: 2 })); + + expect(createdBy).toEqual(["override", "default"]); + } finally { + await server.close(); + } + }); + + it("keeps concurrent HTTP initialize factory overrides isolated", async () => { + const createdBy: string[] = []; + const server = new AcpServer({ + createAgent: (conn: AgentSideConnection) => { + createdBy.push("default"); + return new TestAgent(conn); + }, + }); + + try { + const first = server.handleRequest(jsonRequest(initializeRequest), { + createAgent: (conn) => { + createdBy.push("first"); + return new TestAgent(conn); + }, + }); + const second = server.handleRequest( + jsonRequest({ ...initializeRequest, id: 2 }), + { + createAgent: (conn) => { + createdBy.push("second"); + return new TestAgent(conn); + }, + }, + ); + + await Promise.all([first, second]); + + expect(createdBy).toEqual(expect.arrayContaining(["first", "second"])); + expect(createdBy).toHaveLength(2); + } finally { + await server.close(); + } + }); + + it("ignores HTTP factory overrides for existing-connection POST requests", async () => { + const createdBy: string[] = []; + const server = new AcpServer({ + createAgent: (conn: AgentSideConnection) => { + createdBy.push("default"); + return new TestAgent(conn); + }, + }); + + try { + const connectionId = await initializeDirect(server); + const response = await server.handleRequest( + jsonRequest(sessionNewRequest, { + [HEADER_CONNECTION_ID]: connectionId, + }), + { + createAgent: (conn) => { + createdBy.push("override"); + return new TestAgent(conn); + }, + }, + ); + + expect(response.status).toBe(202); + expect(createdBy).toEqual(["default"]); + } finally { + await server.close(); + } + }); + + it("ignores HTTP factory overrides for GET and DELETE requests", async () => { + const createdBy: string[] = []; + const server = new AcpServer({ + createAgent: (conn: AgentSideConnection) => { + createdBy.push("default"); + return new TestAgent(conn); + }, + }); + + try { + const connectionId = await initializeDirect(server); + const createAgent = (conn: AgentSideConnection): TestAgent => { + createdBy.push("override"); + return new TestAgent(conn); + }; + const getResponse = await server.handleRequest( + new Request("http://127.0.0.1/acp", { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + }, + }), + { createAgent }, + ); + + expect(getResponse.status).toBe(200); + await getResponse.body?.cancel(); + + const deleteResponse = await server.handleRequest( + new Request("http://127.0.0.1/acp", { + method: "DELETE", + headers: { [HEADER_CONNECTION_ID]: connectionId }, + }), + { createAgent }, + ); + + expect(deleteResponse.status).toBe(202); + expect(createdBy).toEqual(["default"]); + } finally { + await server.close(); + } + }); + it("streams session/new responses over the connection SSE stream", async () => { const server = await startTestServer(); @@ -472,6 +659,30 @@ describe("AcpServer", () => { }); }); +async function initializeDirect(server: AcpServer): Promise { + const response = await server.handleRequest(jsonRequest(initializeRequest)); + const connectionId = response.headers.get(HEADER_CONNECTION_ID); + + expect(response.status).toBe(200); + expect(connectionId).toMatch(/^[0-9a-f-]{36}$/); + + return connectionId ?? ""; +} + +function jsonRequest( + body: unknown, + headers: Record = {}, +): Request { + return new Request("http://127.0.0.1/acp", { + method: "POST", + headers: { + "Content-Type": JSON_MIME_TYPE, + ...headers, + }, + body: JSON.stringify(body), + }); +} + async function initialize(url: string): Promise { const response = await postJson(url, initializeRequest); const connectionId = response.headers.get(HEADER_CONNECTION_ID); diff --git a/src/server.ts b/src/server.ts index e50aabc..b7aa0e2 100644 --- a/src/server.ts +++ b/src/server.ts @@ -28,10 +28,20 @@ import type { AnyResponse, } from "./jsonrpc.js"; +export type AgentFactory = (conn: AgentSideConnection) => Agent; + /** Options for creating an ACP server transport. */ export interface AcpServerOptions { /** Creates the agent implementation for each accepted ACP connection. */ - createAgent: (conn: AgentSideConnection) => Agent; + createAgent: AgentFactory; +} + +export interface HandleRequestOptions { + readonly createAgent?: AgentFactory; +} + +export interface PrepareWebSocketUpgradeOptions { + readonly createAgent?: AgentFactory; } export interface PreparedWebSocketUpgrade { @@ -48,7 +58,7 @@ export interface PreparedWebSocketUpgrade { * the `101 Switching Protocols` response. */ export class AcpServer { - private readonly createAgent: (conn: AgentSideConnection) => Agent; + private readonly createAgent: AgentFactory; private readonly registry = new ConnectionRegistry(); constructor(options: AcpServerOptions) { @@ -56,9 +66,12 @@ export class AcpServer { } /** Handles one Streamable HTTP ACP request. */ - async handleRequest(req: Request): Promise { + async handleRequest( + req: Request, + options: HandleRequestOptions = {}, + ): Promise { if (req.method === "POST") { - return await this.handlePost(req); + return await this.handlePost(req, options); } if (req.method === "GET") { @@ -73,8 +86,11 @@ export class AcpServer { } /** Creates a WebSocket connection before accepting the HTTP upgrade. */ - prepareWebSocketUpgrade(): PreparedWebSocketUpgrade { - const connection = this.registry.createConnection(this.createAgent); + prepareWebSocketUpgrade( + options: PrepareWebSocketUpgradeOptions = {}, + ): PreparedWebSocketUpgrade { + const createAgent = options.createAgent ?? this.createAgent; + const connection = this.registry.createConnection(createAgent); let isSettled = false; return { @@ -87,7 +103,7 @@ export class AcpServer { isSettled = true; handleWebSocketConnection(socket, { registry: this.registry, - createAgent: this.createAgent, + createAgent, connection, }); }, @@ -107,7 +123,10 @@ export class AcpServer { this.registry.closeAll(); } - private async handlePost(req: Request): Promise { + private async handlePost( + req: Request, + options: HandleRequestOptions, + ): Promise { const contentType = req.headers.get("Content-Type"); if (!isJsonContentType(contentType)) { @@ -132,7 +151,7 @@ export class AcpServer { if (isInitializeRequest(body.value)) { if (!connectionId) { - return await this.handleInitialize(body.value); + return await this.handleInitialize(body.value, options); } return textResponse("Initialize not allowed on existing connection", 400); @@ -210,7 +229,10 @@ export class AcpServer { return emptyResponse(202); } - private async handleInitialize(message: AnyMessage): Promise { + private async handleInitialize( + message: AnyMessage, + options: HandleRequestOptions, + ): Promise { if (!("id" in message) || message.id === null) { return textResponse("Initialize request must include an ID", 400); } @@ -220,7 +242,9 @@ export class AcpServer { | undefined; try { - connection = this.registry.createConnection(this.createAgent); + connection = this.registry.createConnection( + options.createAgent ?? this.createAgent, + ); await writeInbound(connection, message); const initialResponse = await connection.recvInitial(message.id);