Skip to content

Commit d0b11ad

Browse files
sreenithikziemski
authored andcommitted
refactor: address PR review feedback on transport abstractions
- Remove __init__ and WireMessageT from AbstractBaseSession, making it a pure abstract interface without state management. BaseSession now owns all state. - Convert BaseClientSession from ABC to @runtime_checkable Protocol with structural subtyping. This eliminates the diamond inheritance in ClientSession and removes the need for inheritance to satisfy the interface. - Add missing methods to BaseClientSession protocol: complete(), set_logging_level(), and explicitly declare send_request(), send_notification(), send_progress_notification(). - Remove commented-out import in context.py. - Add comprehensive tests for BaseClientSession protocol satisfaction and E2E usage. - Document all breaking changes in docs/migration.md including: * ClientRequestContext type change to BaseClientSession * Generic callback protocols (SamplingFnT[ClientSession], etc.) * SessionT renamed to SessionT_co * AbstractBaseSession simplification * BaseClientSession Protocol conversion Fixes structural issues identified in PR review: - AbstractBaseSession no longer manages state it doesn't own - Eliminates double initialization of _response_streams and _task_group - Removes Any leakage for WireMessageT - BaseClientSession now earns its place as a pure interface
1 parent d0705df commit d0b11ad

File tree

8 files changed

+388
-123
lines changed

8 files changed

+388
-123
lines changed

