Skip to content

Commit 1107118

Browse files
committed
fix: send Origin header for streamable HTTP
1 parent 616476f commit 1107118

2 files changed

Lines changed: 46 additions & 6 deletions

File tree

src/mcp/client/streamable_http.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,12 @@ def __init__(self, url: str) -> None:
7979
url: The endpoint URL.
8080
"""
8181
self.url = url
82+
parsed_url = httpx.URL(url)
83+
self.origin = f"{parsed_url.scheme}://{parsed_url.netloc.decode()}" if parsed_url.netloc else None
8284
self.session_id: str | None = None
8385
self.protocol_version: str | None = None
8486

85-
def _prepare_headers(self) -> dict[str, str]:
87+
def _prepare_headers(self, client: httpx.AsyncClient | None = None) -> dict[str, str]:
8688
"""Build MCP-specific request headers.
8789
8890
These headers will be merged with the httpx.AsyncClient's default headers,
@@ -92,6 +94,8 @@ def _prepare_headers(self) -> dict[str, str]:
9294
"accept": "application/json, text/event-stream",
9395
"content-type": "application/json",
9496
}
97+
if self.origin and (client is None or "origin" not in client.headers):
98+
headers["origin"] = self.origin
9599
# Add session headers if available
96100
if self.session_id:
97101
headers[MCP_SESSION_ID] = self.session_id
@@ -189,7 +193,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:
189193
if not self.session_id:
190194
return
191195

192-
headers = self._prepare_headers()
196+
headers = self._prepare_headers(client)
193197
if last_event_id:
194198
headers[LAST_EVENT_ID] = last_event_id
195199

@@ -225,7 +229,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:
225229

226230
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
227231
"""Handle a resumption request using GET with SSE."""
228-
headers = self._prepare_headers()
232+
headers = self._prepare_headers(ctx.client)
229233
if ctx.metadata and ctx.metadata.resumption_token:
230234
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
231235
else:
@@ -253,7 +257,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
253257

254258
async def _handle_post_request(self, ctx: RequestContext) -> None:
255259
"""Handle a POST request with response processing."""
256-
headers = self._prepare_headers()
260+
headers = self._prepare_headers(ctx.client)
257261
message = ctx.session_message.message
258262
is_initialization = self._is_initialization_request(message)
259263

@@ -388,7 +392,7 @@ async def _handle_reconnection(
388392
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
389393
await anyio.sleep(delay_ms / 1000.0)
390394

391-
headers = self._prepare_headers()
395+
headers = self._prepare_headers(ctx.client)
392396
headers[LAST_EVENT_ID] = last_event_id
393397

394398
# Extract original request ID to map responses
@@ -496,7 +500,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
496500
return # pragma: no cover
497501

498502
try:
499-
headers = self._prepare_headers()
503+
headers = self._prepare_headers(client)
500504
response = await client.delete(self.url, headers=headers)
501505

502506
if response.status_code == 405:

tests/shared/test_streamable_http.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,3 +2318,39 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(
23182318

23192319
assert "content-type" in headers_data
23202320
assert headers_data["content-type"] == "application/json"
2321+
2322+
2323+
@pytest.mark.anyio
2324+
async def test_streamable_http_client_adds_origin_header(context_aware_server: None, basic_server_url: str) -> None:
2325+
async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream):
2326+
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
2327+
await session.initialize()
2328+
2329+
tool_result = await session.call_tool("echo_headers", {})
2330+
assert len(tool_result.content) == 1
2331+
assert isinstance(tool_result.content[0], TextContent)
2332+
headers_data = json.loads(tool_result.content[0].text)
2333+
2334+
assert headers_data["origin"] == basic_server_url
2335+
2336+
2337+
@pytest.mark.anyio
2338+
async def test_streamable_http_client_preserves_custom_origin_header(
2339+
context_aware_server: None, basic_server_url: str
2340+
) -> None:
2341+
custom_origin = "https://proxy.example"
2342+
2343+
async with create_mcp_http_client(headers={"Origin": custom_origin}) as httpx_client:
2344+
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as (
2345+
read_stream,
2346+
write_stream,
2347+
):
2348+
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
2349+
await session.initialize()
2350+
2351+
tool_result = await session.call_tool("echo_headers", {})
2352+
assert len(tool_result.content) == 1
2353+
assert isinstance(tool_result.content[0], TextContent)
2354+
headers_data = json.loads(tool_result.content[0].text)
2355+
2356+
assert headers_data["origin"] == custom_origin

0 commit comments

Comments
 (0)