Skip to content

Commit ddb29f5

Browse files
feat: add get_session_state() method to ClientSession
This method extracts serializable session state that can be stored in external storage for distributed deployments. Related: #2111
1 parent 7cf3cc9 commit ddb29f5

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

src/mcp/client/session.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import uuid
45
from typing import Any, Protocol
56

67
import anyio.lowlevel
@@ -13,6 +14,7 @@
1314
from mcp.shared._context import RequestContext
1415
from mcp.shared.message import SessionMessage
1516
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
17+
from mcp.shared.session_state import SessionState
1618
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1719
from mcp.types._types import RequestParamsMeta
1820

@@ -132,6 +134,9 @@ def __init__(
132134
self._message_handler = message_handler or _default_message_handler
133135
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
134136
self._server_capabilities: types.ServerCapabilities | None = None
137+
self._server_info: types.Implementation | None = None
138+
self._initialized_sent: bool = False
139+
self._session_id: str = str(uuid.uuid4())
135140
self._experimental_features: ExperimentalClientFeatures | None = None
136141

137142
# Experimental: Task handlers (use defaults if not provided)
@@ -186,8 +191,10 @@ async def initialize(self) -> types.InitializeResult:
186191
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocol_version}")
187192

188193
self._server_capabilities = result.capabilities
194+
self._server_info = result.server_info
189195

190196
await self.send_notification(types.InitializedNotification())
197+
self._initialized_sent = True
191198

192199
return result
193200

@@ -198,6 +205,34 @@ def get_server_capabilities(self) -> types.ServerCapabilities | None:
198205
"""
199206
return self._server_capabilities
200207

208+
def get_session_state(self) -> SessionState:
209+
"""Extract a serializable snapshot of the current session state.
210+
211+
This allows the session state to be stored in external storage
212+
(Redis, database, etc.) for distributed deployments.
213+
214+
Returns:
215+
A SessionState object containing the serializable state
216+
"""
217+
from mcp.shared.version import LATEST_PROTOCOL_VERSION
218+
219+
return SessionState(
220+
session_id=self._session_id,
221+
protocol_version=LATEST_PROTOCOL_VERSION,
222+
next_request_id=self._request_id,
223+
server_capabilities=(
224+
self._server_capabilities.model_dump(by_alias=True, mode="json", exclude_none=True)
225+
if self._server_capabilities
226+
else None
227+
),
228+
server_info=(
229+
self._server_info.model_dump(by_alias=True, mode="json", exclude_none=True)
230+
if self._server_info
231+
else None
232+
),
233+
initialized_sent=self._initialized_sent,
234+
)
235+
201236
@property
202237
def experimental(self) -> ExperimentalClientFeatures:
203238
"""Experimental APIs for tasks and other features.

tests/client/test_session.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,3 +706,74 @@ async def mock_server():
706706
await session.initialize()
707707

708708
await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta)
709+
710+
711+
@pytest.mark.anyio
712+
async def test_client_session_get_state():
713+
"""Test that get_session_state() returns a valid SessionState."""
714+
from mcp.shared.session_state import SessionState
715+
716+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
717+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
718+
719+
async def mock_server():
720+
session_message = await client_to_server_receive.receive()
721+
jsonrpc_request = session_message.message
722+
assert isinstance(jsonrpc_request, JSONRPCRequest)
723+
724+
result = InitializeResult(
725+
protocol_version=LATEST_PROTOCOL_VERSION,
726+
capabilities=ServerCapabilities(
727+
logging=None,
728+
resources=None,
729+
tools=None,
730+
experimental=None,
731+
prompts=None,
732+
),
733+
server_info=Implementation(name="mock-server", version="0.1.0"),
734+
)
735+
736+
async with server_to_client_send:
737+
await server_to_client_send.send(
738+
SessionMessage(
739+
JSONRPCResponse(
740+
jsonrpc="2.0",
741+
id=jsonrpc_request.id,
742+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
743+
)
744+
)
745+
)
746+
await client_to_server_receive.receive()
747+
748+
async with (
749+
ClientSession(
750+
server_to_client_receive,
751+
client_to_server_send,
752+
) as session,
753+
anyio.create_task_group() as tg,
754+
client_to_server_send,
755+
client_to_server_receive,
756+
server_to_client_send,
757+
server_to_client_receive,
758+
):
759+
tg.start_soon(mock_server)
760+
761+
# Initialize the session
762+
await session.initialize()
763+
764+
# Get session state
765+
state = session.get_session_state()
766+
767+
# Verify the state
768+
assert isinstance(state, SessionState)
769+
assert state.session_id is not None
770+
assert state.protocol_version == LATEST_PROTOCOL_VERSION
771+
assert state.next_request_id == 1 # After initialize request
772+
assert state.server_capabilities is not None
773+
assert state.server_info is not None
774+
assert state.initialized_sent is True
775+
776+
# Verify it's serializable
777+
json_str = state.model_dump_json()
778+
assert json_str is not None
779+

0 commit comments

Comments
 (0)