|
1 | 1 | """Tests for StreamableHTTPSessionManager.""" |
2 | 2 |
|
3 | 3 | import json |
| 4 | +from http import HTTPStatus |
4 | 5 | from typing import Any |
5 | 6 | from unittest.mock import AsyncMock, patch |
6 | 7 |
|
7 | 8 | import anyio |
8 | 9 | import pytest |
| 10 | +from starlette.requests import Request |
9 | 11 | from starlette.types import Message, Scope |
10 | 12 |
|
11 | 13 | from mcp.server import streamable_http_manager |
12 | 14 | from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser |
13 | 15 | from mcp.server.auth.provider import AccessToken |
14 | 16 | from mcp.server.lowlevel import Server |
15 | | -from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport |
| 17 | +from mcp.server.streamable_http import ( |
| 18 | + MCP_SESSION_ID_HEADER, |
| 19 | + EventMessage, |
| 20 | + StreamableHTTPServerTransport, |
| 21 | +) |
16 | 22 | from mcp.server.streamable_http_manager import StreamableHTTPSessionManager |
| 23 | +from mcp.shared.message import SessionMessage |
17 | 24 | from mcp.types import INVALID_REQUEST |
18 | 25 |
|
19 | 26 |
|
@@ -555,3 +562,78 @@ async def test_anonymous_session_accepts_anonymous_requests( |
555 | 562 | session_id = await _open_session(manager, None) |
556 | 563 |
|
557 | 564 | assert await _request_session(manager, session_id, None) != 404 |
| 565 | + |
| 566 | + |
| 567 | +@pytest.mark.anyio |
| 568 | +async def test_handle_post_rejects_duplicate_request_id(): |
| 569 | + """Reject a POST whose JSON-RPC id matches an in-flight request on the same session. |
| 570 | +
|
| 571 | + The MCP base protocol forbids reusing a request ID within a session. Prior to the |
| 572 | + fix, the second POST silently overwrote the prior ``_request_streams`` entry, |
| 573 | + leaving the first request hanging forever. Now the duplicate is rejected with |
| 574 | + 409 Conflict and the prior in-flight stream is preserved untouched. |
| 575 | + """ |
| 576 | + transport = StreamableHTTPServerTransport(mcp_session_id=None) |
| 577 | + |
| 578 | + # The early ``writer is None`` guard reads this; the duplicate-id branch never |
| 579 | + # actually sends on it, so a real stream is sufficient. |
| 580 | + read_writer, read_reader = anyio.create_memory_object_stream[SessionMessage | Exception](0) |
| 581 | + transport._read_stream_writer = read_writer |
| 582 | + |
| 583 | + # Seed an in-flight request with id "1". The duplicate-id check must leave this |
| 584 | + # pair in place. |
| 585 | + in_flight_pair = anyio.create_memory_object_stream[EventMessage](0) |
| 586 | + transport._request_streams["1"] = in_flight_pair |
| 587 | + |
| 588 | + body = json.dumps({"jsonrpc": "2.0", "method": "tools/list", "id": 1, "params": {}}).encode() |
| 589 | + |
| 590 | + body_sent = False |
| 591 | + |
| 592 | + async def mock_receive() -> Message: |
| 593 | + nonlocal body_sent |
| 594 | + if body_sent: # pragma: no cover |
| 595 | + await anyio.sleep_forever() |
| 596 | + body_sent = True |
| 597 | + return {"type": "http.request", "body": body, "more_body": False} |
| 598 | + |
| 599 | + sent_messages: list[Message] = [] |
| 600 | + response_body = b"" |
| 601 | + |
| 602 | + async def mock_send(message: Message) -> None: |
| 603 | + nonlocal response_body |
| 604 | + sent_messages.append(message) |
| 605 | + if message["type"] == "http.response.body": |
| 606 | + response_body += message.get("body", b"") |
| 607 | + |
| 608 | + scope = { |
| 609 | + "type": "http", |
| 610 | + "method": "POST", |
| 611 | + "path": "/mcp", |
| 612 | + "headers": [ |
| 613 | + (b"content-type", b"application/json"), |
| 614 | + (b"accept", b"application/json, text/event-stream"), |
| 615 | + ], |
| 616 | + } |
| 617 | + |
| 618 | + request = Request(scope, mock_receive) |
| 619 | + await transport._handle_post_request(scope, request, mock_receive, mock_send) |
| 620 | + |
| 621 | + response_start = next( |
| 622 | + (msg for msg in sent_messages if msg["type"] == "http.response.start"), |
| 623 | + None, |
| 624 | + ) |
| 625 | + assert response_start is not None, "Should have sent a response" |
| 626 | + assert response_start["status"] == HTTPStatus.CONFLICT |
| 627 | + |
| 628 | + error = json.loads(response_body) |
| 629 | + assert error["jsonrpc"] == "2.0" |
| 630 | + assert error["error"]["code"] == INVALID_REQUEST |
| 631 | + assert "already in flight" in error["error"]["message"] |
| 632 | + |
| 633 | + # The pre-existing in-flight stream must remain untouched. |
| 634 | + assert transport._request_streams["1"] is in_flight_pair |
| 635 | + |
| 636 | + in_flight_pair[0].close() |
| 637 | + in_flight_pair[1].close() |
| 638 | + read_writer.close() |
| 639 | + read_reader.close() |
0 commit comments