|
11 | 11 | import anyio |
12 | 12 | import httpx |
13 | 13 | from anyio.abc import TaskGroup |
14 | | -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
15 | 14 | from httpx_sse import EventSource, ServerSentEvent, aconnect_sse |
16 | 15 | from pydantic import ValidationError |
17 | 16 |
|
18 | 17 | from mcp.client._transport import TransportStreams |
| 18 | +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams |
19 | 19 | from mcp.shared._httpx_utils import create_mcp_http_client |
20 | 20 | from mcp.shared.message import ClientMessageMetadata, SessionMessage |
21 | 21 | from mcp.types import ( |
|
38 | 38 |
|
39 | 39 | # TODO(Marcelo): Put the TransportStreams in a module under shared, so we can import here. |
40 | 40 | SessionMessageOrError = SessionMessage | Exception |
41 | | -StreamWriter = MemoryObjectSendStream[SessionMessageOrError] |
42 | | -StreamReader = MemoryObjectReceiveStream[SessionMessage] |
| 41 | +StreamWriter = ContextSendStream[SessionMessageOrError] |
| 42 | +StreamReader = ContextReceiveStream[SessionMessage] |
43 | 43 |
|
44 | 44 | MCP_SESSION_ID = "mcp-session-id" |
45 | 45 | MCP_PROTOCOL_VERSION = "mcp-protocol-version" |
@@ -434,14 +434,15 @@ async def post_writer( |
434 | 434 | client: httpx.AsyncClient, |
435 | 435 | write_stream_reader: StreamReader, |
436 | 436 | read_stream_writer: StreamWriter, |
437 | | - write_stream: MemoryObjectSendStream[SessionMessage], |
| 437 | + write_stream: ContextSendStream[SessionMessage], |
438 | 438 | start_get_stream: Callable[[], None], |
439 | 439 | tg: TaskGroup, |
440 | 440 | ) -> None: |
441 | 441 | """Handle writing requests to the server.""" |
442 | 442 | try: |
443 | 443 | async with write_stream_reader, read_stream_writer, write_stream: |
444 | | - async for session_message in write_stream_reader: |
| 444 | + |
| 445 | + async def _handle_message(session_message: SessionMessage) -> None: |
445 | 446 | message = session_message.message |
446 | 447 | metadata = ( |
447 | 448 | session_message.metadata |
@@ -478,6 +479,14 @@ async def handle_request_async(): |
478 | 479 | else: |
479 | 480 | await handle_request_async() |
480 | 481 |
|
| 482 | + async for session_message in write_stream_reader: |
| 483 | + sender_ctx = write_stream_reader.last_context |
| 484 | + if sender_ctx is not None: |
| 485 | + async with anyio.create_task_group() as tg_local: |
| 486 | + sender_ctx.run(tg_local.start_soon, _handle_message, session_message) |
| 487 | + else: |
| 488 | + await _handle_message(session_message) # pragma: no cover |
| 489 | + |
481 | 490 | except Exception: # pragma: lax no cover |
482 | 491 | logger.exception("Error in post_writer") |
483 | 492 |
|
@@ -547,8 +556,8 @@ async def streamable_http_client( |
547 | 556 | if not client_provided: |
548 | 557 | await stack.enter_async_context(client) |
549 | 558 |
|
550 | | - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) |
551 | | - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) |
| 559 | + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) |
| 560 | + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) |
552 | 561 |
|
553 | 562 | async with ( |
554 | 563 | read_stream_writer, |
|
0 commit comments