Skip to content

Commit 045462e

Browse files
feat: add from_session_state() classmethod to ClientSession
This classmethod creates a new ClientSession from a previously saved SessionState, enabling distributed deployments. Related: #2111
1 parent ddb29f5 commit 045462e

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

src/mcp/client/session.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,67 @@ def get_session_state(self) -> SessionState:
233233
initialized_sent=self._initialized_sent,
234234
)
235235

236+
@classmethod
237+
def from_session_state(
238+
cls,
239+
state: SessionState,
240+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
241+
write_stream: MemoryObjectSendStream[SessionMessage],
242+
client_info: types.Implementation | None = None,
243+
read_timeout_seconds: float | None = None,
244+
sampling_callback: SamplingFnT | None = None,
245+
elicitation_callback: ElicitationFnT | None = None,
246+
list_roots_callback: ListRootsFnT | None = None,
247+
logging_callback: LoggingFnT | None = None,
248+
message_handler: MessageHandlerFnT | None = None,
249+
*,
250+
sampling_capabilities: types.SamplingCapability | None = None,
251+
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
252+
) -> ClientSession:
253+
"""Create a new ClientSession from a previously saved SessionState.
254+
255+
This restores session context from external storage, allowing
256+
distributed instances to continue a session.
257+
258+
Args:
259+
state: The SessionState to restore from
260+
read_stream: The read stream for receiving messages
261+
write_stream: The write stream for sending messages
262+
client_info: Optional client info (defaults to DEFAULT_CLIENT_INFO)
263+
read_timeout_seconds: Optional read timeout for this session
264+
265+
Returns:
266+
A new ClientSession instance with the restored state
267+
"""
268+
# Create session with default initialization
269+
session = cls(
270+
read_stream=read_stream,
271+
write_stream=write_stream,
272+
client_info=client_info,
273+
read_timeout_seconds=read_timeout_seconds,
274+
sampling_callback=sampling_callback,
275+
elicitation_callback=elicitation_callback,
276+
list_roots_callback=list_roots_callback,
277+
logging_callback=logging_callback,
278+
message_handler=message_handler,
279+
sampling_capabilities=sampling_capabilities,
280+
experimental_task_handlers=experimental_task_handlers,
281+
)
282+
283+
# Restore the state
284+
session._session_id = state.session_id
285+
session._request_id = state.next_request_id
286+
287+
if state.server_capabilities:
288+
session._server_capabilities = types.ServerCapabilities.model_validate(state.server_capabilities)
289+
290+
if state.server_info:
291+
session._server_info = types.Implementation.model_validate(state.server_info)
292+
293+
session._initialized_sent = state.initialized_sent
294+
295+
return session
296+
236297
@property
237298
def experimental(self) -> ExperimentalClientFeatures:
238299
"""Experimental APIs for tasks and other features.

tests/client/test_session.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,3 +777,55 @@ async def mock_server():
777777
json_str = state.model_dump_json()
778778
assert json_str is not None
779779

780+
781+
@pytest.mark.anyio
782+
async def test_client_session_from_state():
783+
"""Test that from_session_state() creates a valid session."""
784+
from mcp.shared.session_state import SessionState
785+
786+
# Create a session state
787+
state = SessionState(
788+
session_id="test-session-from-state",
789+
protocol_version=LATEST_PROTOCOL_VERSION,
790+
next_request_id=5,
791+
server_capabilities={
792+
"tools": {},
793+
"resources": {},
794+
"prompts": None,
795+
"logging": None,
796+
"experimental": None,
797+
},
798+
server_info={"name": "test-server", "version": "1.0.0"},
799+
initialized_sent=True,
800+
)
801+
802+
# Create streams
803+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
804+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
805+
806+
# Create session from state
807+
session = ClientSession.from_session_state(
808+
state,
809+
server_to_client_receive,
810+
client_to_server_send,
811+
)
812+
813+
# Verify the session was created with the correct state
814+
assert session._session_id == "test-session-from-state"
815+
assert session._request_id == 5 # Continues from saved state
816+
assert session._server_capabilities is not None
817+
assert session._initialized_sent is True
818+
assert session._server_info is not None
819+
assert session._server_info.name == "test-server"
820+
assert session._server_info.version == "1.0.0"
821+
822+
# Clean up streams
823+
async with (
824+
client_to_server_send,
825+
client_to_server_receive,
826+
server_to_client_send,
827+
server_to_client_receive,
828+
):
829+
pass
830+
831+

0 commit comments

Comments
 (0)