Skip to content

Commit b466beb

Browse files
committed
refactor: make ExperimentalHandlers generic on LifespanResultT
1 parent 9b77527 commit b466beb

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

src/mcp/server/lowlevel/experimental.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
import logging
99
from collections.abc import Awaitable, Callable
10-
from typing import Any
10+
from typing import Any, Generic
11+
12+
from typing_extensions import TypeVar
1113

1214
from mcp.server.context import ServerRequestContext
1315
from mcp.server.experimental.task_support import TaskSupport
@@ -38,16 +40,18 @@
3840

3941
logger = logging.getLogger(__name__)
4042

43+
LifespanResultT = TypeVar("LifespanResultT", default=Any)
44+
4145

42-
class ExperimentalHandlers:
46+
class ExperimentalHandlers(Generic[LifespanResultT]):
4347
"""Experimental request/notification handlers.
4448
4549
WARNING: These APIs are experimental and may change without notice.
4650
"""
4751

4852
def __init__(
4953
self,
50-
add_request_handler: Callable[[str, Callable[[ServerRequestContext[Any], Any], Awaitable[Any]]], None],
54+
add_request_handler: Callable[[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]], None],
5155
has_handler: Callable[[str], bool],
5256
) -> None:
5357
self._add_request_handler = add_request_handler
@@ -79,15 +83,15 @@ def enable_tasks(
7983
store: TaskStore | None = None,
8084
queue: TaskMessageQueue | None = None,
8185
*,
82-
on_get_task: Callable[[ServerRequestContext[Any], GetTaskRequestParams], Awaitable[GetTaskResult]]
86+
on_get_task: Callable[[ServerRequestContext[LifespanResultT], GetTaskRequestParams], Awaitable[GetTaskResult]]
8387
| None = None,
8488
on_task_result: Callable[
85-
[ServerRequestContext[Any], GetTaskPayloadRequestParams], Awaitable[GetTaskPayloadResult]
89+
[ServerRequestContext[LifespanResultT], GetTaskPayloadRequestParams], Awaitable[GetTaskPayloadResult]
8690
]
8791
| None = None,
88-
on_list_tasks: Callable[[ServerRequestContext[Any], PaginatedRequestParams | None], Awaitable[ListTasksResult]]
92+
on_list_tasks: Callable[[ServerRequestContext[LifespanResultT], PaginatedRequestParams | None], Awaitable[ListTasksResult]]
8993
| None = None,
90-
on_cancel_task: Callable[[ServerRequestContext[Any], CancelTaskRequestParams], Awaitable[CancelTaskResult]]
94+
on_cancel_task: Callable[[ServerRequestContext[LifespanResultT], CancelTaskRequestParams], Awaitable[CancelTaskResult]]
9195
| None = None,
9296
) -> TaskSupport:
9397
"""Enable experimental task support.
@@ -139,7 +143,7 @@ def enable_tasks(
139143
# Fill in defaults for any not provided
140144
if not self._has_handler("tasks/get"):
141145

142-
async def _default_get_task(ctx: ServerRequestContext[Any], params: GetTaskRequestParams) -> GetTaskResult:
146+
async def _default_get_task(ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams) -> GetTaskResult:
143147
task = await self._task_support.store.get_task(params.task_id)
144148
if task is None:
145149
raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}")
@@ -158,7 +162,7 @@ async def _default_get_task(ctx: ServerRequestContext[Any], params: GetTaskReque
158162
if not self._has_handler("tasks/result"):
159163

160164
async def _default_get_task_result(
161-
ctx: ServerRequestContext[Any], params: GetTaskPayloadRequestParams
165+
ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams
162166
) -> GetTaskPayloadResult:
163167
assert ctx.request_id is not None
164168
req = GetTaskPayloadRequest(params=params)
@@ -170,7 +174,7 @@ async def _default_get_task_result(
170174
if not self._has_handler("tasks/list"):
171175

172176
async def _default_list_tasks(
173-
ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None
177+
ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None
174178
) -> ListTasksResult:
175179
cursor = params.cursor if params else None
176180
tasks, next_cursor = await self._task_support.store.list_tasks(cursor)
@@ -181,7 +185,7 @@ async def _default_list_tasks(
181185
if not self._has_handler("tasks/cancel"):
182186

183187
async def _default_cancel_task(
184-
ctx: ServerRequestContext[Any], params: CancelTaskRequestParams
188+
ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams
185189
) -> CancelTaskResult:
186190
result = await cancel_task(self._task_support.store, params.task_id)
187191
return result

src/mcp/server/lowlevel/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def __init__(
197197
self._notification_handlers: dict[
198198
str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]]
199199
] = {}
200-
self._experimental_handlers: ExperimentalHandlers | None = None
200+
self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None
201201
self._session_manager: StreamableHTTPSessionManager | None = None
202202
logger.debug("Initializing server %r", name)
203203

@@ -339,7 +339,7 @@ def get_capabilities(
339339
return capabilities
340340

341341
@property
342-
def experimental(self) -> ExperimentalHandlers:
342+
def experimental(self) -> ExperimentalHandlers[LifespanResultT]:
343343
"""Experimental APIs for tasks and other features.
344344
345345
WARNING: These APIs are experimental and may change without notice.

0 commit comments

Comments
 (0)