1414from __future__ import annotations
1515
1616from dataclasses import dataclass , field
17- from typing import TYPE_CHECKING , Protocol
17+ from typing import Generic , Protocol
1818
1919from pydantic import TypeAdapter
2020
2121from mcp import types
22+ from mcp .client .base_client_session import BaseClientSession , ClientSessionT_contra
2223from mcp .shared ._context import RequestContext
2324from mcp .shared .session import RequestResponder
2425
25- if TYPE_CHECKING :
26- from mcp .client .session import ClientSession
2726
28-
29- class GetTaskHandlerFnT (Protocol ):
27+ class GetTaskHandlerFnT (Protocol [ClientSessionT_contra ]):
3028 """Handler for tasks/get requests from server.
3129
3230 WARNING: This is experimental and may change without notice.
3331 """
3432
3533 async def __call__ (
3634 self ,
37- context : RequestContext [ClientSession ],
35+ context : RequestContext [ClientSessionT_contra ],
3836 params : types .GetTaskRequestParams ,
3937 ) -> types .GetTaskResult | types .ErrorData : ... # pragma: no branch
4038
4139
42- class GetTaskResultHandlerFnT (Protocol ):
40+ class GetTaskResultHandlerFnT (Protocol [ ClientSessionT_contra ] ):
4341 """Handler for tasks/result requests from server.
4442
4543 WARNING: This is experimental and may change without notice.
4644 """
4745
4846 async def __call__ (
4947 self ,
50- context : RequestContext [ClientSession ],
48+ context : RequestContext [ClientSessionT_contra ],
5149 params : types .GetTaskPayloadRequestParams ,
5250 ) -> types .GetTaskPayloadResult | types .ErrorData : ... # pragma: no branch
5351
5452
55- class ListTasksHandlerFnT (Protocol ):
53+ class ListTasksHandlerFnT (Protocol [ ClientSessionT_contra ] ):
5654 """Handler for tasks/list requests from server.
5755
5856 WARNING: This is experimental and may change without notice.
5957 """
6058
6159 async def __call__ (
6260 self ,
63- context : RequestContext [ClientSession ],
61+ context : RequestContext [ClientSessionT_contra ],
6462 params : types .PaginatedRequestParams | None ,
6563 ) -> types .ListTasksResult | types .ErrorData : ... # pragma: no branch
6664
6765
68- class CancelTaskHandlerFnT (Protocol ):
66+ class CancelTaskHandlerFnT (Protocol [ ClientSessionT_contra ] ):
6967 """Handler for tasks/cancel requests from server.
7068
7169 WARNING: This is experimental and may change without notice.
7270 """
7371
7472 async def __call__ (
7573 self ,
76- context : RequestContext [ClientSession ],
74+ context : RequestContext [ClientSessionT_contra ],
7775 params : types .CancelTaskRequestParams ,
7876 ) -> types .CancelTaskResult | types .ErrorData : ... # pragma: no branch
7977
8078
81- class TaskAugmentedSamplingFnT (Protocol ):
79+ class TaskAugmentedSamplingFnT (Protocol [ ClientSessionT_contra ] ):
8280 """Handler for task-augmented sampling/createMessage requests from server.
8381
8482 When server sends a CreateMessageRequest with task field, this callback
@@ -90,13 +88,13 @@ class TaskAugmentedSamplingFnT(Protocol):
9088
9189 async def __call__ (
9290 self ,
93- context : RequestContext [ClientSession ],
91+ context : RequestContext [ClientSessionT_contra ],
9492 params : types .CreateMessageRequestParams ,
9593 task_metadata : types .TaskMetadata ,
9694 ) -> types .CreateTaskResult | types .ErrorData : ... # pragma: no branch
9795
9896
99- class TaskAugmentedElicitationFnT (Protocol ):
97+ class TaskAugmentedElicitationFnT (Protocol [ ClientSessionT_contra ] ):
10098 """Handler for task-augmented elicitation/create requests from server.
10199
102100 When server sends an ElicitRequest with task field, this callback
@@ -108,14 +106,14 @@ class TaskAugmentedElicitationFnT(Protocol):
108106
109107 async def __call__ (
110108 self ,
111- context : RequestContext [ClientSession ],
109+ context : RequestContext [ClientSessionT_contra ],
112110 params : types .ElicitRequestParams ,
113111 task_metadata : types .TaskMetadata ,
114112 ) -> types .CreateTaskResult | types .ErrorData : ... # pragma: no branch
115113
116114
117115async def default_get_task_handler (
118- context : RequestContext [ClientSession ],
116+ context : RequestContext [BaseClientSession ],
119117 params : types .GetTaskRequestParams ,
120118) -> types .GetTaskResult | types .ErrorData :
121119 return types .ErrorData (
@@ -125,7 +123,7 @@ async def default_get_task_handler(
125123
126124
127125async def default_get_task_result_handler (
128- context : RequestContext [ClientSession ],
126+ context : RequestContext [BaseClientSession ],
129127 params : types .GetTaskPayloadRequestParams ,
130128) -> types .GetTaskPayloadResult | types .ErrorData :
131129 return types .ErrorData (
@@ -135,7 +133,7 @@ async def default_get_task_result_handler(
135133
136134
137135async def default_list_tasks_handler (
138- context : RequestContext [ClientSession ],
136+ context : RequestContext [BaseClientSession ],
139137 params : types .PaginatedRequestParams | None ,
140138) -> types .ListTasksResult | types .ErrorData :
141139 return types .ErrorData (
@@ -145,7 +143,7 @@ async def default_list_tasks_handler(
145143
146144
147145async def default_cancel_task_handler (
148- context : RequestContext [ClientSession ],
146+ context : RequestContext [BaseClientSession ],
149147 params : types .CancelTaskRequestParams ,
150148) -> types .CancelTaskResult | types .ErrorData :
151149 return types .ErrorData (
@@ -155,7 +153,7 @@ async def default_cancel_task_handler(
155153
156154
157155async def default_task_augmented_sampling (
158- context : RequestContext [ClientSession ],
156+ context : RequestContext [BaseClientSession ],
159157 params : types .CreateMessageRequestParams ,
160158 task_metadata : types .TaskMetadata ,
161159) -> types .CreateTaskResult | types .ErrorData :
@@ -166,7 +164,7 @@ async def default_task_augmented_sampling(
166164
167165
168166async def default_task_augmented_elicitation (
169- context : RequestContext [ClientSession ],
167+ context : RequestContext [BaseClientSession ],
170168 params : types .ElicitRequestParams ,
171169 task_metadata : types .TaskMetadata ,
172170) -> types .CreateTaskResult | types .ErrorData :
@@ -177,7 +175,7 @@ async def default_task_augmented_elicitation(
177175
178176
179177@dataclass
180- class ExperimentalTaskHandlers :
178+ class ExperimentalTaskHandlers ( Generic [ ClientSessionT_contra ]) :
181179 """Container for experimental task handlers.
182180
183181 Groups all task-related handlers that handle server -> client requests.
@@ -195,14 +193,16 @@ class ExperimentalTaskHandlers:
195193 """
196194
197195 # Pure task request handlers
198- get_task : GetTaskHandlerFnT = field (default = default_get_task_handler )
199- get_task_result : GetTaskResultHandlerFnT = field (default = default_get_task_result_handler )
200- list_tasks : ListTasksHandlerFnT = field (default = default_list_tasks_handler )
201- cancel_task : CancelTaskHandlerFnT = field (default = default_cancel_task_handler )
196+ get_task : GetTaskHandlerFnT [ ClientSessionT_contra ] = field (default = default_get_task_handler )
197+ get_task_result : GetTaskResultHandlerFnT [ ClientSessionT_contra ] = field (default = default_get_task_result_handler )
198+ list_tasks : ListTasksHandlerFnT [ ClientSessionT_contra ] = field (default = default_list_tasks_handler )
199+ cancel_task : CancelTaskHandlerFnT [ ClientSessionT_contra ] = field (default = default_cancel_task_handler )
202200
203201 # Task-augmented request handlers
204- augmented_sampling : TaskAugmentedSamplingFnT = field (default = default_task_augmented_sampling )
205- augmented_elicitation : TaskAugmentedElicitationFnT = field (default = default_task_augmented_elicitation )
202+ augmented_sampling : TaskAugmentedSamplingFnT [ClientSessionT_contra ] = field (default = default_task_augmented_sampling )
203+ augmented_elicitation : TaskAugmentedElicitationFnT [ClientSessionT_contra ] = field (
204+ default = default_task_augmented_elicitation
205+ )
206206
207207 def build_capability (self ) -> types .ClientTasksCapability | None :
208208 """Build ClientTasksCapability from the configured handlers.
@@ -250,7 +250,7 @@ def handles_request(request: types.ServerRequest) -> bool:
250250
251251 async def handle_request (
252252 self ,
253- ctx : RequestContext [ClientSession ],
253+ ctx : RequestContext [ClientSessionT_contra ],
254254 responder : RequestResponder [types .ServerRequest , types .ClientResult ],
255255 ) -> None :
256256 """Handle a task-related request from the server.
0 commit comments