Skip to content

Commit 05db6c8

Browse files
committed
add Generic types to be used with the new BaseClientSession class
1 parent 272b3bd commit 05db6c8

File tree

10 files changed

+75
-71
lines changed

10 files changed

+75
-71
lines changed

src/mcp/client/base_client_session.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from abc import abstractmethod
2-
from typing import Any
2+
from typing import Any, TypeVar
33

44
from mcp import types
55
from mcp.shared.session import CommonBaseSession, ProgressFnT
66
from mcp.types._types import RequestParamsMeta
77

8-
# from mcp.shared.session import CommonBaseSession
8+
ClientSessionT_contra = TypeVar("ClientSessionT_contra", bound="BaseClientSession", contravariant=True)
9+
910

1011
class BaseClientSession(
1112
CommonBaseSession[

src/mcp/client/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ async def main():
7676
read_timeout_seconds: float | None = None
7777
"""Timeout for read operations."""
7878

79-
sampling_callback: SamplingFnT | None = None
79+
sampling_callback: SamplingFnT[ClientSession] | None = None
8080
"""Callback for handling sampling requests."""
8181

82-
list_roots_callback: ListRootsFnT | None = None
82+
list_roots_callback: ListRootsFnT[ClientSession] | None = None
8383
"""Callback for handling list roots requests."""
8484

8585
logging_callback: LoggingFnT | None = None
@@ -92,7 +92,7 @@ async def main():
9292
client_info: Implementation | None = None
9393
"""Client implementation info to send to server."""
9494

95-
elicitation_callback: ElicitationFnT | None = None
95+
elicitation_callback: ElicitationFnT[ClientSession] | None = None
9696
"""Callback for handling elicitation requests."""
9797

9898
_session: ClientSession | None = field(init=False, default=None)

src/mcp/client/context.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Request context for MCP client handlers."""
22

3-
from mcp.client.session import ClientSession
3+
from mcp.client import BaseClientSession
4+
5+
# from mcp.client.session import ClientSession
46
from mcp.shared._context import RequestContext
57

6-
ClientRequestContext = RequestContext[ClientSession]
8+
ClientRequestContext = RequestContext[BaseClientSession]
79
"""Context for handling incoming requests in a client session.
810
911
This context is passed to client-side callbacks (sampling, elicitation, list_roots) when the server sends requests

src/mcp/client/experimental/task_handlers.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,71 +14,69 @@
1414
from __future__ import annotations
1515

1616
from dataclasses import dataclass, field
17-
from typing import TYPE_CHECKING, Protocol
17+
from typing import Generic, Protocol
1818

1919
from pydantic import TypeAdapter
2020

2121
from mcp import types
22+
from mcp.client.base_client_session import BaseClientSession, ClientSessionT_contra
2223
from mcp.shared._context import RequestContext
2324
from mcp.shared.session import RequestResponder
2425

25-
if TYPE_CHECKING:
26-
from mcp.client.session import ClientSession
2726

28-
29-
class GetTaskHandlerFnT(Protocol):
27+
class GetTaskHandlerFnT(Protocol[ClientSessionT_contra]):
3028
"""Handler for tasks/get requests from server.
3129
3230
WARNING: This is experimental and may change without notice.
3331
"""
3432

3533
async def __call__(
3634
self,
37-
context: RequestContext[ClientSession],
35+
context: RequestContext[ClientSessionT_contra],
3836
params: types.GetTaskRequestParams,
3937
) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch
4038

4139

42-
class GetTaskResultHandlerFnT(Protocol):
40+
class GetTaskResultHandlerFnT(Protocol[ClientSessionT_contra]):
4341
"""Handler for tasks/result requests from server.
4442
4543
WARNING: This is experimental and may change without notice.
4644
"""
4745

4846
async def __call__(
4947
self,
50-
context: RequestContext[ClientSession],
48+
context: RequestContext[ClientSessionT_contra],
5149
params: types.GetTaskPayloadRequestParams,
5250
) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch
5351

5452

55-
class ListTasksHandlerFnT(Protocol):
53+
class ListTasksHandlerFnT(Protocol[ClientSessionT_contra]):
5654
"""Handler for tasks/list requests from server.
5755
5856
WARNING: This is experimental and may change without notice.
5957
"""
6058

6159
async def __call__(
6260
self,
63-
context: RequestContext[ClientSession],
61+
context: RequestContext[ClientSessionT_contra],
6462
params: types.PaginatedRequestParams | None,
6563
) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch
6664

6765

68-
class CancelTaskHandlerFnT(Protocol):
66+
class CancelTaskHandlerFnT(Protocol[ClientSessionT_contra]):
6967
"""Handler for tasks/cancel requests from server.
7068
7169
WARNING: This is experimental and may change without notice.
7270
"""
7371

7472
async def __call__(
7573
self,
76-
context: RequestContext[ClientSession],
74+
context: RequestContext[ClientSessionT_contra],
7775
params: types.CancelTaskRequestParams,
7876
) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch
7977

8078

81-
class TaskAugmentedSamplingFnT(Protocol):
79+
class TaskAugmentedSamplingFnT(Protocol[ClientSessionT_contra]):
8280
"""Handler for task-augmented sampling/createMessage requests from server.
8381
8482
When server sends a CreateMessageRequest with task field, this callback
@@ -90,13 +88,13 @@ class TaskAugmentedSamplingFnT(Protocol):
9088

9189
async def __call__(
9290
self,
93-
context: RequestContext[ClientSession],
91+
context: RequestContext[ClientSessionT_contra],
9492
params: types.CreateMessageRequestParams,
9593
task_metadata: types.TaskMetadata,
9694
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
9795

9896

99-
class TaskAugmentedElicitationFnT(Protocol):
97+
class TaskAugmentedElicitationFnT(Protocol[ClientSessionT_contra]):
10098
"""Handler for task-augmented elicitation/create requests from server.
10199
102100
When server sends an ElicitRequest with task field, this callback
@@ -108,14 +106,14 @@ class TaskAugmentedElicitationFnT(Protocol):
108106

109107
async def __call__(
110108
self,
111-
context: RequestContext[ClientSession],
109+
context: RequestContext[ClientSessionT_contra],
112110
params: types.ElicitRequestParams,
113111
task_metadata: types.TaskMetadata,
114112
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
115113

116114

117115
async def default_get_task_handler(
118-
context: RequestContext[ClientSession],
116+
context: RequestContext[BaseClientSession],
119117
params: types.GetTaskRequestParams,
120118
) -> types.GetTaskResult | types.ErrorData:
121119
return types.ErrorData(
@@ -125,7 +123,7 @@ async def default_get_task_handler(
125123

126124

127125
async def default_get_task_result_handler(
128-
context: RequestContext[ClientSession],
126+
context: RequestContext[BaseClientSession],
129127
params: types.GetTaskPayloadRequestParams,
130128
) -> types.GetTaskPayloadResult | types.ErrorData:
131129
return types.ErrorData(
@@ -135,7 +133,7 @@ async def default_get_task_result_handler(
135133

136134

137135
async def default_list_tasks_handler(
138-
context: RequestContext[ClientSession],
136+
context: RequestContext[BaseClientSession],
139137
params: types.PaginatedRequestParams | None,
140138
) -> types.ListTasksResult | types.ErrorData:
141139
return types.ErrorData(
@@ -145,7 +143,7 @@ async def default_list_tasks_handler(
145143

146144

147145
async def default_cancel_task_handler(
148-
context: RequestContext[ClientSession],
146+
context: RequestContext[BaseClientSession],
149147
params: types.CancelTaskRequestParams,
150148
) -> types.CancelTaskResult | types.ErrorData:
151149
return types.ErrorData(
@@ -155,7 +153,7 @@ async def default_cancel_task_handler(
155153

156154

157155
async def default_task_augmented_sampling(
158-
context: RequestContext[ClientSession],
156+
context: RequestContext[BaseClientSession],
159157
params: types.CreateMessageRequestParams,
160158
task_metadata: types.TaskMetadata,
161159
) -> types.CreateTaskResult | types.ErrorData:
@@ -166,7 +164,7 @@ async def default_task_augmented_sampling(
166164

167165

168166
async def default_task_augmented_elicitation(
169-
context: RequestContext[ClientSession],
167+
context: RequestContext[BaseClientSession],
170168
params: types.ElicitRequestParams,
171169
task_metadata: types.TaskMetadata,
172170
) -> types.CreateTaskResult | types.ErrorData:
@@ -177,7 +175,7 @@ async def default_task_augmented_elicitation(
177175

178176

179177
@dataclass
180-
class ExperimentalTaskHandlers:
178+
class ExperimentalTaskHandlers(Generic[ClientSessionT_contra]):
181179
"""Container for experimental task handlers.
182180
183181
Groups all task-related handlers that handle server -> client requests.
@@ -195,14 +193,16 @@ class ExperimentalTaskHandlers:
195193
"""
196194

197195
# Pure task request handlers
198-
get_task: GetTaskHandlerFnT = field(default=default_get_task_handler)
199-
get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler)
200-
list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler)
201-
cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler)
196+
get_task: GetTaskHandlerFnT[ClientSessionT_contra] = field(default=default_get_task_handler)
197+
get_task_result: GetTaskResultHandlerFnT[ClientSessionT_contra] = field(default=default_get_task_result_handler)
198+
list_tasks: ListTasksHandlerFnT[ClientSessionT_contra] = field(default=default_list_tasks_handler)
199+
cancel_task: CancelTaskHandlerFnT[ClientSessionT_contra] = field(default=default_cancel_task_handler)
202200

203201
# Task-augmented request handlers
204-
augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling)
205-
augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation)
202+
augmented_sampling: TaskAugmentedSamplingFnT[ClientSessionT_contra] = field(default=default_task_augmented_sampling)
203+
augmented_elicitation: TaskAugmentedElicitationFnT[ClientSessionT_contra] = field(
204+
default=default_task_augmented_elicitation
205+
)
206206

207207
def build_capability(self) -> types.ClientTasksCapability | None:
208208
"""Build ClientTasksCapability from the configured handlers.
@@ -250,7 +250,7 @@ def handles_request(request: types.ServerRequest) -> bool:
250250

251251
async def handle_request(
252252
self,
253-
ctx: RequestContext[ClientSession],
253+
ctx: RequestContext[ClientSessionT_contra],
254254
responder: RequestResponder[types.ServerRequest, types.ClientResult],
255255
) -> None:
256256
"""Handle a task-related request from the server.

src/mcp/client/session.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pydantic import TypeAdapter
99

1010
from mcp import types
11-
from mcp.client.base_client_session import BaseClientSession
11+
from mcp.client.base_client_session import BaseClientSession, ClientSessionT_contra
1212
from mcp.client.experimental import ExperimentalClientFeatures
1313
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
1414
from mcp.shared._context import RequestContext
@@ -22,25 +22,25 @@
2222
logger = logging.getLogger("client")
2323

2424

25-
class SamplingFnT(Protocol):
25+
class SamplingFnT(Protocol[ClientSessionT_contra]):
2626
async def __call__(
2727
self,
28-
context: RequestContext[ClientSession],
28+
context: RequestContext[ClientSessionT_contra],
2929
params: types.CreateMessageRequestParams,
3030
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch
3131

3232

33-
class ElicitationFnT(Protocol):
33+
class ElicitationFnT(Protocol[ClientSessionT_contra]):
3434
async def __call__(
3535
self,
36-
context: RequestContext[ClientSession],
36+
context: RequestContext[ClientSessionT_contra],
3737
params: types.ElicitRequestParams,
3838
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch
3939

4040

41-
class ListRootsFnT(Protocol):
41+
class ListRootsFnT(Protocol[ClientSessionT_contra]):
4242
async def __call__(
43-
self, context: RequestContext[ClientSession]
43+
self, context: RequestContext[ClientSessionT_contra]
4444
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch
4545

4646

@@ -62,7 +62,7 @@ async def _default_message_handler(
6262

6363

6464
async def _default_sampling_callback(
65-
context: RequestContext[ClientSession],
65+
context: RequestContext[BaseClientSession],
6666
params: types.CreateMessageRequestParams,
6767
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData:
6868
return types.ErrorData(
@@ -72,7 +72,7 @@ async def _default_sampling_callback(
7272

7373

7474
async def _default_elicitation_callback(
75-
context: RequestContext[ClientSession],
75+
context: RequestContext[BaseClientSession],
7676
params: types.ElicitRequestParams,
7777
) -> types.ElicitResult | types.ErrorData:
7878
return types.ErrorData( # pragma: no cover
@@ -82,7 +82,7 @@ async def _default_elicitation_callback(
8282

8383

8484
async def _default_list_roots_callback(
85-
context: RequestContext[ClientSession],
85+
context: RequestContext[BaseClientSession],
8686
) -> types.ListRootsResult | types.ErrorData:
8787
return types.ErrorData(
8888
code=types.INVALID_REQUEST,
@@ -114,15 +114,15 @@ def __init__(
114114
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
115115
write_stream: MemoryObjectSendStream[SessionMessage],
116116
read_timeout_seconds: float | None = None,
117-
sampling_callback: SamplingFnT | None = None,
118-
elicitation_callback: ElicitationFnT | None = None,
119-
list_roots_callback: ListRootsFnT | None = None,
117+
sampling_callback: SamplingFnT[ClientSession] | None = None,
118+
elicitation_callback: ElicitationFnT[ClientSession] | None = None,
119+
list_roots_callback: ListRootsFnT[ClientSession] | None = None,
120120
logging_callback: LoggingFnT | None = None,
121121
message_handler: MessageHandlerFnT | None = None,
122122
client_info: types.Implementation | None = None,
123123
*,
124124
sampling_capabilities: types.SamplingCapability | None = None,
125-
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
125+
experimental_task_handlers: ExperimentalTaskHandlers[ClientSession] | None = None,
126126
) -> None:
127127
super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds)
128128
self._client_info = client_info or DEFAULT_CLIENT_INFO

src/mcp/client/session_group.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Callable
1212
from dataclasses import dataclass
1313
from types import TracebackType
14-
from typing import Any, TypeAlias
14+
from typing import Any, Generic, TypeAlias
1515

1616
import anyio
1717
import httpx
@@ -20,7 +20,8 @@
2020

2121
import mcp
2222
from mcp import types
23-
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
23+
from mcp.client.base_client_session import ClientSessionT_contra
24+
from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
2425
from mcp.client.sse import sse_client
2526
from mcp.client.stdio import StdioServerParameters
2627
from mcp.client.streamable_http import streamable_http_client
@@ -70,13 +71,13 @@ class StreamableHttpParameters(BaseModel):
7071
# Use dataclass instead of pydantic BaseModel
7172
# because pydantic BaseModel cannot handle Protocol fields.
7273
@dataclass
73-
class ClientSessionParameters:
74+
class ClientSessionParameters(Generic[ClientSessionT_contra]):
7475
"""Parameters for establishing a client session to an MCP server."""
7576

7677
read_timeout_seconds: float | None = None
77-
sampling_callback: SamplingFnT | None = None
78-
elicitation_callback: ElicitationFnT | None = None
79-
list_roots_callback: ListRootsFnT | None = None
78+
sampling_callback: SamplingFnT[ClientSessionT_contra] | None = None
79+
elicitation_callback: ElicitationFnT[ClientSessionT_contra] | None = None
80+
list_roots_callback: ListRootsFnT[ClientSessionT_contra] | None = None
8081
logging_callback: LoggingFnT | None = None
8182
message_handler: MessageHandlerFnT | None = None
8283
client_info: types.Implementation | None = None
@@ -254,7 +255,7 @@ async def connect_with_session(
254255
async def connect_to_server(
255256
self,
256257
server_params: ServerParameters,
257-
session_params: ClientSessionParameters | None = None,
258+
session_params: ClientSessionParameters[ClientSession] | None = None,
258259
) -> mcp.ClientSession:
259260
"""Connects to a single MCP server."""
260261
server_info, session = await self._establish_session(server_params, session_params or ClientSessionParameters())
@@ -263,7 +264,7 @@ async def connect_to_server(
263264
async def _establish_session(
264265
self,
265266
server_params: ServerParameters,
266-
session_params: ClientSessionParameters,
267+
session_params: ClientSessionParameters[ClientSession],
267268
) -> tuple[types.Implementation, mcp.ClientSession]:
268269
"""Establish a client session to an MCP server."""
269270

0 commit comments

Comments
 (0)