Skip to content

Commit 5f89866

Browse files
Fix: Raise RuntimeError when ClientSession is used without context manager
This change ensures ClientSession is always used within an 'async with' block by checking the context manager state before allowing operations that depend on the background receive loop. Without this check, calling initialize() outside of a context manager causes the program to hang indefinitely because the _receive_loop task is only started in __aenter__. Changes: - Add _entered flag to track context manager state - Override __aenter__ and __aexit__ to manage the flag - Add _check_is_active() method to enforce proper usage - Call _check_is_active() in initialize() to prevent hangs - Add test case for invalid usage Fixes #1564 Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 8f669a7 commit 5f89866

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

src/mcp/client/session.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
from types import TracebackType
45
from typing import Any, Protocol
56

67
import anyio.lowlevel
@@ -107,6 +108,8 @@ class ClientSession(
107108
types.ServerNotification,
108109
]
109110
):
111+
_entered: bool
112+
110113
def __init__(
111114
self,
112115
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
@@ -133,10 +136,29 @@ def __init__(
133136
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
134137
self._server_capabilities: types.ServerCapabilities | None = None
135138
self._experimental_features: ExperimentalClientFeatures | None = None
139+
self._entered = False
136140

137141
# Experimental: Task handlers (use defaults if not provided)
138142
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
139143

144+
async def __aenter__(self) -> ClientSession:
145+
self._entered = True
146+
await super().__aenter__()
147+
return self
148+
149+
async def __aexit__(
150+
self,
151+
exc_type: type[BaseException] | None,
152+
exc_value: BaseException | None,
153+
traceback: TracebackType | None,
154+
) -> None:
155+
self._entered = False
156+
await super().__aexit__(exc_type, exc_value, traceback)
157+
158+
def _check_is_active(self) -> None:
159+
if not self._entered:
160+
raise RuntimeError("ClientSession must be used within an 'async with' block.")
161+
140162
@property
141163
def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]:
142164
return types.server_request_adapter
@@ -146,6 +168,7 @@ def _receive_notification_adapter(self) -> TypeAdapter[types.ServerNotification]
146168
return types.server_notification_adapter
147169

148170
async def initialize(self) -> types.InitializeResult:
171+
self._check_is_active()
149172
sampling = (
150173
(self._sampling_capabilities or types.SamplingCapability())
151174
if self._sampling_callback is not _default_sampling_callback

tests/client/test_session.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Any, cast
4+
35
import anyio
46
import pytest
57

@@ -706,3 +708,18 @@ async def mock_server():
706708
await session.initialize()
707709

708710
await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta)
711+
712+
713+
@pytest.mark.anyio
714+
async def test_initialize_without_context_manager_raises_error():
715+
"""Test that calling initialize() without entering the context manager raises RuntimeError."""
716+
send_stream, receive_stream = anyio.create_memory_object_stream[Any](0)
717+
718+
read_stream = cast(Any, receive_stream)
719+
write_stream = cast(Any, send_stream)
720+
721+
async with send_stream, receive_stream:
722+
session = ClientSession(read_stream, write_stream)
723+
724+
with pytest.raises(RuntimeError, match="must be used within"):
725+
await session.initialize()

0 commit comments

Comments
 (0)