|
26 | 26 | from starlette.applications import Starlette |
27 | 27 | from starlette.requests import Request |
28 | 28 | from starlette.routing import Mount |
| 29 | +from starlette.types import Message |
29 | 30 |
|
30 | 31 | from mcp import MCPError, types |
31 | 32 | from mcp.client.session import ClientSession |
@@ -1755,6 +1756,67 @@ def test_server_rejects_initialize_protocol_version_mismatch( |
1755 | 1756 | assert "protocolVersion" in response.text |
1756 | 1757 |
|
1757 | 1758 |
|
| 1759 | +@pytest.mark.anyio |
| 1760 | +async def test_server_rejects_initialize_protocol_version_mismatch_in_process(): |
| 1761 | + transport = StreamableHTTPServerTransport("/mcp") |
| 1762 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) |
| 1763 | + transport._read_stream_writer = write_stream |
| 1764 | + body = json.dumps( |
| 1765 | + { |
| 1766 | + "jsonrpc": "2.0", |
| 1767 | + "method": "initialize", |
| 1768 | + "params": { |
| 1769 | + "clientInfo": {"name": "test-client", "version": "1.0"}, |
| 1770 | + "protocolVersion": "2025-06-18", |
| 1771 | + "capabilities": {}, |
| 1772 | + }, |
| 1773 | + "id": "init-1", |
| 1774 | + } |
| 1775 | + ).encode() |
| 1776 | + sent: list[Message] = [] |
| 1777 | + received_body = False |
| 1778 | + |
| 1779 | + async def receive() -> Message: |
| 1780 | + nonlocal received_body |
| 1781 | + if received_body: |
| 1782 | + return {"type": "http.disconnect"} |
| 1783 | + |
| 1784 | + received_body = True |
| 1785 | + return {"type": "http.request", "body": body, "more_body": False} |
| 1786 | + |
| 1787 | + async def send(message: Message) -> None: |
| 1788 | + sent.append(message) |
| 1789 | + |
| 1790 | + scope = { |
| 1791 | + "type": "http", |
| 1792 | + "asgi": {"version": "3.0"}, |
| 1793 | + "method": "POST", |
| 1794 | + "path": "/mcp", |
| 1795 | + "raw_path": b"/mcp", |
| 1796 | + "query_string": b"", |
| 1797 | + "headers": [ |
| 1798 | + (b"accept", b"application/json, text/event-stream"), |
| 1799 | + (b"content-type", b"application/json"), |
| 1800 | + (MCP_PROTOCOL_VERSION_HEADER.encode(), b"2025-03-26"), |
| 1801 | + ], |
| 1802 | + "client": ("testclient", 50000), |
| 1803 | + "server": ("testserver", 80), |
| 1804 | + "scheme": "http", |
| 1805 | + } |
| 1806 | + |
| 1807 | + try: |
| 1808 | + await transport.handle_request(scope, receive, send) |
| 1809 | + assert await receive() == {"type": "http.disconnect"} |
| 1810 | + finally: |
| 1811 | + await write_stream.aclose() |
| 1812 | + await read_stream.aclose() |
| 1813 | + |
| 1814 | + assert any(message["type"] == "http.response.start" and message["status"] == 400 for message in sent) |
| 1815 | + response_body = b"".join(message.get("body", b"") for message in sent if message["type"] == "http.response.body") |
| 1816 | + assert MCP_PROTOCOL_VERSION_HEADER.encode() in response_body |
| 1817 | + assert b"protocolVersion" in response_body |
| 1818 | + |
| 1819 | + |
1758 | 1820 | def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): |
1759 | 1821 | """Test server accepts requests without protocol version header.""" |
1760 | 1822 | # First initialize a session to get a valid session ID |
|
0 commit comments