docs/migration.md

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,138 @@ await client.read_resource("test://resource")
471471
await client.read_resource(str(my_any_url))
472472
```
473473

474+
### Transport Abstractions Refactored
475+
476+
The session hierarchy has been refactored to support pluggable transport implementations. This introduces several breaking changes:
477+
478+
#### `ClientRequestContext` type changed
479+
480+
`ClientRequestContext` is now `RequestContext[BaseClientSession]` instead of `RequestContext[ClientSession]`. This means callbacks receive the more general `BaseClientSession` type, which may not have all methods available on `ClientSession`.
481+
482+
**Before:**
483+
484+
```python
485+
from mcp.client.context import ClientRequestContext
486+
from mcp.client.session import ClientSession
487+
488+
async def my_callback(context: ClientRequestContext) -> None:
489+
# Could access ClientSession-specific methods
490+
caps = context.session.get_server_capabilities()
491+
```
492+
493+
**After:**
494+
495+
```python
496+
from mcp.client.context import ClientRequestContext
497+
from mcp.client.session import ClientSession
498+
499+
async def my_callback(context: ClientRequestContext) -> None:
500+
# context.session is BaseClientSession - narrow the type if needed
501+
if isinstance(context.session, ClientSession):
502+
caps = context.session.get_server_capabilities()
503+
```
504+
505+
#### Callback protocols are now generic
506+
507+
`sampling_callback`, `elicitation_callback`, and `list_roots_callback` protocols now require explicit type parameters.
508+
509+
**Before:**
510+
511+
```python
512+
from mcp.client.session import SamplingFnT
513+
514+
async def my_sampling(context, params) -> CreateMessageResult:
515+
...
516+
517+
# Type inferred as SamplingFnT
518+
session = ClientSession(..., sampling_callback=my_sampling)
519+
```
520+
521+
**After:**
522+
523+
```python
524+
from mcp.client.session import SamplingFnT, ClientSession
525+
526+
async def my_sampling(
527+
context: RequestContext[ClientSession],
528+
params: CreateMessageRequestParams
529+
) -> CreateMessageResult:
530+
...
531+
532+
# Explicit type annotation recommended
533+
my_sampling_typed: SamplingFnT[ClientSession] = my_sampling
534+
session = ClientSession(..., sampling_callback=my_sampling_typed)
535+
```
536+
537+
#### `SessionT` renamed to `SessionT_co`
538+
539+
In `mcp.shared._context` and `mcp.shared.progress`, the `SessionT` TypeVar has been renamed to `SessionT_co` to follow naming conventions for covariant type variables.
540+
541+
**Before:**
542+
543+
```python
544+
from mcp.shared._context import SessionT
545+
```
546+
547+
**After:**
548+
549+
```python
550+
from mcp.shared._context import SessionT_co
551+
```
552+
553+
#### `AbstractBaseSession` simplified
554+
555+
`AbstractBaseSession` is now a pure abstract interface with no `__init__` method and no `WireMessageT` type parameter. If you were subclassing it directly, you now need to manage all state in your subclass.
556+
557+
**Before:**
558+
559+
```python
560+
from mcp.shared.session import AbstractBaseSession
561+
562+
class MySession(AbstractBaseSession[MyMessage, ...]):
563+
def __init__(self):
564+
super().__init__() # Would set up _response_streams, _task_group
565+
```
566+
567+
**After:**
568+
569+
```python
570+
from mcp.shared.session import AbstractBaseSession
571+
572+
class MySession(AbstractBaseSession[...]):
573+
def __init__(self):
574+
# Manage your own state - no super().__init__() to call
575+
self._my_state = {}
576+
```
577+
578+
#### `BaseClientSession` is now a Protocol
579+
580+
`BaseClientSession` is now a `typing.Protocol` (structural subtyping) instead of an abstract base class. It no longer inherits from `AbstractBaseSession` and requires no inheritance to satisfy.
581+
582+
**Before:**
583+
584+
```python
585+
from mcp.client.base_client_session import BaseClientSession
586+
587+
class MyClientSession(BaseClientSession):
588+
async def initialize(self) -> InitializeResult:
589+
...
590+
```
591+
592+
**After:**
593+
594+
```python
595+
from mcp.client.base_client_session import BaseClientSession
596+
597+
class MyClientSession:
598+
# Just implement the methods - no inheritance needed
599+
async def initialize(self) -> InitializeResult:
600+
...
601+
602+
# Verify protocol satisfaction at runtime
603+
assert isinstance(MyClientSession(), BaseClientSession)
604+
```
605+
474606
## Deprecations
475607

476608
<!-- Add deprecations below -->
Lines changed: 71 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,68 @@
1-
from abc import abstractmethod
21
from typing import Any, TypeVar
32

3+
from typing_extensions import Protocol, runtime_checkable
4+
45
from mcp import types
5-
from mcp.shared.session import CommonBaseSession, ProgressFnT
6+
from mcp.shared.session import ProgressFnT
67
from mcp.types._types import RequestParamsMeta
78

8-
ClientSessionT_contra = TypeVar("ClientSessionT_contra", bound="BaseClientSession", contravariant=True)
9+
ClientSessionT_contra = TypeVar("ClientSessionT_contra", contravariant=True)
910

1011

11-
class BaseClientSession(
12-
CommonBaseSession[
13-
Any,
14-
types.ClientRequest,
15-
types.ClientNotification,
16-
types.ClientResult,
17-
types.ServerRequest,
18-
types.ServerNotification,
19-
]
20-
):
21-
"""Base class for client transport sessions.
12+
@runtime_checkable
13+
class BaseClientSession(Protocol):
14+
"""Protocol defining the interface for MCP client sessions.
2215
23-
The class provides all the methods that a client session should implement,
24-
irrespective of the transport used.
16+
This protocol specifies all methods that a client session must implement,
17+
irrespective of the transport used. Implementations satisfy this protocol
18+
through structural subtyping — no inheritance required.
2519
"""
2620

27-
@abstractmethod
28-
async def initialize(self) -> types.InitializeResult:
29-
"""Initialize the client session."""
30-
raise NotImplementedError
21+
# Methods from AbstractBaseSession (must be explicitly declared in Protocol)
22+
async def send_request(
23+
self,
24+
request: types.ClientRequest,
25+
result_type: type,
26+
request_read_timeout_seconds: float | None = None,
27+
metadata: Any = None,
28+
progress_callback: ProgressFnT | None = None,
29+
) -> Any: ...
30+
31+
async def send_notification(
32+
self,
33+
notification: types.ClientNotification,
34+
related_request_id: Any = None,
35+
) -> None: ...
36+
37+
async def send_progress_notification(
38+
self,
39+
progress_token: types.ProgressToken,
40+
progress: float,
41+
total: float | None = None,
42+
message: str | None = None,
43+
*,
44+
meta: RequestParamsMeta | None = None,
45+
) -> None: ...
3146

32-
@abstractmethod
33-
async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
34-
"""Send a ping request."""
35-
raise NotImplementedError
47+
# Client-specific methods
48+
async def initialize(self) -> types.InitializeResult: ...
3649

37-
@abstractmethod
38-
async def list_resources(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListResourcesResult:
39-
"""Send a resources/list request.
50+
async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: ...
4051

41-
Args:
42-
params: Full pagination parameters including cursor and any future fields
43-
"""
44-
raise NotImplementedError
52+
async def list_resources(
53+
self, *, params: types.PaginatedRequestParams | None = None
54+
) -> types.ListResourcesResult: ...
4555

46-
@abstractmethod
4756
async def list_resource_templates(
4857
self, *, params: types.PaginatedRequestParams | None = None
49-
) -> types.ListResourceTemplatesResult:
50-
"""Send a resources/templates/list request.
51-
52-
Args:
53-
params: Full pagination parameters including cursor and any future fields
54-
"""
55-
raise NotImplementedError
56-
57-
@abstractmethod
58-
async def read_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.ReadResourceResult:
59-
"""Send a resources/read request."""
60-
raise NotImplementedError
61-
62-
@abstractmethod
63-
async def subscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
64-
"""Send a resources/subscribe request."""
65-
raise NotImplementedError
66-
67-
@abstractmethod
68-
async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
69-
"""Send a resources/unsubscribe request."""
70-
raise NotImplementedError
71-
72-
@abstractmethod
58+
) -> types.ListResourceTemplatesResult: ...
59+
60+
async def read_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.ReadResourceResult: ...
61+
62+
async def subscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: ...
63+
64+
async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: ...
65+
7366
async def call_tool(
7467
self,
7568
name: str,
@@ -78,40 +71,33 @@ async def call_tool(
7871
progress_callback: ProgressFnT | None = None,
7972
*,
8073
meta: RequestParamsMeta | None = None,
81-
) -> types.CallToolResult:
82-
"""Send a tools/call request with optional progress callback support."""
83-
raise NotImplementedError
84-
85-
@abstractmethod
86-
async def list_prompts(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListPromptsResult:
87-
"""Send a prompts/list request.
74+
) -> types.CallToolResult: ...
8875

89-
Args:
90-
params: Full pagination parameters including cursor and any future fields
91-
"""
92-
raise NotImplementedError
76+
async def list_prompts(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListPromptsResult: ...
9377

94-
@abstractmethod
9578
async def get_prompt(
9679
self,
9780
name: str,
9881
arguments: dict[str, str] | None = None,
9982
*,
10083
meta: RequestParamsMeta | None = None,
101-
) -> types.GetPromptResult:
102-
"""Send a prompts/get request."""
103-
raise NotImplementedError
104-
105-
@abstractmethod
106-
async def list_tools(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListToolsResult:
107-
"""Send a tools/list request.
108-
109-
Args:
110-
params: Full pagination parameters including cursor and any future fields
111-
"""
112-
raise NotImplementedError
113-
114-
@abstractmethod
115-
async def send_roots_list_changed(self) -> None: # pragma: no cover
116-
"""Send a roots/list_changed notification."""
117-
raise NotImplementedError
84+
) -> types.GetPromptResult: ...
85+
86+
async def list_tools(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListToolsResult: ...
87+
88+
# Missing methods added per review
89+
async def complete(
90+
self,
91+
ref: types.ResourceTemplateReference | types.PromptReference,
92+
argument: dict[str, str],
93+
context_arguments: dict[str, str] | None = None,
94+
) -> types.CompleteResult: ...
95+
96+
async def set_logging_level(
97+
self,
98+
level: types.LoggingLevel,
99+
*,
100+
meta: RequestParamsMeta | None = None,
101+
) -> types.EmptyResult: ...
102+
103+
async def send_roots_list_changed(self) -> None: ...

src/mcp/client/context.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
"""Request context for MCP client handlers."""
22

33
from mcp.client import BaseClientSession
4-
5-
# from mcp.client.session import ClientSession
64
from mcp.shared._context import RequestContext
75

86
ClientRequestContext = RequestContext[BaseClientSession]

src/mcp/client/session.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ class ClientSession(
107107
types.ServerRequest,
108108
types.ServerNotification,
109109
],
110-
BaseClientSession,
111110
):
112111
def __init__(
113112
self,

src/mcp/shared/_context.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
"""Request context for MCP handlers."""
22

33
from dataclasses import dataclass
4-
from typing import Any, Generic
4+
from typing import Generic
55

66
from typing_extensions import TypeVar
77

8-
from mcp.shared.session import CommonBaseSession
98
from mcp.types import RequestId, RequestParamsMeta
109

11-
SessionT_co = TypeVar("SessionT_co", bound=CommonBaseSession[Any, Any, Any, Any, Any, Any], covariant=True)
10+
SessionT_co = TypeVar("SessionT_co", covariant=True)
1211

1312

1413
@dataclass(kw_only=True)

src/mcp/shared/message.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from collections.abc import Awaitable, Callable
88
from dataclasses import dataclass
9-
from typing import TypeVar
109

1110
from mcp.types import JSONRPCMessage, RequestId
1211

@@ -41,8 +40,6 @@ class ServerMessageMetadata:
4140

4241
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
4342

44-
WireMessageT = TypeVar("WireMessageT")
45-
4643

4744
@dataclass
4845
class SessionMessage:

0 commit comments

Comments
 (0)