Skip to content

Commit 446e856

Browse files
committed
fix: allow omitting mcpServers in session requests
1 parent adaff6d commit 446e856

File tree

7 files changed

+191
-16
lines changed

7 files changed

+191
-16
lines changed

scripts/gen_schema.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,23 @@
22
from __future__ import annotations
33

44
import ast
5+
import contextlib
56
import json
67
import re
78
import subprocess
89
import sys
10+
import tempfile
911
import textwrap
1012
from collections.abc import Callable
1113
from dataclasses import dataclass
1214
from pathlib import Path
1315

1416
ROOT = Path(__file__).resolve().parents[1]
17+
if str(ROOT) not in sys.path:
18+
sys.path.append(str(ROOT))
19+
20+
from scripts.schema_patches import apply_schema_patches # noqa: E402
21+
1522
SCHEMA_DIR = ROOT / "schema"
1623
SCHEMA_JSON = SCHEMA_DIR / "schema.json"
1724
VERSION_FILE = SCHEMA_DIR / "VERSION"
@@ -135,12 +142,23 @@ def generate_schema() -> None:
135142
)
136143
sys.exit(1)
137144

145+
schema_payload = json.loads(SCHEMA_JSON.read_text(encoding="utf-8"))
146+
schema_payload, patch_warnings = apply_schema_patches(schema_payload)
147+
for warning in patch_warnings:
148+
print(f"Warning: {warning.message}", file=sys.stderr)
149+
150+
patched_schema_path: Path | None = None
151+
with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False, encoding="utf-8") as handle:
152+
json.dump(schema_payload, handle, indent=2)
153+
handle.write("\n")
154+
patched_schema_path = Path(handle.name)
155+
138156
cmd = [
139157
sys.executable,
140158
"-m",
141159
"datamodel_code_generator",
142160
"--input",
143-
str(SCHEMA_JSON),
161+
str(patched_schema_path),
144162
"--input-file-type",
145163
"jsonschema",
146164
"--output",
@@ -154,10 +172,15 @@ def generate_schema() -> None:
154172
"--snake-case-field",
155173
]
156174

157-
subprocess.check_call(cmd) # noqa: S603
158-
warnings = postprocess_generated_schema(SCHEMA_OUT)
159-
for warning in warnings:
160-
print(f"Warning: {warning}", file=sys.stderr)
175+
try:
176+
subprocess.check_call(cmd) # noqa: S603
177+
warnings = postprocess_generated_schema(SCHEMA_OUT)
178+
for warning in warnings:
179+
print(f"Warning: {warning}", file=sys.stderr)
180+
finally:
181+
if patched_schema_path is not None:
182+
with contextlib.suppress(OSError):
183+
patched_schema_path.unlink()
161184

162185

163186
def postprocess_generated_schema(output_path: Path) -> list[str]:

scripts/schema_patches.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any
5+
6+
7+
@dataclass(frozen=True, slots=True)
8+
class PatchWarning:
9+
message: str
10+
11+
12+
def apply_schema_patches(schema: dict[str, Any]) -> tuple[dict[str, Any], list[PatchWarning]]:
13+
patched = schema
14+
warnings: list[PatchWarning] = []
15+
16+
patched, warning = _make_defs_field_optional(patched, "NewSessionRequest", "mcpServers")
17+
if warning is not None:
18+
warnings.append(warning)
19+
20+
patched, warning = _make_defs_field_optional(patched, "LoadSessionRequest", "mcpServers")
21+
if warning is not None:
22+
warnings.append(warning)
23+
24+
return patched, warnings
25+
26+
27+
def _make_defs_field_optional(
28+
schema: dict[str, Any],
29+
model_name: str,
30+
field_name: str,
31+
) -> tuple[dict[str, Any], PatchWarning | None]:
32+
defs = schema.get("$defs")
33+
if not isinstance(defs, dict):
34+
return schema, PatchWarning("schema.$defs missing or invalid; cannot apply patches")
35+
36+
model = defs.get(model_name)
37+
if not isinstance(model, dict):
38+
return schema, PatchWarning(f"schema.$defs.{model_name} missing or invalid; cannot patch {field_name}")
39+
40+
required = model.get("required")
41+
if required is None:
42+
return schema, None
43+
if not isinstance(required, list):
44+
return schema, PatchWarning(f"schema.$defs.{model_name}.required invalid; cannot patch {field_name}")
45+
46+
new_required = [item for item in required if item != field_name]
47+
if new_required == required:
48+
return schema, None
49+
50+
model["required"] = new_required
51+
return schema, None

