|
1 | 1 | import asyncio |
2 | | -from dataclasses import dataclass, field |
| 2 | +import logging |
3 | 3 | from typing import Any |
4 | 4 |
|
5 | 5 | from acp import ( |
|
10 | 10 | CancelNotification, |
11 | 11 | InitializeRequest, |
12 | 12 | InitializeResponse, |
| 13 | + LoadSessionRequest, |
| 14 | + LoadSessionResponse, |
13 | 15 | NewSessionRequest, |
14 | 16 | NewSessionResponse, |
15 | 17 | PromptRequest, |
16 | 18 | PromptResponse, |
17 | | - SessionNotification, |
18 | 19 | SetSessionModeRequest, |
19 | 20 | SetSessionModeResponse, |
20 | 21 | stdio_streams, |
21 | 22 | PROTOCOL_VERSION, |
22 | 23 | ) |
23 | 24 | from acp.schema import ( |
| 25 | + AgentCapabilities, |
24 | 26 | AgentMessageChunk, |
25 | | - AllowedOutcome, |
26 | | - ContentToolCallContent, |
27 | | - PermissionOption, |
28 | | - RequestPermissionRequest, |
| 27 | + McpCapabilities, |
| 28 | + PromptCapabilities, |
| 29 | + SessionNotification, |
29 | 30 | TextContentBlock, |
30 | | - ToolCallUpdate, |
31 | 31 | ) |
32 | 32 |
|
33 | 33 |
|
34 | | -@dataclass |
35 | | -class SessionState: |
36 | | - cancel_event: asyncio.Event = field(default_factory=asyncio.Event) |
37 | | - prompt_counter: int = 0 |
38 | | - |
39 | | - def begin_prompt(self) -> None: |
40 | | - self.prompt_counter += 1 |
41 | | - self.cancel_event.clear() |
42 | | - |
43 | | - def cancel(self) -> None: |
44 | | - self.cancel_event.set() |
45 | | - |
46 | | - |
47 | 34 | class ExampleAgent(Agent): |
48 | 35 | def __init__(self, conn: AgentSideConnection) -> None: |
49 | 36 | self._conn = conn |
50 | 37 | self._next_session_id = 0 |
51 | | - self._sessions: dict[str, SessionState] = {} |
52 | | - |
53 | | - def _session(self, session_id: str) -> SessionState: |
54 | | - state = self._sessions.get(session_id) |
55 | | - if state is None: |
56 | | - state = SessionState() |
57 | | - self._sessions[session_id] = state |
58 | | - return state |
59 | 38 |
|
60 | | - async def _send_text(self, session_id: str, text: str) -> None: |
| 39 | + async def _send_chunk(self, session_id: str, content: Any) -> None: |
61 | 40 | await self._conn.sessionUpdate( |
62 | 41 | SessionNotification( |
63 | 42 | sessionId=session_id, |
64 | 43 | update=AgentMessageChunk( |
65 | 44 | sessionUpdate="agent_message_chunk", |
66 | | - content=TextContentBlock(type="text", text=text), |
| 45 | + content=content, |
67 | 46 | ), |
68 | 47 | ) |
69 | 48 | ) |
70 | 49 |
|
71 | | - def _format_prompt_preview(self, blocks: list[Any]) -> str: |
72 | | - parts: list[str] = [] |
73 | | - for block in blocks: |
74 | | - if isinstance(block, dict): |
75 | | - if block.get("type") == "text": |
76 | | - parts.append(str(block.get("text", ""))) |
77 | | - else: |
78 | | - parts.append(f"<{block.get('type', 'content')}>") |
79 | | - else: |
80 | | - parts.append(getattr(block, "text", "<content>")) |
81 | | - preview = " \n".join(filter(None, parts)).strip() |
82 | | - return preview or "<empty prompt>" |
83 | | - |
84 | | - async def _request_permission(self, session_id: str, preview: str, state: SessionState) -> str: |
85 | | - state.prompt_counter += 1 |
86 | | - request = RequestPermissionRequest( |
87 | | - sessionId=session_id, |
88 | | - toolCall=ToolCallUpdate( |
89 | | - toolCallId=f"echo-{state.prompt_counter}", |
90 | | - title="Echo input", |
91 | | - kind="echo", |
92 | | - status="pending", |
93 | | - content=[ |
94 | | - ContentToolCallContent( |
95 | | - type="content", |
96 | | - content=TextContentBlock(type="text", text=preview), |
97 | | - ) |
98 | | - ], |
| 50 | + async def initialize(self, params: InitializeRequest) -> InitializeResponse: # noqa: ARG002 |
| 51 | + logging.info("Received initialize request") |
| 52 | + return InitializeResponse( |
| 53 | + protocolVersion=PROTOCOL_VERSION, |
| 54 | + agentCapabilities=AgentCapabilities( |
| 55 | + loadSession=False, |
| 56 | + mcpCapabilities=McpCapabilities(http=False, sse=False), |
| 57 | + promptCapabilities=PromptCapabilities(audio=False, embeddedContext=False, image=False), |
99 | 58 | ), |
100 | | - options=[ |
101 | | - PermissionOption(optionId="allow-once", name="Allow once", kind="allow_once"), |
102 | | - PermissionOption(optionId="deny", name="Deny", kind="reject_once"), |
103 | | - ], |
104 | 59 | ) |
105 | 60 |
|
106 | | - permission_task = asyncio.create_task(self._conn.requestPermission(request)) |
107 | | - cancel_task = asyncio.create_task(state.cancel_event.wait()) |
108 | | - |
109 | | - done, pending = await asyncio.wait({permission_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED) |
110 | | - |
111 | | - for task in pending: |
112 | | - task.cancel() |
113 | | - |
114 | | - if cancel_task in done: |
115 | | - permission_task.cancel() |
116 | | - return "cancelled" |
117 | | - |
118 | | - try: |
119 | | - response = await permission_task |
120 | | - except asyncio.CancelledError: |
121 | | - return "cancelled" |
122 | | - except Exception as exc: # noqa: BLE001 |
123 | | - await self._send_text(session_id, f"Permission failed: {exc}") |
124 | | - return "error" |
125 | | - |
126 | | - if isinstance(response.outcome, AllowedOutcome): |
127 | | - option_id = response.outcome.optionId |
128 | | - if option_id.startswith("allow"): |
129 | | - return "allowed" |
130 | | - return "denied" |
131 | | - return "cancelled" |
132 | | - |
133 | | - async def initialize(self, params: InitializeRequest) -> InitializeResponse: |
134 | | - return InitializeResponse(protocolVersion=PROTOCOL_VERSION, agentCapabilities=None, authMethods=[]) |
135 | | - |
136 | 61 | async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: # noqa: ARG002 |
137 | | - return {} |
| 62 | + logging.info("Received authenticate request") |
| 63 | + return AuthenticateResponse() |
138 | 64 |
|
139 | 65 | async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: # noqa: ARG002 |
140 | | - session_id = f"sess-{self._next_session_id}" |
| 66 | + logging.info("Received new session request") |
| 67 | + session_id = str(self._next_session_id) |
141 | 68 | self._next_session_id += 1 |
142 | | - self._sessions[session_id] = SessionState() |
143 | 69 | return NewSessionResponse(sessionId=session_id) |
144 | 70 |
|
145 | | - async def loadSession(self, params): # type: ignore[override] |
146 | | - return None |
| 71 | + async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse | None: # noqa: ARG002 |
| 72 | + logging.info("Received load session request") |
| 73 | + return LoadSessionResponse() |
147 | 74 |
|
148 | 75 | async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: # noqa: ARG002 |
149 | | - return {} |
| 76 | + logging.info("Received set session mode request") |
| 77 | + return SetSessionModeResponse() |
150 | 78 |
|
151 | 79 | async def prompt(self, params: PromptRequest) -> PromptResponse: |
152 | | - state = self._session(params.sessionId) |
153 | | - state.begin_prompt() |
154 | | - |
155 | | - preview = self._format_prompt_preview(list(params.prompt)) |
156 | | - await self._send_text(params.sessionId, "Agent received a prompt. Checking permissions...") |
157 | | - |
158 | | - decision = await self._request_permission(params.sessionId, preview, state) |
159 | | - if decision == "cancelled": |
160 | | - await self._send_text(params.sessionId, "Prompt cancelled before permission decided.") |
161 | | - return PromptResponse(stopReason="cancelled") |
162 | | - if decision == "denied": |
163 | | - await self._send_text(params.sessionId, "Permission denied by the client.") |
164 | | - return PromptResponse(stopReason="permission_denied") |
165 | | - if decision == "error": |
166 | | - return PromptResponse(stopReason="error") |
167 | | - |
168 | | - await self._send_text(params.sessionId, "Permission granted. Echoing content:") |
| 80 | + logging.info("Received prompt request") |
169 | 81 |
|
| 82 | + # Notify the client what it just sent and then echo each content block back. |
| 83 | + await self._send_chunk( |
| 84 | + params.sessionId, |
| 85 | + TextContentBlock(type="text", text="Client sent:"), |
| 86 | + ) |
170 | 87 | for block in params.prompt: |
171 | | - if state.cancel_event.is_set(): |
172 | | - await self._send_text(params.sessionId, "Prompt interrupted by cancellation.") |
173 | | - return PromptResponse(stopReason="cancelled") |
174 | | - text = self._format_prompt_preview([block]) |
175 | | - await self._send_text(params.sessionId, text) |
176 | | - await asyncio.sleep(0.4) |
| 88 | + await self._send_chunk(params.sessionId, block) |
177 | 89 |
|
178 | 90 | return PromptResponse(stopReason="end_turn") |
179 | 91 |
|
180 | 92 | async def cancel(self, params: CancelNotification) -> None: # noqa: ARG002 |
181 | | - state = self._sessions.get(params.sessionId) |
182 | | - if state: |
183 | | - state.cancel() |
184 | | - await self._send_text(params.sessionId, "Agent received cancel signal.") |
| 93 | + logging.info("Received cancel notification") |
185 | 94 |
|
186 | 95 | async def extMethod(self, method: str, params: dict) -> dict: # noqa: ARG002 |
| 96 | + logging.info("Received extension method call: %s", method) |
187 | 97 | return {"example": "response"} |
188 | 98 |
|
189 | 99 | async def extNotification(self, method: str, params: dict) -> None: # noqa: ARG002 |
190 | | - return None |
| 100 | + logging.info("Received extension notification: %s", method) |
191 | 101 |
|
192 | 102 |
|
193 | 103 | async def main() -> None: |
| 104 | + logging.basicConfig(level=logging.INFO) |
194 | 105 | reader, writer = await stdio_streams() |
195 | | - AgentSideConnection(lambda conn: ExampleAgent(conn), writer, reader) |
| 106 | + AgentSideConnection(ExampleAgent, writer, reader) |
196 | 107 | await asyncio.Event().wait() |
197 | 108 |
|
198 | 109 |
|
|
0 commit comments