@@ -478,23 +478,11 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
478478 await response (scope , receive , send )
479479 return
480480
481- # Check if this is an initialization request
482- is_initialization_request = isinstance (message , JSONRPCRequest ) and message .method == "initialize"
483-
484- if is_initialization_request :
485- # Check if the server already has an established session
486- if self .mcp_session_id :
487- # Check if request has a session ID
488- request_session_id = self ._get_session_id (request )
489-
490- # If request has a session ID but doesn't match, return 404
491- if request_session_id and request_session_id != self .mcp_session_id : # pragma: no cover
492- response = self ._create_error_response (
493- "Not Found: Invalid or expired session ID" ,
494- HTTPStatus .NOT_FOUND ,
495- )
496- await response (scope , receive , send )
497- return
481+ is_initialization_request = False
482+ if isinstance (message , JSONRPCRequest ) and message .method == "initialize" :
483+ is_initialization_request = True
484+ if not await self ._validate_initialization_request (message , request , send ):
485+ return
498486 elif not await self ._validate_request_headers (request , send ): # pragma: no cover
499487 return
500488
@@ -867,6 +855,44 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
867855
868856 return True
869857
858+ async def _validate_initialization_request (self , message : JSONRPCRequest , request : Request , send : Send ) -> bool :
859+ if not await self ._validate_initialization_protocol_version (message , request , send ):
860+ return False
861+
862+ if not self .mcp_session_id :
863+ return True
864+
865+ request_session_id = self ._get_session_id (request )
866+ if request_session_id and request_session_id != self .mcp_session_id : # pragma: no cover
867+ response = self ._create_error_response (
868+ "Not Found: Invalid or expired session ID" ,
869+ HTTPStatus .NOT_FOUND ,
870+ )
871+ await response (request .scope , request .receive , send )
872+ return False
873+
874+ return True
875+
876+ async def _validate_initialization_protocol_version (
877+ self , message : JSONRPCRequest , request : Request , send : Send
878+ ) -> bool :
879+ header_protocol_version = request .headers .get (MCP_PROTOCOL_VERSION_HEADER )
880+ body_protocol_version = str (message .params .get ("protocolVersion" )) if message .params else None
881+ if (
882+ header_protocol_version is not None
883+ and body_protocol_version is not None
884+ and header_protocol_version != body_protocol_version
885+ ):
886+ response = self ._create_error_response (
887+ f"Bad Request: { MCP_PROTOCOL_VERSION_HEADER } header does not match initialize.params.protocolVersion" ,
888+ HTTPStatus .BAD_REQUEST ,
889+ INVALID_REQUEST ,
890+ )
891+ await response (request .scope , request .receive , send )
892+ return False
893+
894+ return True
895+
870896 async def _replay_events (self , last_event_id : str , request : Request , send : Send ) -> None : # pragma: no cover
871897 """Replays events that would have been sent after the specified event ID.
872898
0 commit comments