src/acp/client/connection.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
__all__ = ["ClientSideConnection"]
4747
_CLIENT_CONNECTION_ERROR = "ClientSideConnection requires asyncio StreamWriter/StreamReader"
48+
_MISSING = object()
4849

4950

5051
@final
@@ -93,7 +94,10 @@ async def initialize(
9394

9495
@param_model(NewSessionRequest)
9596
async def new_session(
96-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any
97+
self,
98+
cwd: str,
99+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
100+
**kwargs: Any,
97101
) -> NewSessionResponse:
98102
return await request_model(
99103
self._conn,
@@ -104,12 +108,27 @@ async def new_session(
104108

105109
@param_model(LoadSessionRequest)
106110
async def load_session(
107-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any
111+
self,
112+
cwd: str,
113+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | str | None = None,
114+
session_id: str | object = _MISSING,
115+
**kwargs: Any,
108116
) -> LoadSessionResponse:
117+
if session_id is _MISSING:
118+
if isinstance(mcp_servers, str):
119+
session_id = mcp_servers
120+
mcp_servers = None
121+
else:
122+
raise TypeError("load_session() missing required argument: 'session_id'")
109123
return await request_model_from_dict(
110124
self._conn,
111125
AGENT_METHODS["session_load"],
112-
LoadSessionRequest(cwd=cwd, mcp_servers=mcp_servers, session_id=session_id, field_meta=kwargs or None),
126+
LoadSessionRequest(
127+
cwd=cwd,
128+
mcp_servers=cast(list[HttpMcpServer | SseMcpServer | McpServerStdio] | None, mcp_servers),
129+
session_id=cast(str, session_id),
130+
field_meta=kwargs or None,
131+
),
113132
LoadSessionResponse,
114133
)
115134

src/acp/interfaces.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,19 @@ async def initialize(
152152

153153
@param_model(NewSessionRequest)
154154
async def new_session(
155-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any
155+
self,
156+
cwd: str,
157+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
158+
**kwargs: Any,
156159
) -> NewSessionResponse: ...
157160

158161
@param_model(LoadSessionRequest)
159162
async def load_session(
160-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any
163+
self,
164+
cwd: str,
165+
session_id: str,
166+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
167+
**kwargs: Any,
161168
) -> LoadSessionResponse | None: ...
162169

163170
@param_model(ListSessionsRequest)

src/acp/schema.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,12 +1337,12 @@ class NewSessionRequest(BaseModel):
13371337
]
13381338
# List of MCP (Model Context Protocol) servers the agent should connect to.
13391339
mcp_servers: Annotated[
1340-
List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]],
1340+
Optional[List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]]],
13411341
Field(
13421342
alias="mcpServers",
13431343
description="List of MCP (Model Context Protocol) servers the agent should connect to.",
13441344
),
1345-
]
1345+
] = None
13461346

13471347

13481348
class PermissionOption(BaseModel):
@@ -1985,12 +1985,12 @@ class LoadSessionRequest(BaseModel):
19851985
cwd: Annotated[str, Field(description="The working directory for this session.")]
19861986
# List of MCP servers to connect to for this session.
19871987
mcp_servers: Annotated[
1988-
List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]],
1988+
Optional[List[Union[HttpMcpServer, SseMcpServer, McpServerStdio]]],
19891989
Field(
19901990
alias="mcpServers",
19911991
description="List of MCP servers to connect to for this session.",
19921992
),
1993-
]
1993+
] = None
19941994
# The ID of the session to load.
19951995
session_id: Annotated[str, Field(alias="sessionId", description="The ID of the session to load.")]
19961996

