Skip to content

Commit 0e504fa

Browse files
committed
Address feedback
🏠 Remote-Dev: homespace
1 parent 51372a3 commit 0e504fa

File tree

20 files changed

+322
-284
lines changed

20 files changed

+322
-284
lines changed

src/mcp/client/_memory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import uuid
56
from collections.abc import AsyncIterator
67
from contextlib import AbstractAsyncContextManager, asynccontextmanager
78
from types import TracebackType
@@ -50,12 +51,14 @@ async def _connect(self) -> AsyncIterator[TransportStreams]:
5051

5152
async with anyio.create_task_group() as tg:
5253
# Start server in background
54+
memory_session_id = uuid.uuid4().hex
5355
tg.start_soon(
5456
lambda: actual_server.run(
5557
server_read,
5658
server_write,
5759
actual_server.create_initialization_options(),
5860
raise_exceptions=self._raise_exceptions,
61+
session_id=memory_session_id,
5962
)
6063
)
6164

src/mcp/server/experimental/request_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ async def work(task: ServerTaskContext) -> CallToolResult:
188188
task_group = support.task_group
189189

190190
session_id = self._session.session_id
191+
if session_id is None:
192+
raise RuntimeError("Session ID is required for task operations but session has no ID.")
191193
task = await support.store.create_task(self.task_metadata, task_id, session_id=session_id)
192194

