1111- Server polls client's task status via tasks/get, tasks/result, etc.
1212"""
1313
14+ from __future__ import annotations
15+
1416from dataclasses import dataclass , field
15- from typing import TYPE_CHECKING , Any , Protocol
17+ from typing import TYPE_CHECKING , Protocol
1618
1719from pydantic import TypeAdapter
1820
@@ -32,7 +34,7 @@ class GetTaskHandlerFnT(Protocol):
3234
3335 async def __call__ (
3436 self ,
35- context : RequestContext [" ClientSession" , Any ],
37+ context : RequestContext [ClientSession ],
3638 params : types .GetTaskRequestParams ,
3739 ) -> types .GetTaskResult | types .ErrorData : ... # pragma: no branch
3840
@@ -45,7 +47,7 @@ class GetTaskResultHandlerFnT(Protocol):
4547
4648 async def __call__ (
4749 self ,
48- context : RequestContext [" ClientSession" , Any ],
50+ context : RequestContext [ClientSession ],
4951 params : types .GetTaskPayloadRequestParams ,
5052 ) -> types .GetTaskPayloadResult | types .ErrorData : ... # pragma: no branch
5153
@@ -58,7 +60,7 @@ class ListTasksHandlerFnT(Protocol):
5860
5961 async def __call__ (
6062 self ,
61- context : RequestContext [" ClientSession" , Any ],
63+ context : RequestContext [ClientSession ],
6264 params : types .PaginatedRequestParams | None ,
6365 ) -> types .ListTasksResult | types .ErrorData : ... # pragma: no branch
6466
@@ -71,7 +73,7 @@ class CancelTaskHandlerFnT(Protocol):
7173
7274 async def __call__ (
7375 self ,
74- context : RequestContext [" ClientSession" , Any ],
76+ context : RequestContext [ClientSession ],
7577 params : types .CancelTaskRequestParams ,
7678 ) -> types .CancelTaskResult | types .ErrorData : ... # pragma: no branch
7779
@@ -88,7 +90,7 @@ class TaskAugmentedSamplingFnT(Protocol):
8890
8991 async def __call__ (
9092 self ,
91- context : RequestContext [" ClientSession" , Any ],
93+ context : RequestContext [ClientSession ],
9294 params : types .CreateMessageRequestParams ,
9395 task_metadata : types .TaskMetadata ,
9496 ) -> types .CreateTaskResult | types .ErrorData : ... # pragma: no branch
@@ -106,14 +108,14 @@ class TaskAugmentedElicitationFnT(Protocol):
106108
107109 async def __call__ (
108110 self ,
109- context : RequestContext [" ClientSession" , Any ],
111+ context : RequestContext [ClientSession ],
110112 params : types .ElicitRequestParams ,
111113 task_metadata : types .TaskMetadata ,
112114 ) -> types .CreateTaskResult | types .ErrorData : ... # pragma: no branch
113115
114116
115117async def default_get_task_handler (
116- context : RequestContext [" ClientSession" , Any ],
118+ context : RequestContext [ClientSession ],
117119 params : types .GetTaskRequestParams ,
118120) -> types .GetTaskResult | types .ErrorData :
119121 return types .ErrorData (
@@ -123,7 +125,7 @@ async def default_get_task_handler(
123125
124126
125127async def default_get_task_result_handler (
126- context : RequestContext [" ClientSession" , Any ],
128+ context : RequestContext [ClientSession ],
127129 params : types .GetTaskPayloadRequestParams ,
128130) -> types .GetTaskPayloadResult | types .ErrorData :
129131 return types .ErrorData (
@@ -133,7 +135,7 @@ async def default_get_task_result_handler(
133135
134136
135137async def default_list_tasks_handler (
136- context : RequestContext [" ClientSession" , Any ],
138+ context : RequestContext [ClientSession ],
137139 params : types .PaginatedRequestParams | None ,
138140) -> types .ListTasksResult | types .ErrorData :
139141 return types .ErrorData (
@@ -143,7 +145,7 @@ async def default_list_tasks_handler(
143145
144146
145147async def default_cancel_task_handler (
146- context : RequestContext [" ClientSession" , Any ],
148+ context : RequestContext [ClientSession ],
147149 params : types .CancelTaskRequestParams ,
148150) -> types .CancelTaskResult | types .ErrorData :
149151 return types .ErrorData (
@@ -153,7 +155,7 @@ async def default_cancel_task_handler(
153155
154156
155157async def default_task_augmented_sampling (
156- context : RequestContext [" ClientSession" , Any ],
158+ context : RequestContext [ClientSession ],
157159 params : types .CreateMessageRequestParams ,
158160 task_metadata : types .TaskMetadata ,
159161) -> types .CreateTaskResult | types .ErrorData :
@@ -164,7 +166,7 @@ async def default_task_augmented_sampling(
164166
165167
166168async def default_task_augmented_elicitation (
167- context : RequestContext [" ClientSession" , Any ],
169+ context : RequestContext [ClientSession ],
168170 params : types .ElicitRequestParams ,
169171 task_metadata : types .TaskMetadata ,
170172) -> types .CreateTaskResult | types .ErrorData :
@@ -248,7 +250,7 @@ def handles_request(request: types.ServerRequest) -> bool:
248250
249251 async def handle_request (
250252 self ,
251- ctx : RequestContext [" ClientSession" , Any ],
253+ ctx : RequestContext [ClientSession ],
252254 responder : RequestResponder [types .ServerRequest , types .ClientResult ],
253255 ) -> None :
254256 """Handle a task-related request from the server.
0 commit comments