tests/conftest.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,19 @@ async def initialize(
243243
return InitializeResponse(protocol_version=protocol_version, agent_capabilities=None, auth_methods=[])
244244

245245
async def new_session(
246-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any
246+
self,
247+
cwd: str,
248+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
249+
**kwargs: Any,
247250
) -> NewSessionResponse:
248251
return NewSessionResponse(session_id="test-session-123")
249252

250253
async def load_session(
251-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any
254+
self,
255+
cwd: str,
256+
session_id: str,
257+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
258+
**kwargs: Any,
252259
) -> LoadSessionResponse | None:
253260
return LoadSessionResponse()
254261

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import asyncio
2+
from typing import Any
3+
4+
import pytest
5+
6+
from acp import InitializeResponse, LoadSessionResponse, NewSessionResponse
7+
from acp.core import AgentSideConnection, ClientSideConnection
8+
from acp.schema import HttpMcpServer, McpServerStdio, SseMcpServer
9+
from tests.conftest import TestAgent, TestClient
10+
11+
# Regression from a real-world client run where `mcpServers` is omitted from session requests.
12+
13+
14+
class Issue55Agent(TestAgent):
15+
def __init__(self) -> None:
16+
super().__init__()
17+
self.seen_new_session: tuple[str, Any] | None = None
18+
self.seen_load_session: tuple[str, str, Any] | None = None
19+
20+
async def new_session(
21+
self,
22+
cwd: str,
23+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
24+
**kwargs: Any,
25+
) -> NewSessionResponse:
26+
self.seen_new_session = (cwd, mcp_servers)
27+
return await super().new_session(cwd=cwd, mcp_servers=mcp_servers, **kwargs)
28+
29+
async def load_session(
30+
self,
31+
cwd: str,
32+
session_id: str,
33+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
34+
**kwargs: Any,
35+
) -> LoadSessionResponse | None:
36+
self.seen_load_session = (cwd, session_id, mcp_servers)
37+
return await super().load_session(cwd=cwd, session_id=session_id, mcp_servers=mcp_servers, **kwargs)
38+
39+
40+
@pytest.mark.asyncio
41+
async def test_session_requests_allow_missing_mcp_servers(server) -> None:
42+
client = TestClient()
43+
captured_agent: list[Issue55Agent] = []
44+
45+
agent_conn = ClientSideConnection(client, server._client_writer, server._client_reader) # type: ignore[arg-type]
46+
_agent_side = AgentSideConnection(
47+
lambda _conn: captured_agent.append(Issue55Agent()) or captured_agent[-1],
48+
server._server_writer,
49+
server._server_reader,
50+
listening=True,
51+
)
52+
53+
init = await asyncio.wait_for(agent_conn.initialize(protocol_version=1), timeout=1.0)
54+
assert isinstance(init, InitializeResponse)
55+
56+
new_session = await asyncio.wait_for(agent_conn.new_session(cwd="/workspace"), timeout=1.0)
57+
assert isinstance(new_session, NewSessionResponse)
58+
59+
load_session = await asyncio.wait_for(
60+
agent_conn.load_session(cwd="/workspace", session_id=new_session.session_id),
61+
timeout=1.0,
62+
)
63+
assert isinstance(load_session, LoadSessionResponse)
64+
65+
assert captured_agent, "Agent was not constructed"
66+
[agent] = captured_agent
67+
assert agent.seen_new_session == ("/workspace", None)
68+
assert agent.seen_load_session == ("/workspace", new_session.session_id, None)

0 commit comments

Comments
 (0)