@@ -28,7 +28,9 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
2828be instantiated directly by users of the MCP framework.
2929"""
3030
31+ import logging
3132from enum import Enum
33+ from types import TracebackType
3234from typing import Any , TypeVar , overload
3335
3436import anyio
@@ -50,11 +52,59 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
5052)
5153from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
5254
55+ logger = logging .getLogger (__name__ )
56+
5357
5458class InitializationState (Enum ):
59+ """Represents the lifecycle states of a server session.
60+
61+ State transitions:
62+ NotInitialized -> Initializing -> Initialized -> Closing -> Closed
63+ Stateless -> Closing -> Closed
64+ Any state -> Error (on unrecoverable failure)
65+ """
66+
5567 NotInitialized = 1
5668 Initializing = 2
5769 Initialized = 3
70+ Stateless = 4
71+ Closing = 5
72+ Closed = 6
73+ Error = 7
74+
75+
76+ # Valid state transitions: maps each state to the set of states it can transition to.
77+ _VALID_TRANSITIONS : dict [InitializationState , set [InitializationState ]] = {
78+ InitializationState .NotInitialized : {
79+ InitializationState .Initializing ,
80+ InitializationState .Initialized , # client may send notification without prior request
81+ InitializationState .Closing ,
82+ InitializationState .Error ,
83+ },
84+ InitializationState .Initializing : {
85+ InitializationState .Initialized ,
86+ InitializationState .Closing ,
87+ InitializationState .Error ,
88+ },
89+ InitializationState .Initialized : {
90+ InitializationState .Initializing , # re-initialization
91+ InitializationState .Closing ,
92+ InitializationState .Error ,
93+ },
94+ InitializationState .Stateless : {
95+ InitializationState .Closing ,
96+ InitializationState .Error ,
97+ },
98+ InitializationState .Closing : {
99+ InitializationState .Closed ,
100+ InitializationState .Error ,
101+ },
102+ InitializationState .Closed : set (),
103+ InitializationState .Error : {
104+ InitializationState .Closing ,
105+ InitializationState .Closed ,
106+ },
107+ }
58108
59109
60110ServerSessionT = TypeVar ("ServerSessionT" , bound = "ServerSession" )
@@ -73,7 +123,7 @@ class ServerSession(
73123 types .ClientNotification ,
74124 ]
75125):
76- _initialized : InitializationState = InitializationState .NotInitialized
126+ _initialization_state : InitializationState = InitializationState .NotInitialized
77127 _client_params : types .InitializeRequestParams | None = None
78128 _experimental_features : ExperimentalServerSessionFeatures | None = None
79129
@@ -86,9 +136,7 @@ def __init__(
86136 ) -> None :
87137 super ().__init__ (read_stream , write_stream )
88138 self ._stateless = stateless
89- self ._initialization_state = (
90- InitializationState .Initialized if stateless else InitializationState .NotInitialized
91- )
139+ self ._initialization_state = InitializationState .Stateless if stateless else InitializationState .NotInitialized
92140
93141 self ._init_options = init_options
94142 self ._incoming_message_stream_writer , self ._incoming_message_stream_reader = anyio .create_memory_object_stream [
@@ -104,6 +152,39 @@ def _receive_request_adapter(self) -> TypeAdapter[types.ClientRequest]:
104152 def _receive_notification_adapter (self ) -> TypeAdapter [types .ClientNotification ]:
105153 return types .client_notification_adapter
106154
155+ @property
156+ def initialization_state (self ) -> InitializationState :
157+ """Return the current initialization state of the session."""
158+ return self ._initialization_state
159+
160+ @property
161+ def is_initialized (self ) -> bool :
162+ """Check whether the session is ready to process requests.
163+
164+ Returns True when the session has completed initialization handshake
165+ (Initialized) or is operating in stateless mode (Stateless).
166+ """
167+ return self ._initialization_state in (
168+ InitializationState .Initialized ,
169+ InitializationState .Stateless ,
170+ )
171+
172+ def _transition_state (self , new_state : InitializationState ) -> None :
173+ """Transition the session to a new state, validating the transition.
174+
175+ Args:
176+ new_state: The target state to transition to.
177+
178+ Raises:
179+ RuntimeError: If the transition is not valid from the current state.
180+ """
181+ current = self ._initialization_state
182+ valid_targets = _VALID_TRANSITIONS .get (current , set ())
183+ if new_state not in valid_targets :
184+ raise RuntimeError (f"Invalid session state transition: { current .name } -> { new_state .name } " )
185+ logger .debug ("Session state transition: %s -> %s" , current .name , new_state .name )
186+ self ._initialization_state = new_state
187+
107188 @property
108189 def client_params (self ) -> types .InitializeRequestParams | None :
109190 return self ._client_params
@@ -161,11 +242,34 @@ async def _receive_loop(self) -> None:
161242 async with self ._incoming_message_stream_writer :
162243 await super ()._receive_loop ()
163244
245+ async def __aexit__ (
246+ self ,
247+ exc_type : type [BaseException ] | None ,
248+ exc_val : BaseException | None ,
249+ exc_tb : TracebackType | None ,
250+ ) -> bool | None :
251+ """Clean up the session with proper state transitions."""
252+ try :
253+ if self ._initialization_state not in (
254+ InitializationState .Closed ,
255+ InitializationState .Closing ,
256+ ):
257+ self ._transition_state (InitializationState .Closing )
258+ except RuntimeError :
259+ logger .debug ("Could not transition to Closing from %s" , self ._initialization_state .name )
260+ try :
261+ return await super ().__aexit__ (exc_type , exc_val , exc_tb )
262+ finally :
263+ try :
264+ self ._transition_state (InitializationState .Closed )
265+ except RuntimeError :
266+ logger .debug ("Could not transition to Closed from %s" , self ._initialization_state .name )
267+
164268 async def _received_request (self , responder : RequestResponder [types .ClientRequest , types .ServerResult ]):
165269 match responder .request :
166270 case types .InitializeRequest (params = params ):
167271 requested_version = params .protocol_version
168- self ._initialization_state = InitializationState .Initializing
272+ self ._transition_state ( InitializationState .Initializing )
169273 self ._client_params = params
170274 with responder :
171275 await responder .respond (
@@ -185,22 +289,27 @@ async def _received_request(self, responder: RequestResponder[types.ClientReques
185289 instructions = self ._init_options .instructions ,
186290 )
187291 )
188- self ._initialization_state = InitializationState .Initialized
292+ self ._transition_state ( InitializationState .Initialized )
189293 case types .PingRequest ():
190294 # Ping requests are allowed at any time
191295 pass
192296 case _:
193- if self ._initialization_state != InitializationState . Initialized :
297+ if not self .is_initialized :
194298 raise RuntimeError ("Received request before initialization was complete" )
195299
196300 async def _received_notification (self , notification : types .ClientNotification ) -> None :
197301 # Need this to avoid ASYNC910
198302 await anyio .lowlevel .checkpoint ()
199303 match notification :
200304 case types .InitializedNotification ():
201- self ._initialization_state = InitializationState .Initialized
305+ # Transition to Initialized if not already there (e.g. stateless mode)
306+ if self ._initialization_state in (
307+ InitializationState .NotInitialized ,
308+ InitializationState .Initializing ,
309+ ):
310+ self ._transition_state (InitializationState .Initialized )
202311 case _:
203- if self ._initialization_state != InitializationState . Initialized : # pragma: no cover
312+ if not self .is_initialized : # pragma: no cover
204313 raise RuntimeError ("Received notification before initialization was complete" )
205314
206315 async def send_log_message (
0 commit comments