Skip to content

Commit e21d8d0

Browse files
authored
fix: allow omitting mcpServers in session requests (#58)
* fix: allow omitting mcpServers in session requests * fix: warn on positional load_session session_id * Revert "fix: warn on positional load_session session_id" This reverts commit 1f21da8. * Revert "fix: allow omitting mcpServers in session requests" This reverts commit 73499a8. * fix: default mcpServers to empty list
1 parent b4f253c commit e21d8d0

File tree

4 files changed

+109
-13
lines changed

4 files changed

+109
-13
lines changed

scripts/gen_signature.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,19 @@
99

1010
from acp import schema
1111

12+
SIGNATURE_OPTIONAL_FIELDS: set[tuple[str, str]] = {
13+
("LoadSessionRequest", "mcp_servers"),
14+
("NewSessionRequest", "mcp_servers"),
15+
}
16+
1217

1318
class NodeTransformer(ast.NodeTransformer):
1419
def __init__(self) -> None:
1520
self._type_import_node: ast.ImportFrom | None = None
1621
self._schema_import_node: ast.ImportFrom | None = None
1722
self._should_rewrite = False
1823
self._literals = {name: value for name, value in schema.__dict__.items() if t.get_origin(value) is t.Literal}
24+
self._current_model_name: str | None = None
1925

2026
def _add_typing_import(self, name: str) -> None:
2127
if not self._type_import_node:
@@ -71,9 +77,13 @@ def visit_func(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.AST:
7177
self._should_rewrite = True
7278
model_name = t.cast(ast.Name, decorator.args[0]).id
7379
model = t.cast(type[schema.BaseModel], getattr(schema, model_name))
74-
param_defaults = [
75-
self._to_param_def(name, field) for name, field in model.model_fields.items() if name != "field_meta"
76-
]
80+
self._current_model_name = model_name
81+
try:
82+
param_defaults = [
83+
self._to_param_def(name, field) for name, field in model.model_fields.items() if name != "field_meta"
84+
]
85+
finally:
86+
self._current_model_name = None
7787
param_defaults.sort(key=lambda x: x[1] is not None)
7888
node.args.args[1:] = [param for param, _ in param_defaults]
7989
node.args.defaults = [default for _, default in param_defaults if default is not None]
@@ -84,12 +94,18 @@ def visit_func(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.AST:
8494
def _to_param_def(self, name: str, field: FieldInfo) -> tuple[ast.arg, ast.expr | None]:
8595
arg = ast.arg(arg=name)
8696
ann = field.annotation
87-
if field.default is PydanticUndefined:
88-
default = None
89-
elif isinstance(field.default, dict | BaseModel):
97+
override_optional = (self._current_model_name, name) in SIGNATURE_OPTIONAL_FIELDS
98+
if override_optional:
99+
if ann is not None:
100+
ann = ann | None
90101
default = ast.Constant(None)
91102
else:
92-
default = ast.Constant(value=field.default)
103+
if field.default is PydanticUndefined:
104+
default = None
105+
elif isinstance(field.default, dict | BaseModel):
106+
default = ast.Constant(None)
107+
else:
108+
default = ast.Constant(value=field.default)
93109
if ann is not None:
94110
arg.annotation = self._format_annotation(ann)
95111
return arg, default

src/acp/client/connection.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,23 +93,31 @@ async def initialize(
9393

9494
@param_model(NewSessionRequest)
9595
async def new_session(
96-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any
96+
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, **kwargs: Any
9797
) -> NewSessionResponse:
98+
resolved_mcp_servers = mcp_servers or []
9899
return await request_model(
99100
self._conn,
100101
AGENT_METHODS["session_new"],
101-
NewSessionRequest(cwd=cwd, mcp_servers=mcp_servers, field_meta=kwargs or None),
102+
NewSessionRequest(cwd=cwd, mcp_servers=resolved_mcp_servers, field_meta=kwargs or None),
102103
NewSessionResponse,
103104
)
104105

105106
@param_model(LoadSessionRequest)
106107
async def load_session(
107-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any
108+
self,
109+
cwd: str,
110+
session_id: str,
111+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
112+
**kwargs: Any,
108113
) -> LoadSessionResponse:
114+
resolved_mcp_servers = mcp_servers or []
109115
return await request_model_from_dict(
110116
self._conn,
111117
AGENT_METHODS["session_load"],
112-
LoadSessionRequest(cwd=cwd, mcp_servers=mcp_servers, session_id=session_id, field_meta=kwargs or None),
118+
LoadSessionRequest(
119+
cwd=cwd, mcp_servers=resolved_mcp_servers, session_id=session_id, field_meta=kwargs or None
120+
),
113121
LoadSessionResponse,
114122
)
115123

src/acp/interfaces.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,16 @@ async def initialize(
154154

155155
@param_model(NewSessionRequest)
156156
async def new_session(
157-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any
157+
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, **kwargs: Any
158158
) -> NewSessionResponse: ...
159159

160160
@param_model(LoadSessionRequest)
161161
async def load_session(
162-
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any
162+
self,
163+
cwd: str,
164+
session_id: str,
165+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
166+
**kwargs: Any,
163167
) -> LoadSessionResponse | None: ...
164168

165169
@param_model(ListSessionsRequest)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import asyncio
2+
from typing import Any
3+
4+
import pytest
5+
6+
from acp import InitializeResponse, LoadSessionResponse, NewSessionResponse
7+
from acp.core import AgentSideConnection, ClientSideConnection
8+
from acp.schema import HttpMcpServer, McpServerStdio, SseMcpServer
9+
from tests.conftest import TestAgent, TestClient
10+
11+
12+
class McpOptionalAgent(TestAgent):
13+
def __init__(self) -> None:
14+
super().__init__()
15+
self.seen_new_session: tuple[str, Any] | None = None
16+
self.seen_load_session: tuple[str, str, Any] | None = None
17+
18+
async def new_session(
19+
self,
20+
cwd: str,
21+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
22+
**kwargs: Any,
23+
) -> NewSessionResponse:
24+
resolved_mcp_servers = mcp_servers or []
25+
self.seen_new_session = (cwd, resolved_mcp_servers)
26+
return await super().new_session(cwd=cwd, mcp_servers=resolved_mcp_servers, **kwargs)
27+
28+
async def load_session(
29+
self,
30+
cwd: str,
31+
session_id: str,
32+
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
33+
**kwargs: Any,
34+
) -> LoadSessionResponse | None:
35+
resolved_mcp_servers = mcp_servers or []
36+
self.seen_load_session = (cwd, session_id, resolved_mcp_servers)
37+
return await super().load_session(cwd=cwd, session_id=session_id, mcp_servers=resolved_mcp_servers, **kwargs)
38+
39+
40+
@pytest.mark.asyncio
41+
async def test_session_requests_default_empty_mcp_servers(server) -> None:
42+
client = TestClient()
43+
captured_agent: list[McpOptionalAgent] = []
44+
45+
agent_conn = ClientSideConnection(client, server._client_writer, server._client_reader) # type: ignore[arg-type]
46+
_agent_side = AgentSideConnection(
47+
lambda _conn: captured_agent.append(McpOptionalAgent()) or captured_agent[-1],
48+
server._server_writer,
49+
server._server_reader,
50+
listening=True,
51+
)
52+
53+
init = await asyncio.wait_for(agent_conn.initialize(protocol_version=1), timeout=1.0)
54+
assert isinstance(init, InitializeResponse)
55+
56+
new_session = await asyncio.wait_for(agent_conn.new_session(cwd="/workspace"), timeout=1.0)
57+
assert isinstance(new_session, NewSessionResponse)
58+
59+
load_session = await asyncio.wait_for(
60+
agent_conn.load_session(cwd="/workspace", session_id=new_session.session_id),
61+
timeout=1.0,
62+
)
63+
assert isinstance(load_session, LoadSessionResponse)
64+
65+
assert captured_agent, "Agent was not constructed"
66+
[agent] = captured_agent
67+
assert agent.seen_new_session == ("/workspace", [])
68+
assert agent.seen_load_session == ("/workspace", new_session.session_id, [])

0 commit comments

Comments
 (0)