Skip to content

Commit cee78d3

Browse files
committed
Fix #883: AssertionError with Starlette middleware and Fix #156: Jupyter logging support
1 parent dda845a commit cee78d3

File tree

4 files changed

+101
-7
lines changed

4 files changed

+101
-7
lines changed

src/mcp/client/stdio.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
terminate_windows_process_tree,
2222
)
2323
from mcp.shared.message import SessionMessage
24+
from mcp.shared.jupyter import is_jupyter
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -118,12 +119,18 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
118119
try:
119120
command = _get_executable_command(server.command)
120121

122+
# In Jupyter, we pipe stderr to read it in the main process
123+
# because sys.stderr redirection might not be reliable for subprocesses.
124+
actual_errlog = errlog
125+
if is_jupyter():
126+
actual_errlog = anyio.lowlevel.PIPE
127+
121128
# Open process with stderr piped for capture
122129
process = await _create_platform_compatible_process(
123130
command=command,
124131
args=server.args,
125132
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
126-
errlog=errlog,
133+
errlog=actual_errlog,
127134
cwd=server.cwd,
128135
)
129136
except OSError:
@@ -177,9 +184,29 @@ async def stdin_writer():
177184
except anyio.ClosedResourceError: # pragma: no cover
178185
await anyio.lowlevel.checkpoint()
179186

187+
async def stderr_reader():
188+
if not process.stderr:
189+
return
190+
191+
try:
192+
async for chunk in TextReceiveStream(
193+
process.stderr,
194+
encoding=server.encoding,
195+
errors=server.encoding_error_handler,
196+
):
197+
if is_jupyter():
198+
# In Jupyter, we use print() for better output handling
199+
# Red text for stderr
200+
print(f"\033[91m{chunk}\033[0m", end="", flush=True)
201+
else:
202+
print(chunk, file=sys.stderr, end="", flush=True)
203+
except anyio.ClosedResourceError:
204+
await anyio.lowlevel.checkpoint()
205+
180206
async with anyio.create_task_group() as tg, process:
181207
tg.start_soon(stdout_reader)
182208
tg.start_soon(stdin_writer)
209+
tg.start_soon(stderr_reader)
183210
try:
184211
yield read_stream, write_stream
185212
finally:

src/mcp/server/mcpserver/server.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,6 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no
805805
await self._lowlevel_server.run(
806806
streams[0], streams[1], self._lowlevel_server.create_initialization_options()
807807
)
808-
return Response()
809808

810809
# Create routes
811810
routes: list[Route | Mount] = []
@@ -869,15 +868,18 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no
869868
)
870869
else:
871870
# Auth is disabled, no need for RequireAuthMiddleware
872-
# Since handle_sse is an ASGI app, we need to create a compatible endpoint
873-
async def sse_endpoint(request: Request) -> Response: # pragma: no cover
874-
# Convert the Starlette request to ASGI parameters
875-
return await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage]
871+
872+
873+
# Use an ASGI-compatible wrapper to avoid Starlette's high-level route wrapping
874+
# which expects a Response object and causes double-sending.
875+
class HandleSseAsgi:
876+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
877+
await handle_sse(scope, receive, send)
876878

877879
routes.append(
878880
Route(
879881
sse_path,
880-
endpoint=sse_endpoint,
882+
endpoint=HandleSseAsgi(),
881883
methods=["GET"],
882884
)
883885
)

src/mcp/shared/jupyter.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
import sys
3+
4+
def is_jupyter() -> bool:
5+
"""Check if we are running in a Jupyter notebook environment."""
6+
try:
7+
shell = get_ipython().__class__.__name__ # type: ignore
8+
if shell == 'ZMQInteractiveShell':
9+
return True # Jupyter notebook or qtconsole
10+
elif shell == 'TerminalInteractiveShell':
11+
return False # Terminal running IPython
12+
else:
13+
return False # Other type (?)
14+
except NameError:
15+
return False # Probably standard Python interpreter
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
import pytest
3+
from starlette.applications import Starlette
4+
from starlette.middleware import Middleware
5+
from starlette.middleware.base import BaseHTTPMiddleware
6+
from starlette.testclient import TestClient
7+
from mcp.server.mcpserver import MCPServer
8+
from mcp.server.transport_security import TransportSecuritySettings
9+
10+
class MockMiddleware(BaseHTTPMiddleware):
11+
async def dispatch(self, request, call_next):
12+
return await call_next(request)
13+
14+
def test_883_middleware_sse_no_assertion_error():
15+
"""Test that using MCP SSE with Starlette middleware doesn't cause double-response error."""
16+
mcp_server = MCPServer("test-server")
17+
transport_security = TransportSecuritySettings(enable_dns_rebinding_protection=False)
18+
sse_app = mcp_server.sse_app(transport_security=transport_security)
19+
20+
app = Starlette(middleware=[Middleware(MockMiddleware)])
21+
app.mount("/", sse_app)
22+
23+
client = TestClient(app)
24+
25+
# We use a context manager to ensure the stream is closed quickly
26+
with client.stream("GET", "/sse") as response:
27+
assert response.status_code == 200
28+
# Just check headers are there
29+
assert "text/event-stream" in response.headers["content-type"]
30+
31+
def test_883_middleware_post_accepted():
32+
"""Test that POST messages work with middleware."""
33+
mcp_server = MCPServer("test-server")
34+
transport_security = TransportSecuritySettings(enable_dns_rebinding_protection=False)
35+
sse_app = mcp_server.sse_app(transport_security=transport_security)
36+
37+
app = Starlette(middleware=[Middleware(MockMiddleware)])
38+
app.mount("/", sse_app)
39+
40+
client = TestClient(app)
41+
42+
# POST to /messages/ (with invalid session, but should not AssertionError)
43+
response = client.post("/messages/?session_id=00000000000000000000000000000000", json={
44+
"jsonrpc": "2.0",
45+
"method": "notifications/initialized",
46+
"params": {}
47+
})
48+
49+
# 404 is expected here as we didn't establish a real session
50+
assert response.status_code == 404

0 commit comments

Comments
 (0)