Skip to content

Commit 2ab20a2

Browse files
committed
add BaseSession and BaseClientSession classes
1 parent dda845a commit 2ab20a2

File tree

5 files changed

+201
-9
lines changed

5 files changed

+201
-9
lines changed

src/mcp/client/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""MCP Client module."""
22

33
from mcp.client._transport import Transport
4+
from mcp.client.base_client_session import BaseClientSession
45
from mcp.client.client import Client
56
from mcp.client.context import ClientRequestContext
67
from mcp.client.session import ClientSession
78

8-
__all__ = ["Client", "ClientRequestContext", "ClientSession", "Transport"]
9+
__all__ = ["BaseClientSession", "Client", "ClientRequestContext", "ClientSession", "Transport"]
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from abc import abstractmethod
2+
from typing import Any
3+
4+
from mcp import types
5+
from mcp.shared.session import ProgressFnT
6+
from mcp.types._types import RequestParamsMeta
7+
8+
# from mcp.shared.session import CommonBaseSession
9+
10+
class BaseClientSession:
11+
"""Base class for client transport sessions.
12+
13+
The class provides all the methods that a client session should implement,
14+
irrespective of the transport used.
15+
"""
16+
17+
@abstractmethod
18+
async def initialize(self) -> types.InitializeResult:
19+
"""Initialize the client session."""
20+
raise NotImplementedError
21+
22+
@abstractmethod
23+
async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
24+
"""Send a ping request."""
25+
raise NotImplementedError
26+
27+
@abstractmethod
28+
async def send_progress_notification(
29+
self,
30+
progress_token: types.ProgressToken,
31+
progress: float,
32+
total: float | None = None,
33+
message: str | None = None,
34+
) -> None:
35+
"""Sends a progress notification for a request that is currently being processed."""
36+
raise NotImplementedError
37+
38+
@abstractmethod
39+
async def list_resources(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListResourcesResult:
40+
"""Send a resources/list request.
41+
42+
Args:
43+
params: Full pagination parameters including cursor and any future fields
44+
"""
45+
raise NotImplementedError
46+
47+
@abstractmethod
48+
async def list_resource_templates(
49+
self, *, params: types.PaginatedRequestParams | None = None
50+
) -> types.ListResourceTemplatesResult:
51+
"""Send a resources/templates/list request.
52+
53+
Args:
54+
params: Full pagination parameters including cursor and any future fields
55+
"""
56+
raise NotImplementedError
57+
58+
@abstractmethod
59+
async def read_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.ReadResourceResult:
60+
"""Send a resources/read request."""
61+
raise NotImplementedError
62+
63+
@abstractmethod
64+
async def subscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
65+
"""Send a resources/subscribe request."""
66+
raise NotImplementedError
67+
68+
@abstractmethod
69+
async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
70+
"""Send a resources/unsubscribe request."""
71+
raise NotImplementedError
72+
73+
@abstractmethod
74+
async def call_tool(
75+
self,
76+
name: str,
77+
arguments: dict[str, Any] | None = None,
78+
read_timeout_seconds: float | None = None,
79+
progress_callback: ProgressFnT | None = None,
80+
*,
81+
meta: RequestParamsMeta | None = None,
82+
) -> types.CallToolResult:
83+
"""Send a tools/call request with optional progress callback support."""
84+
raise NotImplementedError
85+
86+
@abstractmethod
87+
async def list_prompts(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListPromptsResult:
88+
"""Send a prompts/list request.
89+
90+
Args:
91+
params: Full pagination parameters including cursor and any future fields
92+
"""
93+
raise NotImplementedError
94+
95+
@abstractmethod
96+
async def get_prompt(
97+
self,
98+
name: str,
99+
arguments: dict[str, str] | None = None,
100+
*,
101+
meta: RequestParamsMeta | None = None,
102+
) -> types.GetPromptResult:
103+
"""Send a prompts/get request."""
104+
raise NotImplementedError
105+
106+
@abstractmethod
107+
async def list_tools(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListToolsResult:
108+
"""Send a tools/list request.
109+
110+
Args:
111+
params: Full pagination parameters including cursor and any future fields
112+
"""
113+
raise NotImplementedError
114+
115+
@abstractmethod
116+
async def send_roots_list_changed(self) -> None: # pragma: no cover
117+
"""Send a roots/list_changed notification."""
118+
raise NotImplementedError

src/mcp/client/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pydantic import TypeAdapter
99

1010
from mcp import types
11+
from mcp.client.base_client_session import BaseClientSession
1112
from mcp.client.experimental import ExperimentalClientFeatures
1213
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
1314
from mcp.shared._context import RequestContext
@@ -105,7 +106,8 @@ class ClientSession(
105106
types.ClientResult,
106107
types.ServerRequest,
107108
types.ServerNotification,
108-
]
109+
],
110+
BaseClientSession,
109111
):
110112
def __init__(
111113
self,

src/mcp/shared/_context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
from typing_extensions import TypeVar
77

8-
from mcp.shared.session import BaseSession
8+
from mcp.client import BaseClientSession
9+
from mcp.shared.session import CommonBaseSession
910
from mcp.types import RequestId, RequestParamsMeta
1011

11-
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
12-
12+
SessionT = TypeVar("SessionT", bound=CommonBaseSession[Any, Any, Any, Any, Any] | BaseClientSession)
1313

1414
@dataclass(kw_only=True)
1515
class RequestContext(Generic[SessionT]):

src/mcp/shared/session.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
from abc import ABC, abstractmethod
45
from collections.abc import Callable
56
from contextlib import AsyncExitStack
67
from types import TracebackType
@@ -154,14 +155,84 @@ def cancelled(self) -> bool:
154155
return self._cancel_scope.cancel_called
155156

156157

157-
class BaseSession(
158+
class CommonBaseSession(
159+
ABC,
158160
Generic[
159161
SendRequestT,
160162
SendNotificationT,
161163
SendResultT,
162164
ReceiveRequestT,
163165
ReceiveNotificationT,
164166
],
167+
):
168+
"""Common base class for sessions agnostic to message types.
169+
170+
The class optionally takes in read and write streams, to provide flexibility without streams for transports that
171+
don't require them.
172+
173+
Sessions that do require read-write streams can inherit from this class, and impose the mandatory streams in their
174+
respective constructors.
175+
"""
176+
177+
def __init__(
178+
self,
179+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None,
180+
write_stream: MemoryObjectSendStream[SessionMessage] | None = None,
181+
# If none, reading will never time out
182+
read_timeout_seconds: float | None = None,
183+
) -> None:
184+
self._read_stream = read_stream
185+
self._write_stream = write_stream
186+
self._session_read_timeout_seconds = read_timeout_seconds
187+
self._response_streams = {}
188+
189+
@abstractmethod
190+
async def send_request(
191+
self,
192+
request: SendRequestT,
193+
result_type: type[ReceiveResultT],
194+
request_read_timeout_seconds: float | None = None,
195+
metadata: MessageMetadata = None,
196+
progress_callback: ProgressFnT | None = None,
197+
) -> ReceiveResultT:
198+
"""Sends a request and wait for a response.
199+
200+
Raises an MCPError if the response contains an error. If a request read timeout is provided, it will take
201+
precedence over the session read timeout.
202+
203+
Do not use this method to emit notifications! Use send_notification() instead.
204+
"""
205+
raise NotImplementedError
206+
207+
@abstractmethod
208+
async def send_notification(
209+
self,
210+
notification: SendNotificationT,
211+
related_request_id: RequestId | None = None,
212+
) -> None:
213+
"""Emits a notification, which is a one-way message that does not expect a response."""
214+
raise NotImplementedError
215+
216+
@abstractmethod
217+
async def send_progress_notification(
218+
self,
219+
progress_token: ProgressToken,
220+
progress: float,
221+
total: float | None = None,
222+
message: str | None = None,
223+
) -> None:
224+
"""Sends a progress notification for a request that is currently being processed."""
225+
raise NotImplementedError
226+
227+
228+
class BaseSession(
229+
CommonBaseSession[
230+
SendRequestT,
231+
SendNotificationT,
232+
SendResultT,
233+
ReceiveRequestT,
234+
ReceiveNotificationT,
235+
],
165236
):
166237
"""Implements an MCP "session" on top of read/write streams, including features
167238
like request/response linking, notifications, and progress.
@@ -170,6 +241,8 @@ class BaseSession(
170241
messages when entered.
171242
"""
172243

244+
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
245+
_write_stream: MemoryObjectSendStream[SessionMessage]
173246
_response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]]
174247
_request_id: int
175248
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
@@ -183,11 +256,9 @@ def __init__(
183256
# If none, reading will never time out
184257
read_timeout_seconds: float | None = None,
185258
) -> None:
186-
self._read_stream = read_stream
187-
self._write_stream = write_stream
259+
super().__init__(read_stream, write_stream, read_timeout_seconds)
188260
self._response_streams = {}
189261
self._request_id = 0
190-
self._session_read_timeout_seconds = read_timeout_seconds
191262
self._in_flight = {}
192263
self._progress_callbacks = {}
193264
self._response_routers = []

0 commit comments

Comments
 (0)