Skip to content

Commit d443932

Browse files
committed
feat: expand InitializationState with explicit lifecycle state machine
Expand the InitializationState enum with new states (Stateless, Closing, Closed, Error) and add centralized transition validation via a _VALID_TRANSITIONS table and _transition_state() method. Key changes: - Add Stateless state for sessions that skip the initialization handshake - Add Closing/Closed states for orderly shutdown tracking - Add Error state for unrecoverable failures with recovery paths - Add _transition_state() method that validates transitions against a table - Add initialization_state property (read-only) and is_initialized property - Override __aexit__ to transition through Closing -> Closed on exit - Update _received_request and _received_notification to use new APIs - Add comprehensive test suite (20 tests) covering all state transitions Github-Issue: #1691
1 parent e5bdd4c commit d443932

File tree

2 files changed

+420
-9
lines changed

2 files changed

+420
-9
lines changed

src/mcp/server/session.py

Lines changed: 114 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
3737
be instantiated directly by users of the MCP framework.
3838
"""
3939

40+
import logging
4041
from enum import Enum
42+
from types import TracebackType
4143
from typing import Any, TypeVar
4244

4345
import anyio
@@ -58,11 +60,58 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
5860
)
5961
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
6062

63+
logger = logging.getLogger(__name__)
64+
6165

6266
class InitializationState(Enum):
67+
"""Represents the lifecycle states of a server session.
68+
69+
State transitions:
70+
NotInitialized -> Initializing -> Initialized -> Closing -> Closed
71+
Stateless -> Closing -> Closed
72+
Any state -> Error (on unrecoverable failure)
73+
"""
74+
6375
NotInitialized = 1
6476
Initializing = 2
6577
Initialized = 3
78+
Stateless = 4
79+
Closing = 5
80+
Closed = 6
81+
Error = 7
82+
83+
84+
# Valid state transitions: maps each state to the set of states it can transition to.
85+
_VALID_TRANSITIONS: dict[InitializationState, set[InitializationState]] = {
86+
InitializationState.NotInitialized: {
87+
InitializationState.Initializing,
88+
InitializationState.Closing,
89+
InitializationState.Error,
90+
},
91+
InitializationState.Initializing: {
92+
InitializationState.Initialized,
93+
InitializationState.Closing,
94+
InitializationState.Error,
95+
},
96+
InitializationState.Initialized: {
97+
InitializationState.Initializing, # re-initialization
98+
InitializationState.Closing,
99+
InitializationState.Error,
100+
},
101+
InitializationState.Stateless: {
102+
InitializationState.Closing,
103+
InitializationState.Error,
104+
},
105+
InitializationState.Closing: {
106+
InitializationState.Closed,
107+
InitializationState.Error,
108+
},
109+
InitializationState.Closed: set(),
110+
InitializationState.Error: {
111+
InitializationState.Closing,
112+
InitializationState.Closed,
113+
},
114+
}
66115

67116

68117
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
@@ -81,7 +130,7 @@ class ServerSession(
81130
types.ClientNotification,
82131
]
83132
):
84-
_initialized: InitializationState = InitializationState.NotInitialized
133+
_initialization_state: InitializationState = InitializationState.NotInitialized
85134
_client_params: types.InitializeRequestParams | None = None
86135
_experimental_features: ExperimentalServerSessionFeatures | None = None
87136

@@ -93,16 +142,47 @@ def __init__(
93142
stateless: bool = False,
94143
) -> None:
95144
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
96-
self._initialization_state = (
97-
InitializationState.Initialized if stateless else InitializationState.NotInitialized
98-
)
145+
self._initialization_state = InitializationState.Stateless if stateless else InitializationState.NotInitialized
99146

100147
self._init_options = init_options
101148
self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
102149
ServerRequestResponder
103150
](0)
104151
self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose())
105152

153+
@property
154+
def initialization_state(self) -> InitializationState:
155+
"""Return the current initialization state of the session."""
156+
return self._initialization_state
157+
158+
@property
159+
def is_initialized(self) -> bool:
160+
"""Check whether the session is ready to process requests.
161+
162+
Returns True when the session has completed initialization handshake
163+
(Initialized) or is operating in stateless mode (Stateless).
164+
"""
165+
return self._initialization_state in (
166+
InitializationState.Initialized,
167+
InitializationState.Stateless,
168+
)
169+
170+
def _transition_state(self, new_state: InitializationState) -> None:
171+
"""Transition the session to a new state, validating the transition.
172+
173+
Args:
174+
new_state: The target state to transition to.
175+
176+
Raises:
177+
RuntimeError: If the transition is not valid from the current state.
178+
"""
179+
current = self._initialization_state
180+
valid_targets = _VALID_TRANSITIONS.get(current, set())
181+
if new_state not in valid_targets:
182+
raise RuntimeError(f"Invalid session state transition: {current.name} -> {new_state.name}")
183+
logger.debug("Session state transition: %s -> %s", current.name, new_state.name)
184+
self._initialization_state = new_state
185+
106186
@property
107187
def client_params(self) -> types.InitializeRequestParams | None:
108188
return self._client_params # pragma: no cover
@@ -160,11 +240,34 @@ async def _receive_loop(self) -> None:
160240
async with self._incoming_message_stream_writer:
161241
await super()._receive_loop()
162242

243+
async def __aexit__(
244+
self,
245+
exc_type: type[BaseException] | None,
246+
exc_val: BaseException | None,
247+
exc_tb: TracebackType | None,
248+
) -> bool | None:
249+
"""Clean up the session with proper state transitions."""
250+
try:
251+
if self._initialization_state not in (
252+
InitializationState.Closed,
253+
InitializationState.Closing,
254+
):
255+
self._transition_state(InitializationState.Closing)
256+
except RuntimeError:
257+
logger.debug("Could not transition to Closing from %s", self._initialization_state.name)
258+
try:
259+
return await super().__aexit__(exc_type, exc_val, exc_tb)
260+
finally:
261+
try:
262+
self._transition_state(InitializationState.Closed)
263+
except RuntimeError:
264+
logger.debug("Could not transition to Closed from %s", self._initialization_state.name)
265+
163266
async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]):
164267
match responder.request.root:
165268
case types.InitializeRequest(params=params):
166269
requested_version = params.protocolVersion
167-
self._initialization_state = InitializationState.Initializing
270+
self._transition_state(InitializationState.Initializing)
168271
self._client_params = params
169272
with responder:
170273
await responder.respond(
@@ -184,22 +287,24 @@ async def _received_request(self, responder: RequestResponder[types.ClientReques
184287
)
185288
)
186289
)
187-
self._initialization_state = InitializationState.Initialized
290+
self._transition_state(InitializationState.Initialized)
188291
case types.PingRequest():
189292
# Ping requests are allowed at any time
190293
pass
191294
case _:
192-
if self._initialization_state != InitializationState.Initialized:
295+
if not self.is_initialized:
193296
raise RuntimeError("Received request before initialization was complete")
194297

195298
async def _received_notification(self, notification: types.ClientNotification) -> None:
196299
# Need this to avoid ASYNC910
197300
await anyio.lowlevel.checkpoint()
198301
match notification.root:
199302
case types.InitializedNotification():
200-
self._initialization_state = InitializationState.Initialized
303+
# Only transition if not already initialized (e.g. stateless mode)
304+
if self._initialization_state == InitializationState.Initializing:
305+
self._transition_state(InitializationState.Initialized)
201306
case _:
202-
if self._initialization_state != InitializationState.Initialized: # pragma: no cover
307+
if not self.is_initialized: # pragma: no cover
203308
raise RuntimeError("Received notification before initialization was complete")
204309

205310
async def send_log_message(

0 commit comments

Comments
 (0)