Skip to content

Commit cef101e

Browse files
committed
fix: reject initialize protocol version conflicts
1 parent f475344 commit cef101e

2 files changed

Lines changed: 80 additions & 17 deletions

File tree

src/mcp/server/streamable_http.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -478,23 +478,11 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
478478
await response(scope, receive, send)
479479
return
480480

481-
# Check if this is an initialization request
482-
is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize"
483-
484-
if is_initialization_request:
485-
# Check if the server already has an established session
486-
if self.mcp_session_id:
487-
# Check if request has a session ID
488-
request_session_id = self._get_session_id(request)
489-
490-
# If request has a session ID but doesn't match, return 404
491-
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
492-
response = self._create_error_response(
493-
"Not Found: Invalid or expired session ID",
494-
HTTPStatus.NOT_FOUND,
495-
)
496-
await response(scope, receive, send)
497-
return
481+
is_initialization_request = False
482+
if isinstance(message, JSONRPCRequest) and message.method == "initialize":
483+
is_initialization_request = True
484+
if not await self._validate_initialization_request(message, request, send):
485+
return
498486
elif not await self._validate_request_headers(request, send): # pragma: no cover
499487
return
500488

@@ -867,6 +855,44 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
867855

868856
return True
869857

858+
async def _validate_initialization_request(self, message: JSONRPCRequest, request: Request, send: Send) -> bool:
859+
if not await self._validate_initialization_protocol_version(message, request, send):
860+
return False
861+
862+
if not self.mcp_session_id:
863+
return True
864+
865+
request_session_id = self._get_session_id(request)
866+
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
867+
response = self._create_error_response(
868+
"Not Found: Invalid or expired session ID",
869+
HTTPStatus.NOT_FOUND,
870+
)
871+
await response(request.scope, request.receive, send)
872+
return False
873+
874+
return True
875+
876+
async def _validate_initialization_protocol_version(
877+
self, message: JSONRPCRequest, request: Request, send: Send
878+
) -> bool:
879+
header_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
880+
body_protocol_version = str(message.params.get("protocolVersion")) if message.params else None
881+
if (
882+
header_protocol_version is not None
883+
and body_protocol_version is not None
884+
and header_protocol_version != body_protocol_version
885+
):
886+
response = self._create_error_response(
887+
f"Bad Request: {MCP_PROTOCOL_VERSION_HEADER} header does not match initialize.params.protocolVersion",
888+
HTTPStatus.BAD_REQUEST,
889+
INVALID_REQUEST,
890+
)
891+
await response(request.scope, request.receive, send)
892+
return False
893+
894+
return True
895+
870896
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover
871897
"""Replays events that would have been sent after the specified event ID.
872898

tests/shared/test_streamable_http.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,6 +1718,43 @@ def test_server_validates_protocol_version_header(basic_server: None, basic_serv
17181718
assert response.status_code == 200
17191719

17201720

1721+
@pytest.mark.parametrize(
1722+
("header_version", "body_version"),
1723+
[
1724+
("2025-03-26", "2025-06-18"),
1725+
("2025-06-18", "2025-03-26"),
1726+
],
1727+
)
1728+
def test_server_rejects_initialize_protocol_version_mismatch(
1729+
basic_server: None, basic_server_url: str, header_version: str, body_version: str
1730+
):
1731+
"""Test initialize rejects conflicting protocol versions in header and body."""
1732+
init_request: dict[str, Any] = {
1733+
"jsonrpc": "2.0",
1734+
"method": "initialize",
1735+
"params": {
1736+
"clientInfo": {"name": "test-client", "version": "1.0"},
1737+
"protocolVersion": body_version,
1738+
"capabilities": {},
1739+
},
1740+
"id": "init-1",
1741+
}
1742+
1743+
response = requests.post(
1744+
f"{basic_server_url}/mcp",
1745+
headers={
1746+
"Accept": "application/json, text/event-stream",
1747+
"Content-Type": "application/json",
1748+
MCP_PROTOCOL_VERSION_HEADER: header_version,
1749+
},
1750+
json=init_request,
1751+
)
1752+
1753+
assert response.status_code == 400
1754+
assert MCP_PROTOCOL_VERSION_HEADER in response.text
1755+
assert "protocolVersion" in response.text
1756+
1757+
17211758
def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str):
17221759
"""Test server accepts requests without protocol version header."""
17231760
# First initialize a session to get a valid session ID

0 commit comments

Comments
 (0)