|
| 1 | +import pytest |
| 2 | +from starlette.applications import Starlette |
| 3 | +from starlette.middleware import Middleware |
| 4 | +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint |
| 5 | +from starlette.requests import Request |
| 6 | +from starlette.responses import Response |
| 7 | +from httpx import AsyncClient, ASGITransport |
| 8 | + |
| 9 | +from mcp.server.mcpserver import MCPServer |
| 10 | +from mcp.server.transport_security import TransportSecuritySettings |
| 11 | + |
| 12 | + |
| 13 | +class MockMiddleware(BaseHTTPMiddleware): |
| 14 | + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: |
| 15 | + return await call_next(request) |
| 16 | + |
| 17 | + |
| 18 | +@pytest.mark.anyio |
| 19 | +async def test_883_middleware_sse_no_assertion_error(): |
| 20 | + """Test that using MCP SSE with Starlette middleware doesn't cause double-response error.""" |
| 21 | + mcp_server = MCPServer("test-server") |
| 22 | + transport_security = TransportSecuritySettings(enable_dns_rebinding_protection=False) |
| 23 | + # Using host="0.0.0.0" avoids auto-protection triggering logic for localhost |
| 24 | + sse_app = mcp_server.sse_app(transport_security=transport_security, host="0.0.0.0") |
| 25 | + |
| 26 | + app = Starlette(middleware=[Middleware(MockMiddleware)]) |
| 27 | + # Mount at root to simplify test paths |
| 28 | + app.mount("/", sse_app) |
| 29 | + |
| 30 | + # Use ASGITransport to properly test the ASGI app stack |
| 31 | + transport = ASGITransport(app=app) |
| 32 | + async with AsyncClient(transport=transport, base_url="http://testserver") as client: |
| 33 | + async with client.stream("GET", "/sse") as response: |
| 34 | + assert response.status_code == 200 |
| 35 | + assert "text/event-stream" in response.headers["content-type"] |
| 36 | + # Consume stream a bit or close immediately |
| 37 | + pass |
| 38 | + |
| 39 | + |
| 40 | +@pytest.mark.anyio |
| 41 | +async def test_883_middleware_post_accepted(): |
| 42 | + """Test that POST messages work with middleware.""" |
| 43 | + mcp_server = MCPServer("test-server") |
| 44 | + transport_security = TransportSecuritySettings(enable_dns_rebinding_protection=False) |
| 45 | + sse_app = mcp_server.sse_app(transport_security=transport_security, host="0.0.0.0") |
| 46 | + |
| 47 | + app = Starlette(middleware=[Middleware(MockMiddleware)]) |
| 48 | + app.mount("/", sse_app) |
| 49 | + |
| 50 | + transport = ASGITransport(app=app) |
| 51 | + async with AsyncClient(transport=transport, base_url="http://testserver") as client: |
| 52 | + response = await client.post( |
| 53 | + "/messages/?session_id=00000000000000000000000000000000", |
| 54 | + json={"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}, |
| 55 | + ) |
| 56 | + # 404 is expected here as we didn't establish a real session |
| 57 | + assert response.status_code == 404 |
0 commit comments