|
20 | 20 | PromptRequest, |
21 | 21 | PromptResponse, |
22 | 22 | SessionNotification, |
| 23 | + SetSessionModeRequest, |
| 24 | + SetSessionModeResponse, |
23 | 25 | stdio_streams, |
24 | 26 | PROTOCOL_VERSION, |
25 | 27 | ) |
@@ -138,15 +140,51 @@ def _send_cost_hint(self) -> None: |
138 | 140 | except RuntimeError: |
139 | 141 | self._schedule(self._send(hint)) |
140 | 142 |
|
| 143 | + async def on_tool_start(self, title: str, command: str, tool_call_id: str) -> None: |
| 144 | + """Send a tool_call start notification for a bash command.""" |
| 145 | + update = SessionUpdate4( |
| 146 | + sessionUpdate="tool_call", |
| 147 | + toolCallId=tool_call_id, |
| 148 | + title=title, |
| 149 | + kind="execute", |
| 150 | + status="pending", |
| 151 | + content=[ |
| 152 | + ToolCallContent1( |
| 153 | + type="content", content=ContentBlock1(type="text", text=f"```bash\n{command}\n```") |
| 154 | + ) |
| 155 | + ], |
| 156 | + rawInput={"command": command}, |
| 157 | + ) |
| 158 | + await self._send(update) |
| 159 | + |
| 160 | + async def on_tool_complete( |
| 161 | + self, |
| 162 | + tool_call_id: str, |
| 163 | + output: str, |
| 164 | + returncode: int, |
| 165 | + *, |
| 166 | + status: str = "completed", |
| 167 | + ) -> None: |
| 168 | + """Send a tool_call_update with the final output and return code.""" |
| 169 | + update = SessionUpdate5( |
| 170 | + sessionUpdate="tool_call_update", |
| 171 | + toolCallId=tool_call_id, |
| 172 | + status=status, |
| 173 | + content=[ |
| 174 | + ToolCallContent1( |
| 175 | + type="content", content=ContentBlock1(type="text", text=f"```ansi\n{output}\n```") |
| 176 | + ) |
| 177 | + ], |
| 178 | + rawOutput={"output": output, "returncode": returncode}, |
| 179 | + ) |
| 180 | + await self._send(update) |
| 181 | + |
141 | 182 | def add_message(self, role: str, content: str, **kwargs): |
142 | 183 | super().add_message(role, content, **kwargs) |
143 | | - # Only the client should send user_message_chunk. The agent reports its own text via agent_message_chunk. |
144 | | - if not getattr(self, "_emit_updates", True): |
| 184 | + # Only stream LM output as agent_message_chunk; tool output is handled via tool_call_update. |
| 185 | + if not getattr(self, "_emit_updates", True) or role != "assistant": |
145 | 186 | return |
146 | | - # Avoid duplicating tool outputs as a separate "Observation" agent message; rely on tool_call_update |
147 | 187 | text = str(content) |
148 | | - if role == "user" and text.strip().startswith("Observation:"): |
149 | | - return |
150 | 188 | block = ContentBlock1(type="text", text=text) |
151 | 189 | update = SessionUpdate2(sessionUpdate="agent_message_chunk", content=block) |
152 | 190 | try: |
@@ -324,6 +362,15 @@ async def loadSession(self, params) -> None: # type: ignore[override] |
324 | 362 | async def authenticate(self, _params: AuthenticateRequest) -> None: |
325 | 363 | return None |
326 | 364 |
|
| 365 | + async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: # type: ignore[override] |
| 366 | + sess = self._sessions.get(params.sessionId) |
| 367 | + if not sess: |
| 368 | + return SetSessionModeResponse() |
| 369 | + mode = params.modeId.lower() |
| 370 | + if mode in ("confirm", "yolo", "human"): |
| 371 | + sess["config"].mode = mode # type: ignore[attr-defined] |
| 372 | + return SetSessionModeResponse() |
| 373 | + |
327 | 374 | def _extract_mode_from_blocks(self, blocks) -> Literal["confirm", "yolo", "human"] | None: |
328 | 375 | for b in blocks: |
329 | 376 | if getattr(b, "type", None) == "text": |
|
0 commit comments