diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index f26944a2d..cd459589a 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -1163,6 +1163,7 @@ async def report_progress(self, progress: float, total: float | None = None, mes progress=progress, total=total, message=message, + related_request_id=self.request_id, ) async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index fb4bb0101..5d5f8b8fc 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -35,6 +35,12 @@ async def test_progress_token_zero_first_call(): # Verify progress notifications assert mock_session.send_progress_notification.call_count == 3, "All progress notifications should be sent" - mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=0.0, total=10.0, message=None) - mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=5.0, total=10.0, message=None) - mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=10.0, total=10.0, message=None) + mock_session.send_progress_notification.assert_any_call( + progress_token=0, progress=0.0, total=10.0, message=None, related_request_id="test-request" + ) + mock_session.send_progress_notification.assert_any_call( + progress_token=0, progress=5.0, total=10.0, message=None, related_request_id="test-request" + ) + mock_session.send_progress_notification.assert_any_call( + progress_token=0, progress=10.0, total=10.0, message=None, related_request_id="test-request" + ) diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 3f253baa8..cfbe6587b 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1,7 +1,7 @@ import base64 from pathlib import Path from typing import Any -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from inline_snapshot import snapshot @@ -10,6 +10,8 @@ from starlette.routing import Mount, Route from mcp.client import Client +from mcp.server.context import ServerRequestContext +from mcp.server.experimental.request_context import Experimental from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.prompts.base import Message, UserMessage @@ -1436,3 +1438,34 @@ def test_streamable_http_no_redirect() -> None: # Verify path values assert streamable_routes[0].path == "/mcp", "Streamable route path should be /mcp" + + +async def test_report_progress_passes_related_request_id(): + """Test that report_progress passes the request_id as related_request_id. + + Without related_request_id, the streamable HTTP transport cannot route + progress notifications to the correct SSE stream, causing them to be + silently dropped. See #953 and #2001. + """ + mock_session = AsyncMock() + mock_session.send_progress_notification = AsyncMock() + + request_context = ServerRequestContext( + request_id="req-abc-123", + session=mock_session, + meta={"progress_token": "tok-1"}, + lifespan_context=None, + experimental=Experimental(), + ) + + ctx = Context(request_context=request_context, mcp_server=MagicMock()) + + await ctx.report_progress(50, 100, message="halfway") + + mock_session.send_progress_notification.assert_awaited_once_with( + progress_token="tok-1", + progress=50, + total=100, + message="halfway", + related_request_id="req-abc-123", + )