@@ -22,6 +22,7 @@ class StoredTask:
2222 """Internal storage representation of a task."""
2323
2424 task : Task
25+ session_id : str | None = None
2526 result : Result | None = None
2627 # Time when this task should be removed (None = never)
2728 expires_at : datetime | None = field (default = None )
@@ -32,6 +33,7 @@ class InMemoryTaskStore(TaskStore):
3233
3334 Features:
3435 - Automatic TTL-based cleanup (lazy expiration)
36+ - Session isolation (tasks are scoped to their creating session)
3537 - Thread-safe for single-process async use
3638 - Pagination support for list_tasks
3739
@@ -66,10 +68,25 @@ def _cleanup_expired(self) -> None:
6668 for task_id in expired_ids :
6769 del self ._tasks [task_id ]
6870
71+ def _get_stored_task (self , task_id : str , session_id : str | None = None ) -> StoredTask | None :
72+ """Retrieve a stored task, enforcing session ownership when a session_id is provided.
73+
74+ Returns None if the task does not exist or belongs to a different session.
75+ When either the caller's session_id or the stored task's session_id is None,
76+ no filtering occurs (backward compatibility).
77+ """
78+ stored = self ._tasks .get (task_id )
79+ if stored is None :
80+ return None
81+ if session_id is not None and stored .session_id is not None and stored .session_id != session_id :
82+ return None
83+ return stored
84+
6985 async def create_task (
7086 self ,
7187 metadata : TaskMetadata ,
7288 task_id : str | None = None ,
89+ session_id : str | None = None ,
7390 ) -> Task :
7491 """Create a new task with the given metadata."""
7592 # Cleanup expired tasks on access
@@ -82,19 +99,20 @@ async def create_task(
8299
83100 stored = StoredTask (
84101 task = task ,
102+ session_id = session_id ,
85103 expires_at = self ._calculate_expiry (metadata .ttl ),
86104 )
87105 self ._tasks [task .task_id ] = stored
88106
89107 # Return a copy to prevent external modification
90108 return Task (** task .model_dump ())
91109
92- async def get_task (self , task_id : str ) -> Task | None :
110+ async def get_task (self , task_id : str , session_id : str | None = None ) -> Task | None :
93111 """Get a task by ID."""
94112 # Cleanup expired tasks on access
95113 self ._cleanup_expired ()
96114
97- stored = self ._tasks . get (task_id )
115+ stored = self ._get_stored_task (task_id , session_id )
98116 if stored is None :
99117 return None
100118
@@ -106,9 +124,10 @@ async def update_task(
106124 task_id : str ,
107125 status : TaskStatus | None = None ,
108126 status_message : str | None = None ,
127+ session_id : str | None = None ,
109128 ) -> Task :
110129 """Update a task's status and/or message."""
111- stored = self ._tasks . get (task_id )
130+ stored = self ._get_stored_task (task_id , session_id )
112131 if stored is None :
113132 raise ValueError (f"Task with ID { task_id } not found" )
114133
@@ -137,17 +156,17 @@ async def update_task(
137156
138157 return Task (** stored .task .model_dump ())
139158
140- async def store_result (self , task_id : str , result : Result ) -> None :
159+ async def store_result (self , task_id : str , result : Result , session_id : str | None = None ) -> None :
141160 """Store the result for a task."""
142- stored = self ._tasks . get (task_id )
161+ stored = self ._get_stored_task (task_id , session_id )
143162 if stored is None :
144163 raise ValueError (f"Task with ID { task_id } not found" )
145164
146165 stored .result = result
147166
148- async def get_result (self , task_id : str ) -> Result | None :
167+ async def get_result (self , task_id : str , session_id : str | None = None ) -> Result | None :
149168 """Get the stored result for a task."""
150- stored = self ._tasks . get (task_id )
169+ stored = self ._get_stored_task (task_id , session_id )
151170 if stored is None :
152171 return None
153172
@@ -156,34 +175,41 @@ async def get_result(self, task_id: str) -> Result | None:
156175 async def list_tasks (
157176 self ,
158177 cursor : str | None = None ,
178+ session_id : str | None = None ,
159179 ) -> tuple [list [Task ], str | None ]:
160180 """List tasks with pagination."""
161181 # Cleanup expired tasks on access
162182 self ._cleanup_expired ()
163183
164- all_task_ids = list (self ._tasks .keys ())
184+ # Filter tasks by session ownership before pagination
185+ filtered_task_ids = [
186+ task_id
187+ for task_id , stored in self ._tasks .items ()
188+ if session_id is None or stored .session_id is None or stored .session_id == session_id
189+ ]
165190
166191 start_index = 0
167192 if cursor is not None :
168193 try :
169- cursor_index = all_task_ids .index (cursor )
194+ cursor_index = filtered_task_ids .index (cursor )
170195 start_index = cursor_index + 1
171196 except ValueError :
172197 raise ValueError (f"Invalid cursor: { cursor } " )
173198
174- page_task_ids = all_task_ids [start_index : start_index + self ._page_size ]
199+ page_task_ids = filtered_task_ids [start_index : start_index + self ._page_size ]
175200 tasks = [Task (** self ._tasks [tid ].task .model_dump ()) for tid in page_task_ids ]
176201
177202 # Determine next cursor
178203 next_cursor = None
179- if start_index + self ._page_size < len (all_task_ids ) and page_task_ids :
204+ if start_index + self ._page_size < len (filtered_task_ids ) and page_task_ids :
180205 next_cursor = page_task_ids [- 1 ]
181206
182207 return tasks , next_cursor
183208
184- async def delete_task (self , task_id : str ) -> bool :
209+ async def delete_task (self , task_id : str , session_id : str | None = None ) -> bool :
185210 """Delete a task."""
186- if task_id not in self ._tasks :
211+ stored = self ._get_stored_task (task_id , session_id )
212+ if stored is None :
187213 return False
188214
189215 del self ._tasks [task_id ]
0 commit comments