Skip to content

Commit e8b1f76

Browse files
author
Chojan Shang
committed
refactor(examples): make basic example better
Signed-off-by: Chojan Shang <chojan.shang@vesoft.com>
1 parent d5feb89 commit e8b1f76

File tree

2 files changed

+130
-280
lines changed

2 files changed

+130
-280
lines changed

examples/agent.py

Lines changed: 38 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from dataclasses import dataclass, field
2+
import logging
33
from typing import Any
44

55
from acp import (
@@ -10,189 +10,100 @@
1010
CancelNotification,
1111
InitializeRequest,
1212
InitializeResponse,
13+
LoadSessionRequest,
14+
LoadSessionResponse,
1315
NewSessionRequest,
1416
NewSessionResponse,
1517
PromptRequest,
1618
PromptResponse,
17-
SessionNotification,
1819
SetSessionModeRequest,
1920
SetSessionModeResponse,
2021
stdio_streams,
2122
PROTOCOL_VERSION,
2223
)
2324
from acp.schema import (
25+
AgentCapabilities,
2426
AgentMessageChunk,
25-
AllowedOutcome,
26-
ContentToolCallContent,
27-
PermissionOption,
28-
RequestPermissionRequest,
27+
McpCapabilities,
28+
PromptCapabilities,
29+
SessionNotification,
2930
TextContentBlock,
30-
ToolCallUpdate,
3131
)
3232

3333

34-
@dataclass
35-
class SessionState:
36-
cancel_event: asyncio.Event = field(default_factory=asyncio.Event)
37-
prompt_counter: int = 0
38-
39-
def begin_prompt(self) -> None:
40-
self.prompt_counter += 1
41-
self.cancel_event.clear()
42-
43-
def cancel(self) -> None:
44-
self.cancel_event.set()
45-
46-
4734
class ExampleAgent(Agent):
4835
def __init__(self, conn: AgentSideConnection) -> None:
4936
self._conn = conn
5037
self._next_session_id = 0
51-
self._sessions: dict[str, SessionState] = {}
52-
53-
def _session(self, session_id: str) -> SessionState:
54-
state = self._sessions.get(session_id)
55-
if state is None:
56-
state = SessionState()
57-
self._sessions[session_id] = state
58-
return state
5938

60-
async def _send_text(self, session_id: str, text: str) -> None:
39+
async def _send_chunk(self, session_id: str, content: Any) -> None:
6140
await self._conn.sessionUpdate(
6241
SessionNotification(
6342
sessionId=session_id,
6443
update=AgentMessageChunk(
6544
sessionUpdate="agent_message_chunk",
66-
content=TextContentBlock(type="text", text=text),
45+
content=content,
6746
),
6847
)
6948
)
7049