193195
task_ctx = ServerTaskContext(

src/mcp/server/experimental/task_context.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,11 @@ def __init__(
8888
queue: The message queue for elicitation/sampling
8989
handler: The result handler for response routing (required for elicit/create_message)
9090
"""
91-
self._ctx = TaskContext(task=task, store=store)
91+
session_id = session.session_id
92+
if session_id is None:
93+
raise RuntimeError("Session ID is required for task operations but session has no ID.")
94+
self._session_id = session_id
95+
self._ctx = TaskContext(task=task, store=store, session_id=session_id)
9296
self._session = session
9397
self._queue = queue
9498
self._handler = handler
@@ -210,7 +214,7 @@ async def elicit(
210214
raise RuntimeError("handler is required for elicit(). Pass handler= to ServerTaskContext.")
211215

212216
# Update status to input_required
213-
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
217+
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id)
214218

215219
# Build the request using session's helper
216220
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
@@ -234,12 +238,12 @@ async def elicit(
234238
try:
235239
# Wait for response (routed back via TaskResultHandler)
236240
response_data = await resolver.wait()
237-
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
241+
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
238242
return ElicitResult.model_validate(response_data)
239243
except anyio.get_cancelled_exc_class():
240244
# This path is tested in test_elicit_restores_status_on_cancellation
241245
# which verifies status is restored to "working" after cancellation.
242-
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
246+
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
243247
raise
244248

245249
async def elicit_url(
@@ -279,7 +283,7 @@ async def elicit_url(
279283
raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.")
280284

281285
# Update status to input_required
282-
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
286+
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id)
283287

284288
# Build the request using session's helper
285289
request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage]
@@ -304,10 +308,10 @@ async def elicit_url(
304308
try:
305309
# Wait for response (routed back via TaskResultHandler)
306310
response_data = await resolver.wait()
307-
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
311+
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
308312
return ElicitResult.model_validate(response_data)
309313
except anyio.get_cancelled_exc_class(): # pragma: no cover
310-
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
314+
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
311315
raise
312316

313317
async def create_message(
@@ -362,7 +366,7 @@ async def create_message(
362366
raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.")
363367

364368
# Update status to input_required
365-
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
369+
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id)
366370

367371
# Build the request using session's helper
368372
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
@@ -394,12 +398,12 @@ async def create_message(
394398
try:
395399
# Wait for response (routed back via TaskResultHandler)
396400
response_data = await resolver.wait()
397-
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
401+
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
398402
return CreateMessageResult.model_validate(response_data)
399403
except anyio.get_cancelled_exc_class():
400404
# This path is tested in test_create_message_restores_status_on_cancellation
401405
# which verifies status is restored to "working" after cancellation.
402-
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
406+
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
403407
raise
404408

405409
async def elicit_as_task(
@@ -435,7 +439,7 @@ async def elicit_as_task(
435439
raise RuntimeError("handler is required for elicit_as_task()")
436440

437441
# Update status to input_required
438-
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
442+
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id)
439443

440444
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
441445
message=message,
@@ -472,11 +476,11 @@ async def elicit_as_task(
472476
ElicitResult,
473477
)
474478

475-
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
479+
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
476480
return result
477481

478482
except anyio.get_cancelled_exc_class(): # pragma: no cover
479-
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
483+
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
480484
raise
481485

482486
async def create_message_as_task(
@@ -531,7 +535,7 @@ async def create_message_as_task(
531535
raise RuntimeError("handler is required for create_message_as_task()")
532536

533537
# Update status to input_required
534-
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
538+
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id)
535539

536540
# Build request WITH task field for task-augmented sampling
537541
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
@@ -577,9 +581,9 @@ async def create_message_as_task(
577581
CreateMessageResult,
578582
)
579583

580-
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
584+
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
581585
return result
582586

583587
except anyio.get_cancelled_exc_class(): # pragma: no cover
584-
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
588+
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
585589
raise

src/mcp/server/experimental/task_result_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ async def handle(
8080
request: GetTaskPayloadRequest,
8181
session: ServerSession,
8282
request_id: RequestId,
83-
session_id: str | None = None,
83+
session_id: str,
8484
) -> GetTaskPayloadResult:
8585
"""Handle a tasks/result request.
8686
@@ -95,7 +95,7 @@ async def handle(
9595
request: The GetTaskPayloadRequest
9696
session: The server session for sending messages
9797
request_id: The request ID for relatedRequestId routing
98-
session_id: Optional session identifier for access control.
98+
session_id: Session identifier for access control.
9999
100100
Returns:
101101
GetTaskPayloadResult with the task's final payload

src/mcp/server/lowlevel/experimental.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,22 @@ def enable_tasks(
147147
if on_cancel_task is not None:
148148
self._add_request_handler("tasks/cancel", on_cancel_task)
149149

150+
def _require_session_id(ctx: ServerRequestContext[LifespanResultT]) -> str:
151+
session_id = ctx.session.session_id
152+
if session_id is None:
153+
raise MCPError(
154+
code=INVALID_PARAMS,
155+
message="Session ID is required for task operations.",
156+
)
157+
return session_id
158+
150159
# Fill in defaults for any not provided
151160
if not self._has_handler("tasks/get"):
152161

153162
async def _default_get_task(
154163
ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams
155164
) -> GetTaskResult:
156-
session_id = ctx.session.session_id
165+
session_id = _require_session_id(ctx)
157166
task = await task_support.store.get_task(params.task_id, session_id=session_id)
158167
if task is None:
159168
raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}")
@@ -175,7 +184,7 @@ async def _default_get_task_result(
175184
ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams
176185
) -> GetTaskPayloadResult:
177186
assert ctx.request_id is not None
178-
session_id = ctx.session.session_id
187+
session_id = _require_session_id(ctx)
179188
req = GetTaskPayloadRequest(params=params)
180189
result = await task_support.handler.handle(req, ctx.session, ctx.request_id, session_id=session_id)
181190
return result
@@ -188,7 +197,7 @@ async def _default_list_tasks(
188197
ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None
189198
) -> ListTasksResult:
190199
cursor = params.cursor if params else None
191-
session_id = ctx.session.session_id
200+
session_id = _require_session_id(ctx)
192201
tasks, next_cursor = await task_support.store.list_tasks(cursor, session_id=session_id)
193202
return ListTasksResult(tasks=tasks, next_cursor=next_cursor)
194203

@@ -199,7 +208,7 @@ async def _default_list_tasks(
199208
async def _default_cancel_task(
200209
ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams
201210
) -> CancelTaskResult:
202-
session_id = ctx.session.session_id
211+
session_id = _require_session_id(ctx)
203212
result = await cancel_task(task_support.store, params.task_id, session_id=session_id)
204213
return result
205214

src/mcp/server/lowlevel/server.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,6 @@ async def run(
374374
# the initialization lifecycle, but can do so with any available node
375375
# rather than requiring initialization for each connection.
376376
stateless: bool = False,
377-
# Optional session identifier for task isolation. When provided (e.g.,
378-
# from the transport's mcp_session_id), tasks are bound to this ID.
379377
session_id: str | None = None,
380378
):
381379
async with AsyncExitStack() as stack:

src/mcp/server/session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
write_stream: MemoryObjectSendStream[SessionMessage],
8484
init_options: InitializationOptions,
8585
stateless: bool = False,
86+
*,
8687
session_id: str | None = None,
8788
) -> None:
8889
super().__init__(read_stream, write_stream)

src/mcp/shared/experimental/tasks/context.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,20 @@ class TaskContext:
2121
use ServerTaskContext from mcp.server.experimental.
2222
2323
Example (distributed worker):
24-
async def worker_job(task_id: str):
24+
async def worker_job(task_id: str, session_id: str):
2525
store = RedisTaskStore(redis_url)
26-
task = await store.get_task(task_id)
27-
ctx = TaskContext(task=task, store=store)
26+
task = await store.get_task(task_id, session_id=session_id)
27+
ctx = TaskContext(task=task, store=store, session_id=session_id)
2828
2929
await ctx.update_status("Working...")
3030
result = await do_work()
3131
await ctx.complete(result)
3232
"""
3333

34-
def __init__(self, task: Task, store: TaskStore):
34+
def __init__(self, task: Task, store: TaskStore, *, session_id: str):
3535
self._task = task
3636
self._store = store
37+
self._session_id = session_id
3738
self._cancelled = False
3839

3940
@property
@@ -68,6 +69,7 @@ async def update_status(self, message: str) -> None:
6869
self._task = await self._store.update_task(
6970
self.task_id,
7071
status_message=message,
72+
session_id=self._session_id,
7173
)
7274

7375
async def complete(self, result: Result) -> None:
@@ -76,10 +78,11 @@ async def complete(self, result: Result) -> None:
7678
Args:
7779
result: The task result
7880
"""
79-
await self._store.store_result(self.task_id, result)
81+
await self._store.store_result(self.task_id, result, session_id=self._session_id)
8082
self._task = await self._store.update_task(
8183
self.task_id,
8284
status=TASK_STATUS_COMPLETED,
85+
session_id=self._session_id,
8386
)
8487

8588
async def fail(self, error: str) -> None:
@@ -92,4 +95,5 @@ async def fail(self, error: str) -> None:
9295
self.task_id,
9396
status=TASK_STATUS_FAILED,
9497
status_message=error,
98+
session_id=self._session_id,
9599
)

src/mcp/shared/experimental/tasks/helpers.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def is_terminal(status: TaskStatus) -> bool:
5050
async def cancel_task(
5151
store: TaskStore,
5252
task_id: str,
53-
session_id: str | None = None,
53+
*,
54+
session_id: str,
5455
) -> CancelTaskResult:
5556
"""Cancel a task with spec-compliant validation.
5657
@@ -63,7 +64,7 @@ async def cancel_task(
6364
Args:
6465
store: The task store
6566
task_id: The task identifier to cancel
66-
session_id: Optional session identifier for access control.
67+
session_id: Session identifier for access control.
6768
6869
Returns:
6970
CancelTaskResult with the cancelled task state
@@ -75,7 +76,7 @@ async def cancel_task(
7576
7677
Example:
7778
async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResult:
78-
return await cancel_task(store, params.task_id)
79+
return await cancel_task(store, params.task_id, session_id=ctx.session.session_id)
7980
"""
8081
task = await store.get_task(task_id, session_id=session_id)
8182
if task is None:
@@ -124,6 +125,8 @@ def create_task_state(
124125
async def task_execution(
125126
task_id: str,
126127
store: TaskStore,
128+
*,
129+
session_id: str,
127130
) -> AsyncIterator[TaskContext]:
128131
"""Context manager for safe task execution (pure, no server dependencies).
129132
@@ -136,6 +139,7 @@ async def task_execution(
136139
Args:
137140
task_id: The task identifier to execute
138141
store: The task store (must be accessible by the worker)
142+
session_id: Session identifier for access control.
139143
140144
Yields:
141145
TaskContext for updating status and completing/failing the task
@@ -144,18 +148,18 @@ async def task_execution(
144148
ValueError: If the task is not found in the store
145149
146150
Example (distributed worker):
147-
async def worker_process(task_id: str):
151+
async def worker_process(task_id: str, session_id: str):
148152
store = RedisTaskStore(redis_url)
149-
async with task_execution(task_id, store) as ctx:
153+
async with task_execution(task_id, store, session_id=session_id) as ctx:
150154
await ctx.update_status("Working...")
151155
result = await do_work()
152156
await ctx.complete(result)
153157
"""
154-
task = await store.get_task(task_id)
158+
task = await store.get_task(task_id, session_id=session_id)
155159
if task is None:
156160
raise ValueError(f"Task {task_id} not found")
157161

158-
ctx = TaskContext(task, store)
162+
ctx = TaskContext(task, store, session_id=session_id)
159163
try:
160164
yield ctx
161165
except Exception as e:

0 commit comments

Comments
 (0)