77
88import logging
99from collections .abc import Awaitable , Callable
10- from typing import Any
10+ from typing import Any , Generic
11+
12+ from typing_extensions import TypeVar
1113
1214from mcp .server .context import ServerRequestContext
1315from mcp .server .experimental .task_support import TaskSupport
3840
3941logger = logging .getLogger (__name__ )
4042
43+ LifespanResultT = TypeVar ("LifespanResultT" , default = Any )
44+
4145
42- class ExperimentalHandlers :
46+ class ExperimentalHandlers ( Generic [ LifespanResultT ]) :
4347 """Experimental request/notification handlers.
4448
4549 WARNING: These APIs are experimental and may change without notice.
4650 """
4751
4852 def __init__ (
4953 self ,
50- add_request_handler : Callable [[str , Callable [[ServerRequestContext [Any ], Any ], Awaitable [Any ]]], None ],
54+ add_request_handler : Callable [[str , Callable [[ServerRequestContext [LifespanResultT ], Any ], Awaitable [Any ]]], None ],
5155 has_handler : Callable [[str ], bool ],
5256 ) -> None :
5357 self ._add_request_handler = add_request_handler
@@ -79,15 +83,15 @@ def enable_tasks(
7983 store : TaskStore | None = None ,
8084 queue : TaskMessageQueue | None = None ,
8185 * ,
82- on_get_task : Callable [[ServerRequestContext [Any ], GetTaskRequestParams ], Awaitable [GetTaskResult ]]
86+ on_get_task : Callable [[ServerRequestContext [LifespanResultT ], GetTaskRequestParams ], Awaitable [GetTaskResult ]]
8387 | None = None ,
8488 on_task_result : Callable [
85- [ServerRequestContext [Any ], GetTaskPayloadRequestParams ], Awaitable [GetTaskPayloadResult ]
89+ [ServerRequestContext [LifespanResultT ], GetTaskPayloadRequestParams ], Awaitable [GetTaskPayloadResult ]
8690 ]
8791 | None = None ,
88- on_list_tasks : Callable [[ServerRequestContext [Any ], PaginatedRequestParams | None ], Awaitable [ListTasksResult ]]
92+ on_list_tasks : Callable [[ServerRequestContext [LifespanResultT ], PaginatedRequestParams | None ], Awaitable [ListTasksResult ]]
8993 | None = None ,
90- on_cancel_task : Callable [[ServerRequestContext [Any ], CancelTaskRequestParams ], Awaitable [CancelTaskResult ]]
94+ on_cancel_task : Callable [[ServerRequestContext [LifespanResultT ], CancelTaskRequestParams ], Awaitable [CancelTaskResult ]]
9195 | None = None ,
9296 ) -> TaskSupport :
9397 """Enable experimental task support.
@@ -139,7 +143,7 @@ def enable_tasks(
139143 # Fill in defaults for any not provided
140144 if not self ._has_handler ("tasks/get" ):
141145
142- async def _default_get_task (ctx : ServerRequestContext [Any ], params : GetTaskRequestParams ) -> GetTaskResult :
146+ async def _default_get_task (ctx : ServerRequestContext [LifespanResultT ], params : GetTaskRequestParams ) -> GetTaskResult :
143147 task = await self ._task_support .store .get_task (params .task_id )
144148 if task is None :
145149 raise MCPError (code = INVALID_PARAMS , message = f"Task not found: { params .task_id } " )
@@ -158,7 +162,7 @@ async def _default_get_task(ctx: ServerRequestContext[Any], params: GetTaskReque
158162 if not self ._has_handler ("tasks/result" ):
159163
160164 async def _default_get_task_result (
161- ctx : ServerRequestContext [Any ], params : GetTaskPayloadRequestParams
165+ ctx : ServerRequestContext [LifespanResultT ], params : GetTaskPayloadRequestParams
162166 ) -> GetTaskPayloadResult :
163167 assert ctx .request_id is not None
164168 req = GetTaskPayloadRequest (params = params )
@@ -170,7 +174,7 @@ async def _default_get_task_result(
170174 if not self ._has_handler ("tasks/list" ):
171175
172176 async def _default_list_tasks (
173- ctx : ServerRequestContext [Any ], params : PaginatedRequestParams | None
177+ ctx : ServerRequestContext [LifespanResultT ], params : PaginatedRequestParams | None
174178 ) -> ListTasksResult :
175179 cursor = params .cursor if params else None
176180 tasks , next_cursor = await self ._task_support .store .list_tasks (cursor )
@@ -181,7 +185,7 @@ async def _default_list_tasks(
181185 if not self ._has_handler ("tasks/cancel" ):
182186
183187 async def _default_cancel_task (
184- ctx : ServerRequestContext [Any ], params : CancelTaskRequestParams
188+ ctx : ServerRequestContext [LifespanResultT ], params : CancelTaskRequestParams
185189 ) -> CancelTaskResult :
186190 result = await cancel_task (self ._task_support .store , params .task_id )
187191 return result
0 commit comments