71-
def _format_prompt_preview(self, blocks: list[Any]) -> str:
72-
parts: list[str] = []
73-
for block in blocks:
74-
if isinstance(block, dict):
75-
if block.get("type") == "text":
76-
parts.append(str(block.get("text", "")))
77-
else:
78-
parts.append(f"<{block.get('type', 'content')}>")
79-
else:
80-
parts.append(getattr(block, "text", "<content>"))
81-
preview = " \n".join(filter(None, parts)).strip()
82-
return preview or "<empty prompt>"
83-
84-
async def _request_permission(self, session_id: str, preview: str, state: SessionState) -> str:
85-
state.prompt_counter += 1
86-
request = RequestPermissionRequest(
87-
sessionId=session_id,
88-
toolCall=ToolCallUpdate(
89-
toolCallId=f"echo-{state.prompt_counter}",
90-
title="Echo input",
91-
kind="echo",
92-
status="pending",
93-
content=[
94-
ContentToolCallContent(
95-
type="content",
96-
content=TextContentBlock(type="text", text=preview),
97-
)
98-
],
50+
async def initialize(self, params: InitializeRequest) -> InitializeResponse: # noqa: ARG002
51+
logging.info("Received initialize request")
52+
return InitializeResponse(
53+
protocolVersion=PROTOCOL_VERSION,
54+
agentCapabilities=AgentCapabilities(
55+
loadSession=False,
56+
mcpCapabilities=McpCapabilities(http=False, sse=False),
57+
promptCapabilities=PromptCapabilities(audio=False, embeddedContext=False, image=False),
9958
),
100-
options=[
101-
PermissionOption(optionId="allow-once", name="Allow once", kind="allow_once"),
102-
PermissionOption(optionId="deny", name="Deny", kind="reject_once"),
103-
],
10459
)
10560

106-
permission_task = asyncio.create_task(self._conn.requestPermission(request))
107-
cancel_task = asyncio.create_task(state.cancel_event.wait())
108-
109-
done, pending = await asyncio.wait({permission_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
110-
111-
for task in pending:
112-
task.cancel()
113-
114-
if cancel_task in done:
115-
permission_task.cancel()
116-
return "cancelled"
117-
118-
try:
119-
response = await permission_task
120-
except asyncio.CancelledError:
121-
return "cancelled"
122-
except Exception as exc: # noqa: BLE001
123-
await self._send_text(session_id, f"Permission failed: {exc}")
124-
return "error"
125-
126-
if isinstance(response.outcome, AllowedOutcome):
127-
option_id = response.outcome.optionId
128-
if option_id.startswith("allow"):
129-
return "allowed"
130-
return "denied"
131-
return "cancelled"
132-
133-
async def initialize(self, params: InitializeRequest) -> InitializeResponse:
134-
return InitializeResponse(protocolVersion=PROTOCOL_VERSION, agentCapabilities=None, authMethods=[])
135-
13661
async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: # noqa: ARG002
137-
return {}
62+
logging.info("Received authenticate request")
63+
return AuthenticateResponse()
13864

13965
async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: # noqa: ARG002
140-
session_id = f"sess-{self._next_session_id}"
66+
logging.info("Received new session request")
67+
session_id = str(self._next_session_id)
14168
self._next_session_id += 1
142-
self._sessions[session_id] = SessionState()
14369
return NewSessionResponse(sessionId=session_id)
14470

145-
async def loadSession(self, params): # type: ignore[override]
146-
return None
71+
async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse | None: # noqa: ARG002
72+
logging.info("Received load session request")
73+
return LoadSessionResponse()
14774

14875
async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: # noqa: ARG002
149-
return {}
76+
logging.info("Received set session mode request")
77+
return SetSessionModeResponse()
15078

15179
async def prompt(self, params: PromptRequest) -> PromptResponse:
152-
state = self._session(params.sessionId)
153-
state.begin_prompt()
154-
155-
preview = self._format_prompt_preview(list(params.prompt))
156-
await self._send_text(params.sessionId, "Agent received a prompt. Checking permissions...")
157-
158-
decision = await self._request_permission(params.sessionId, preview, state)
159-
if decision == "cancelled":
160-
await self._send_text(params.sessionId, "Prompt cancelled before permission decided.")
161-
return PromptResponse(stopReason="cancelled")
162-
if decision == "denied":
163-
await self._send_text(params.sessionId, "Permission denied by the client.")
164-
return PromptResponse(stopReason="permission_denied")
165-
if decision == "error":
166-
return PromptResponse(stopReason="error")
167-
168-
await self._send_text(params.sessionId, "Permission granted. Echoing content:")
80+
logging.info("Received prompt request")
16981

82+
# Notify the client what it just sent and then echo each content block back.
83+
await self._send_chunk(
84+
params.sessionId,
85+
TextContentBlock(type="text", text="Client sent:"),
86+
)
17087
for block in params.prompt:
171-
if state.cancel_event.is_set():
172-
await self._send_text(params.sessionId, "Prompt interrupted by cancellation.")
173-
return PromptResponse(stopReason="cancelled")
174-
text = self._format_prompt_preview([block])
175-
await self._send_text(params.sessionId, text)
176-
await asyncio.sleep(0.4)
88+
await self._send_chunk(params.sessionId, block)
17789

17890
return PromptResponse(stopReason="end_turn")
17991

18092
async def cancel(self, params: CancelNotification) -> None: # noqa: ARG002
181-
state = self._sessions.get(params.sessionId)
182-
if state:
183-
state.cancel()
184-
await self._send_text(params.sessionId, "Agent received cancel signal.")
93+
logging.info("Received cancel notification")
18594

18695
async def extMethod(self, method: str, params: dict) -> dict: # noqa: ARG002
96+
logging.info("Received extension method call: %s", method)
18797
return {"example": "response"}
18898

18999
async def extNotification(self, method: str, params: dict) -> None: # noqa: ARG002
190-
return None
100+
logging.info("Received extension notification: %s", method)
191101

192102

193103
async def main() -> None:
104+
logging.basicConfig(level=logging.INFO)
194105
reader, writer = await stdio_streams()
195-
AgentSideConnection(lambda conn: ExampleAgent(conn), writer, reader)
106+
AgentSideConnection(ExampleAgent, writer, reader)
196107
await asyncio.Event().wait()
197108

198109

0 commit comments

Comments
 (0)