Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ dev = [
"coverage[toml]>=7.10.7,<=7.13",
"pillow>=12.0",
"strict-no-cover",
"pytest-cov>=7.0.0",
]
docs = [
"mkdocs>=1.6.1",
Expand Down
2 changes: 2 additions & 0 deletions src/mcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .server.session import ServerSession
from .server.stdio import stdio_server
from .shared.exceptions import MCPError, UrlElicitationRequiredError
from .shared.session_state import SessionState
from .types import (
CallToolRequest,
ClientCapabilities,
Expand Down Expand Up @@ -114,6 +115,7 @@
"SamplingMessageContentBlock",
"SamplingRole",
"SamplingToolsCapability",
"SessionState",
"ServerCapabilities",
"ServerNotification",
"ServerRequest",
Expand Down
96 changes: 96 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import uuid
from typing import Any, Protocol

import anyio.lowlevel
Expand All @@ -13,6 +14,7 @@
from mcp.shared._context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.session_state import SessionState
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types._types import RequestParamsMeta

Expand Down Expand Up @@ -132,6 +134,9 @@ def __init__(
self._message_handler = message_handler or _default_message_handler
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._server_capabilities: types.ServerCapabilities | None = None
self._server_info: types.Implementation | None = None
self._initialized_sent: bool = False
self._session_id: str = str(uuid.uuid4())
self._experimental_features: ExperimentalClientFeatures | None = None

# Experimental: Task handlers (use defaults if not provided)
Expand Down Expand Up @@ -186,8 +191,10 @@ async def initialize(self) -> types.InitializeResult:
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocol_version}")

self._server_capabilities = result.capabilities
self._server_info = result.server_info

await self.send_notification(types.InitializedNotification())
self._initialized_sent = True

return result

Expand All @@ -198,6 +205,95 @@ def get_server_capabilities(self) -> types.ServerCapabilities | None:
"""
return self._server_capabilities

def get_session_state(self) -> SessionState:
"""Extract a serializable snapshot of the current session state.

This allows the session state to be stored in external storage
(Redis, database, etc.) for distributed deployments.

Returns:
A SessionState object containing the serializable state
"""
from mcp.shared.version import LATEST_PROTOCOL_VERSION

return SessionState(
session_id=self._session_id,
protocol_version=LATEST_PROTOCOL_VERSION,
next_request_id=self._request_id,
server_capabilities=(
self._server_capabilities.model_dump(by_alias=True, mode="json", exclude_none=True)
if self._server_capabilities
else None
),
server_info=(
self._server_info.model_dump(by_alias=True, mode="json", exclude_none=True)
if self._server_info
else None
),
initialized_sent=self._initialized_sent,
)

@classmethod
def from_session_state(
cls,
state: SessionState,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
client_info: types.Implementation | None = None,
read_timeout_seconds: float | None = None,
sampling_callback: SamplingFnT | None = None,
elicitation_callback: ElicitationFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
*,
sampling_capabilities: types.SamplingCapability | None = None,
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
) -> ClientSession:
"""Create a new ClientSession from a previously saved SessionState.

This restores session context from external storage, allowing
distributed instances to continue a session.

Args:
state: The SessionState to restore from
read_stream: The read stream for receiving messages
write_stream: The write stream for sending messages
client_info: Optional client info (defaults to DEFAULT_CLIENT_INFO)
read_timeout_seconds: Optional read timeout for this session

Returns:
A new ClientSession instance with the restored state
"""
# Create session with default initialization
session = cls(
read_stream=read_stream,
write_stream=write_stream,
client_info=client_info,
read_timeout_seconds=read_timeout_seconds,
sampling_callback=sampling_callback,
elicitation_callback=elicitation_callback,
list_roots_callback=list_roots_callback,
logging_callback=logging_callback,
message_handler=message_handler,
sampling_capabilities=sampling_capabilities,
experimental_task_handlers=experimental_task_handlers,
)

# Restore the state
session._session_id = state.session_id
session._request_id = state.next_request_id

if state.server_capabilities:
session._server_capabilities = types.ServerCapabilities.model_validate(state.server_capabilities)

if state.server_info:
session._server_info = types.Implementation.model_validate(state.server_info)

session._initialized_sent = state.initialized_sent

return session

@property
def experimental(self) -> ExperimentalClientFeatures:
"""Experimental APIs for tasks and other features.
Expand Down
3 changes: 3 additions & 0 deletions src/mcp/shared/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .session_state import SessionState

__all__ = ["SessionState"]
51 changes: 51 additions & 0 deletions src/mcp/shared/session_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Serializable session state for distributed deployments.

This module provides a SessionState dataclass that can be serialized to JSON
and stored in external storage (Redis, database, etc.) for distributed deployments.

This enables session state to be shared across multiple server instances,
allowing MCP services to run behind load balancers or in horizontally-scaled
deployments.
"""

from __future__ import annotations

from typing import Any

from pydantic import BaseModel, Field


class SessionState(BaseModel):
"""A serializable snapshot of MCP session state.

This contains the minimal state needed to reconstruct a session context
across process boundaries. Runtime objects (streams, callbacks) are NOT
included as they cannot be serialized and must be recreated.

Attributes:
session_id: Unique identifier for this session
protocol_version: MCP protocol version being used
next_request_id: The next request ID to use (continues sequence)
server_capabilities: Server capabilities from initialization (as dict)
server_info: Server metadata from initialization (as dict)
initialized_sent: Whether the initialized notification was sent
"""

session_id: str = Field(description="Unique identifier for this session")
protocol_version: str = Field(description="MCP protocol version being used")
next_request_id: int = Field(
description="Next request ID to use",
ge=0,
)
server_capabilities: dict[str, Any] | None = Field(
default=None,
description="Server capabilities received during initialization",
)
server_info: dict[str, Any] | None = Field(
default=None,
description="Server information metadata",
)
initialized_sent: bool = Field(
default=False,
description="Whether the initialized notification was sent",
)
Loading