Skip to content

Commit cf1bbd2

Browse files
committed
feat: expand InitializationState with explicit lifecycle state machine
Expand the InitializationState enum with new states (Stateless, Closing, Closed, Error) and add centralized transition validation via a _VALID_TRANSITIONS table and _transition_state() method. Key changes: - Add Stateless state for sessions that skip the initialization handshake - Add Closing/Closed states for orderly shutdown tracking - Add Error state for unrecoverable failures with recovery paths - Add _transition_state() method that validates transitions against a table - Add initialization_state property (read-only) and is_initialized property - Override __aexit__ to transition through Closing -> Closed on exit - Update _received_request and _received_notification to use new APIs - Add comprehensive test suite (20 tests) covering all state transitions Github-Issue: #1691
1 parent c032854 commit cf1bbd2

File tree

2 files changed

+421
-9
lines changed

2 files changed

+421
-9
lines changed

src/mcp/server/session.py

Lines changed: 118 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
2828
be instantiated directly by users of the MCP framework.
2929
"""
3030

31+
import logging
3132
from enum import Enum
33+
from types import TracebackType
3234
from typing import Any, TypeVar, overload
3335

3436
import anyio
@@ -50,11 +52,59 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
5052
)
5153
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
5254

55+
logger = logging.getLogger(__name__)
56+
5357

5458
class InitializationState(Enum):
59+
"""Represents the lifecycle states of a server session.
60+
61+
State transitions:
62+
NotInitialized -> Initializing -> Initialized -> Closing -> Closed
63+
Stateless -> Closing -> Closed
64+
Any state -> Error (on unrecoverable failure)
65+
"""
66+
5567
NotInitialized = 1
5668
Initializing = 2
5769
Initialized = 3
70+
Stateless = 4
71+
Closing = 5
72+
Closed = 6
73+
Error = 7
74+
75+
76+
# Valid state transitions: maps each state to the set of states it can transition to.
77+
_VALID_TRANSITIONS: dict[InitializationState, set[InitializationState]] = {
78+
InitializationState.NotInitialized: {
79+
InitializationState.Initializing,
80+
InitializationState.Initialized, # client may send notification without prior request
81+
InitializationState.Closing,
82+
InitializationState.Error,
83+
},
84+
InitializationState.Initializing: {
85+
InitializationState.Initialized,
86+
InitializationState.Closing,
87+
InitializationState.Error,
88+
},
89+
InitializationState.Initialized: {
90+
InitializationState.Initializing, # re-initialization
91+
InitializationState.Closing,
92+
InitializationState.Error,
93+
},
94+
InitializationState.Stateless: {
95+
InitializationState.Closing,
96+
InitializationState.Error,
97+
},
98+
InitializationState.Closing: {
99+
InitializationState.Closed,
100+
InitializationState.Error,
101+
},
102+
InitializationState.Closed: set(),
103+
InitializationState.Error: {
104+
InitializationState.Closing,
105+
InitializationState.Closed,
106+
},
107+
}
58108

59109

60110
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
@@ -73,7 +123,7 @@ class ServerSession(
73123
types.ClientNotification,
74124
]
75125
):
76-
_initialized: InitializationState = InitializationState.NotInitialized
126+
_initialization_state: InitializationState = InitializationState.NotInitialized
77127
_client_params: types.InitializeRequestParams | None = None
78128
_experimental_features: ExperimentalServerSessionFeatures | None = None
79129

@@ -86,9 +136,7 @@ def __init__(
86136
) -> None:
87137
super().__init__(read_stream, write_stream)
88138
self._stateless = stateless
89-
self._initialization_state = (
90-
InitializationState.Initialized if stateless else InitializationState.NotInitialized
91-
)
139+
self._initialization_state = InitializationState.Stateless if stateless else InitializationState.NotInitialized
92140

93141
self._init_options = init_options
94142
self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
@@ -104,6 +152,39 @@ def _receive_request_adapter(self) -> TypeAdapter[types.ClientRequest]:
104152
def _receive_notification_adapter(self) -> TypeAdapter[types.ClientNotification]:
105153
return types.client_notification_adapter
106154

155+
@property
156+
def initialization_state(self) -> InitializationState:
157+
"""Return the current initialization state of the session."""
158+
return self._initialization_state
159+
160+
@property
161+
def is_initialized(self) -> bool:
162+
"""Check whether the session is ready to process requests.
163+
164+
Returns True when the session has completed initialization handshake
165+
(Initialized) or is operating in stateless mode (Stateless).
166+
"""
167+
return self._initialization_state in (
168+
InitializationState.Initialized,
169+
InitializationState.Stateless,
170+
)
171+
172+
def _transition_state(self, new_state: InitializationState) -> None:
173+
"""Transition the session to a new state, validating the transition.
174+
175+
Args:
176+
new_state: The target state to transition to.
177+
178+
Raises:
179+
RuntimeError: If the transition is not valid from the current state.
180+
"""
181+
current = self._initialization_state
182+
valid_targets = _VALID_TRANSITIONS.get(current, set())
183+
if new_state not in valid_targets:
184+
raise RuntimeError(f"Invalid session state transition: {current.name} -> {new_state.name}")
185+
logger.debug("Session state transition: %s -> %s", current.name, new_state.name)
186+
self._initialization_state = new_state
187+
107188
@property
108189
def client_params(self) -> types.InitializeRequestParams | None:
109190
return self._client_params
@@ -161,11 +242,34 @@ async def _receive_loop(self) -> None:
161242
async with self._incoming_message_stream_writer:
162243
await super()._receive_loop()
163244

245+
async def __aexit__(
246+
self,
247+
exc_type: type[BaseException] | None,
248+
exc_val: BaseException | None,
249+
exc_tb: TracebackType | None,
250+
) -> bool | None:
251+
"""Clean up the session with proper state transitions."""
252+
try:
253+
if self._initialization_state not in (
254+
InitializationState.Closed,
255+
InitializationState.Closing,
256+
):
257+
self._transition_state(InitializationState.Closing)
258+
except RuntimeError:
259+
logger.debug("Could not transition to Closing from %s", self._initialization_state.name)
260+
try:
261+
return await super().__aexit__(exc_type, exc_val, exc_tb)
262+
finally:
263+
try:
264+
self._transition_state(InitializationState.Closed)
265+
except RuntimeError:
266+
logger.debug("Could not transition to Closed from %s", self._initialization_state.name)
267+
164268
async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]):
165269
match responder.request:
166270
case types.InitializeRequest(params=params):
167271
requested_version = params.protocol_version
168-
self._initialization_state = InitializationState.Initializing
272+
self._transition_state(InitializationState.Initializing)
169273
self._client_params = params
170274
with responder:
171275
await responder.respond(
@@ -185,22 +289,27 @@ async def _received_request(self, responder: RequestResponder[types.ClientReques
185289
instructions=self._init_options.instructions,
186290
)
187291
)
188-
self._initialization_state = InitializationState.Initialized
292+
self._transition_state(InitializationState.Initialized)
189293
case types.PingRequest():
190294
# Ping requests are allowed at any time
191295
pass
192296
case _:
193-
if self._initialization_state != InitializationState.Initialized:
297+
if not self.is_initialized:
194298
raise RuntimeError("Received request before initialization was complete")
195299

196300
async def _received_notification(self, notification: types.ClientNotification) -> None:
197301
# Need this to avoid ASYNC910
198302
await anyio.lowlevel.checkpoint()
199303
match notification:
200304
case types.InitializedNotification():
201-
self._initialization_state = InitializationState.Initialized
305+
# Transition to Initialized if not already there (e.g. stateless mode)
306+
if self._initialization_state in (
307+
InitializationState.NotInitialized,
308+
InitializationState.Initializing,
309+
):
310+
self._transition_state(InitializationState.Initialized)
202311
case _:
203-
if self._initialization_state != InitializationState.Initialized: # pragma: no cover
312+
if not self.is_initialized: # pragma: no cover
204313
raise RuntimeError("Received notification before initialization was complete")
205314

206315
async def send_log_message(

0 commit comments

Comments
 (0)