Skip to content

Commit 797cca8

Browse files
committed
Fix the session binding logic for tasks.
1 parent 43d709c commit 797cca8

File tree

7 files changed

+233
-32
lines changed

7 files changed

+233
-32
lines changed

src/mcp/server/experimental/request_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ async def work(task: ServerTaskContext) -> CallToolResult:
187187
# Access task_group via TaskSupport - raises if not in run() context
188188
task_group = support.task_group
189189

190-
task = await support.store.create_task(self.task_metadata, task_id)
190+
session_id = str(id(self._session))
191+
task = await support.store.create_task(self.task_metadata, task_id, session_id=session_id)
191192

192193
task_ctx = ServerTaskContext(
193194
task=task,

src/mcp/server/experimental/task_result_handler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ async def handle(
8080
request: GetTaskPayloadRequest,
8181
session: ServerSession,
8282
request_id: RequestId,
83+
session_id: str | None = None,
8384
) -> GetTaskPayloadResult:
8485
"""Handle a tasks/result request.
8586
@@ -94,22 +95,23 @@ async def handle(
9495
request: The GetTaskPayloadRequest
9596
session: The server session for sending messages
9697
request_id: The request ID for relatedRequestId routing
98+
session_id: Optional session identifier for access control.
9799
98100
Returns:
99101
GetTaskPayloadResult with the task's final payload
100102
"""
101103
task_id = request.params.task_id
102104

103105
while True:
104-
task = await self._store.get_task(task_id)
106+
task = await self._store.get_task(task_id, session_id=session_id)
105107
if task is None:
106108
raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {task_id}")
107109

108110
await self._deliver_queued_messages(task_id, session, request_id)
109111

110112
# If task is terminal, return result
111113
if is_terminal(task.status):
112-
result = await self._store.get_result(task_id)
114+
result = await self._store.get_result(task_id, session_id=session_id)
113115
# GetTaskPayloadResult is a Result with extra="allow"
114116
# The stored result contains the actual payload data
115117
# Per spec: tasks/result MUST include _meta with related-task metadata

src/mcp/server/lowlevel/experimental.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ def enable_tasks(
153153
async def _default_get_task(
154154
ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams
155155
) -> GetTaskResult:
156-
task = await task_support.store.get_task(params.task_id)
156+
session_id = str(id(ctx.session))
157+
task = await task_support.store.get_task(params.task_id, session_id=session_id)
157158
if task is None:
158159
raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}")
159160
return GetTaskResult(
@@ -174,8 +175,9 @@ async def _default_get_task_result(
174175
ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams
175176
) -> GetTaskPayloadResult:
176177
assert ctx.request_id is not None
178+
session_id = str(id(ctx.session))
177179
req = GetTaskPayloadRequest(params=params)
178-
result = await task_support.handler.handle(req, ctx.session, ctx.request_id)
180+
result = await task_support.handler.handle(req, ctx.session, ctx.request_id, session_id=session_id)
179181
return result
180182

181183
self._add_request_handler("tasks/result", _default_get_task_result)
@@ -186,7 +188,8 @@ async def _default_list_tasks(
186188
ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None
187189
) -> ListTasksResult:
188190
cursor = params.cursor if params else None
189-
tasks, next_cursor = await task_support.store.list_tasks(cursor)
191+
session_id = str(id(ctx.session))
192+
tasks, next_cursor = await task_support.store.list_tasks(cursor, session_id=session_id)
190193
return ListTasksResult(tasks=tasks, next_cursor=next_cursor)
191194

192195
self._add_request_handler("tasks/list", _default_list_tasks)
@@ -196,7 +199,8 @@ async def _default_list_tasks(
196199
async def _default_cancel_task(
197200
ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams
198201
) -> CancelTaskResult:
199-
result = await cancel_task(task_support.store, params.task_id)
202+
session_id = str(id(ctx.session))
203+
result = await cancel_task(task_support.store, params.task_id, session_id=session_id)
200204
return result
201205

202206
self._add_request_handler("tasks/cancel", _default_cancel_task)

src/mcp/shared/experimental/tasks/helpers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def is_terminal(status: TaskStatus) -> bool:
5050
async def cancel_task(
5151
store: TaskStore,
5252
task_id: str,
53+
session_id: str | None = None,
5354
) -> CancelTaskResult:
5455
"""Cancel a task with spec-compliant validation.
5556
@@ -62,28 +63,29 @@ async def cancel_task(
6263
Args:
6364
store: The task store
6465
task_id: The task identifier to cancel
66+
session_id: Optional session identifier for access control.
6567
6668
Returns:
6769
CancelTaskResult with the cancelled task state
6870
6971
Raises:
7072
MCPError: With INVALID_PARAMS (-32602) if:
71-
- Task does not exist
73+
- Task does not exist or is not accessible by this session
7274
- Task is already in a terminal state (completed, failed, cancelled)
7375
7476
Example:
7577
async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResult:
7678
return await cancel_task(store, params.task_id)
7779
"""
78-
task = await store.get_task(task_id)
80+
task = await store.get_task(task_id, session_id=session_id)
7981
if task is None:
8082
raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {task_id}")
8183

8284
if is_terminal(task.status):
8385
raise MCPError(code=INVALID_PARAMS, message=f"Cannot cancel task in terminal state '{task.status}'")
8486

8587
# Update task to cancelled status
86-
cancelled_task = await store.update_task(task_id, status=TASK_STATUS_CANCELLED)
88+
cancelled_task = await store.update_task(task_id, status=TASK_STATUS_CANCELLED, session_id=session_id)
8789
return CancelTaskResult(**cancelled_task.model_dump())
8890

8991

src/mcp/shared/experimental/tasks/in_memory_task_store.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class StoredTask:
2222
"""Internal storage representation of a task."""
2323

2424
task: Task
25+
session_id: str | None = None
2526
result: Result | None = None
2627
# Time when this task should be removed (None = never)
2728
expires_at: datetime | None = field(default=None)
@@ -32,6 +33,7 @@ class InMemoryTaskStore(TaskStore):
3233
3334
Features:
3435
- Automatic TTL-based cleanup (lazy expiration)
36+
- Session isolation (tasks are scoped to their creating session)
3537
- Thread-safe for single-process async use
3638
- Pagination support for list_tasks
3739
@@ -66,10 +68,25 @@ def _cleanup_expired(self) -> None:
6668
for task_id in expired_ids:
6769
del self._tasks[task_id]
6870

71+
def _get_stored_task(self, task_id: str, session_id: str | None = None) -> StoredTask | None:
72+
"""Retrieve a stored task, enforcing session ownership when a session_id is provided.
73+
74+
Returns None if the task does not exist or belongs to a different session.
75+
When either the caller's session_id or the stored task's session_id is None,
76+
no filtering occurs (backward compatibility).
77+
"""
78+
stored = self._tasks.get(task_id)
79+
if stored is None:
80+
return None
81+
if session_id is not None and stored.session_id is not None and stored.session_id != session_id:
82+
return None
83+
return stored
84+
6985
async def create_task(
7086
self,
7187
metadata: TaskMetadata,
7288
task_id: str | None = None,
89+
session_id: str | None = None,
7390
) -> Task:
7491
"""Create a new task with the given metadata."""
7592
# Cleanup expired tasks on access
@@ -82,19 +99,20 @@ async def create_task(
8299

83100
stored = StoredTask(
84101
task=task,
102+
session_id=session_id,
85103
expires_at=self._calculate_expiry(metadata.ttl),
86104
)
87105
self._tasks[task.task_id] = stored
88106

89107
# Return a copy to prevent external modification
90108
return Task(**task.model_dump())
91109

92-
async def get_task(self, task_id: str) -> Task | None:
110+
async def get_task(self, task_id: str, session_id: str | None = None) -> Task | None:
93111
"""Get a task by ID."""
94112
# Cleanup expired tasks on access
95113
self._cleanup_expired()
96114

97-
stored = self._tasks.get(task_id)
115+
stored = self._get_stored_task(task_id, session_id)
98116
if stored is None:
99117
return None
100118

@@ -106,9 +124,10 @@ async def update_task(
106124
task_id: str,
107125
status: TaskStatus | None = None,
108126
status_message: str | None = None,
127+
session_id: str | None = None,
109128
) -> Task:
110129
"""Update a task's status and/or message."""
111-
stored = self._tasks.get(task_id)
130+
stored = self._get_stored_task(task_id, session_id)
112131
if stored is None:
113132
raise ValueError(f"Task with ID {task_id} not found")
114133

@@ -137,17 +156,17 @@ async def update_task(
137156

138157
return Task(**stored.task.model_dump())
139158

140-
async def store_result(self, task_id: str, result: Result) -> None:
159+
async def store_result(self, task_id: str, result: Result, session_id: str | None = None) -> None:
141160
"""Store the result for a task."""
142-
stored = self._tasks.get(task_id)
161+
stored = self._get_stored_task(task_id, session_id)
143162
if stored is None:
144163
raise ValueError(f"Task with ID {task_id} not found")
145164

146165
stored.result = result
147166

148-
async def get_result(self, task_id: str) -> Result | None:
167+
async def get_result(self, task_id: str, session_id: str | None = None) -> Result | None:
149168
"""Get the stored result for a task."""
150-
stored = self._tasks.get(task_id)
169+
stored = self._get_stored_task(task_id, session_id)
151170
if stored is None:
152171
return None
153172

@@ -156,34 +175,41 @@ async def get_result(self, task_id: str) -> Result | None:
156175
async def list_tasks(
157176
self,
158177
cursor: str | None = None,
178+
session_id: str | None = None,
159179
) -> tuple[list[Task], str | None]:
160180
"""List tasks with pagination."""
161181
# Cleanup expired tasks on access
162182
self._cleanup_expired()
163183

164-
all_task_ids = list(self._tasks.keys())
184+
# Filter tasks by session ownership before pagination
185+
filtered_task_ids = [
186+
task_id
187+
for task_id, stored in self._tasks.items()
188+
if session_id is None or stored.session_id is None or stored.session_id == session_id
189+
]
165190

166191
start_index = 0
167192
if cursor is not None:
168193
try:
169-
cursor_index = all_task_ids.index(cursor)
194+
cursor_index = filtered_task_ids.index(cursor)
170195
start_index = cursor_index + 1
171196
except ValueError:
172197
raise ValueError(f"Invalid cursor: {cursor}")
173198

174-
page_task_ids = all_task_ids[start_index : start_index + self._page_size]
199+
page_task_ids = filtered_task_ids[start_index : start_index + self._page_size]
175200
tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids]
176201

177202
# Determine next cursor
178203
next_cursor = None
179-
if start_index + self._page_size < len(all_task_ids) and page_task_ids:
204+
if start_index + self._page_size < len(filtered_task_ids) and page_task_ids:
180205
next_cursor = page_task_ids[-1]
181206

182207
return tasks, next_cursor
183208

184-
async def delete_task(self, task_id: str) -> bool:
209+
async def delete_task(self, task_id: str, session_id: str | None = None) -> bool:
185210
"""Delete a task."""
186-
if task_id not in self._tasks:
211+
stored = self._get_stored_task(task_id, session_id)
212+
if stored is None:
187213
return False
188214

189215
del self._tasks[task_id]

src/mcp/shared/experimental/tasks/store.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@ async def create_task(
1919
self,
2020
metadata: TaskMetadata,
2121
task_id: str | None = None,
22+
session_id: str | None = None,
2223
) -> Task:
2324
"""Create a new task.
2425
2526
Args:
2627
metadata: Task metadata (ttl, etc.)
2728
task_id: Optional task ID. If None, implementation should generate one.
29+
session_id: Optional session identifier. When provided, the task is
30+
bound to this session for isolation purposes.
2831
2932
Returns:
3033
The created Task with status="working"
@@ -34,14 +37,15 @@ async def create_task(
3437
"""
3538

3639
@abstractmethod
37-
async def get_task(self, task_id: str) -> Task | None:
40+
async def get_task(self, task_id: str, session_id: str | None = None) -> Task | None:
3841
"""Get a task by ID.
3942
4043
Args:
4144
task_id: The task identifier
45+
session_id: Optional session identifier for access control.
4246
4347
Returns:
44-
The Task, or None if not found
48+
The Task, or None if not found or not accessible by this session.
4549
"""
4650

4751
@abstractmethod
@@ -50,70 +54,78 @@ async def update_task(
5054
task_id: str,
5155
status: TaskStatus | None = None,
5256
status_message: str | None = None,
57+
session_id: str | None = None,
5358
) -> Task:
5459
"""Update a task's status and/or message.
5560
5661
Args:
5762
task_id: The task identifier
5863
status: New status (if changing)
5964
status_message: New status message (if changing)
65+
session_id: Optional session identifier for access control.
6066
6167
Returns:
6268
The updated Task
6369
6470
Raises:
65-
ValueError: If task not found
71+
ValueError: If task not found or not accessible by this session.
6672
ValueError: If attempting to transition from a terminal status
6773
(completed, failed, cancelled). Per spec, terminal states
6874
MUST NOT transition to any other status.
6975
"""
7076

7177
@abstractmethod
72-
async def store_result(self, task_id: str, result: Result) -> None:
78+
async def store_result(self, task_id: str, result: Result, session_id: str | None = None) -> None:
7379
"""Store the result for a task.
7480
7581
Args:
7682
task_id: The task identifier
7783
result: The result to store
84+
session_id: Optional session identifier for access control.
7885
7986
Raises:
80-
ValueError: If task not found
87+
ValueError: If task not found or not accessible by this session.
8188
"""
8289

8390
@abstractmethod
84-
async def get_result(self, task_id: str) -> Result | None:
91+
async def get_result(self, task_id: str, session_id: str | None = None) -> Result | None:
8592
"""Get the stored result for a task.
8693
8794
Args:
8895
task_id: The task identifier
96+
session_id: Optional session identifier for access control.
8997
9098
Returns:
91-
The stored Result, or None if not available
99+
The stored Result, or None if not available.
92100
"""
93101

94102
@abstractmethod
95103
async def list_tasks(
96104
self,
97105
cursor: str | None = None,
106+
session_id: str | None = None,
98107
) -> tuple[list[Task], str | None]:
99108
"""List tasks with pagination.
100109
101110
Args:
102111
cursor: Optional cursor for pagination
112+
session_id: Optional session identifier. When provided, only tasks
113+
belonging to this session are returned.
103114
104115
Returns:
105116
Tuple of (tasks, next_cursor). next_cursor is None if no more pages.
106117
"""
107118

108119
@abstractmethod
109-
async def delete_task(self, task_id: str) -> bool:
120+
async def delete_task(self, task_id: str, session_id: str | None = None) -> bool:
110121
"""Delete a task.
111122
112123
Args:
113124
task_id: The task identifier
125+
session_id: Optional session identifier for access control.
114126
115127
Returns:
116-
True if deleted, False if not found
128+
True if deleted, False if not found or not accessible by this session.
117129
"""
118130

119131
@abstractmethod

0 commit comments

Comments
 (0)