Skip to content

Commit 4564376

Browse files
committed
Add richer OTel MCP span attributes
1 parent d5b9155 commit 4564376

File tree

6 files changed

+174
-17
lines changed

6 files changed

+174
-17
lines changed

src/mcp/client/streamable_http.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections.abc import AsyncGenerator, Awaitable, Callable
88
from contextlib import asynccontextmanager
99
from dataclasses import dataclass
10+
from types import TracebackType
1011

1112
import anyio
1213
import httpx
@@ -17,6 +18,7 @@
1718
from mcp.client._transport import TransportStreams
1819
from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams
1920
from mcp.shared._httpx_utils import create_mcp_http_client
21+
from mcp.shared._stream_protocols import WriteStream
2022
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2123
from mcp.types import (
2224
INTERNAL_ERROR,
@@ -512,6 +514,35 @@ def get_session_id(self) -> str | None:
512514
return self.session_id # pragma: no cover
513515

514516

517+
class _SessionAwareWriteStream:
518+
"""Write-stream wrapper that exposes the transport session ID."""
519+
520+
def __init__(self, inner: WriteStream[SessionMessage], transport: StreamableHTTPTransport) -> None:
521+
self._inner = inner
522+
self._transport = transport
523+
524+
async def send(self, item: SessionMessage) -> None:
525+
await self._inner.send(item)
526+
527+
async def aclose(self) -> None:
528+
await self._inner.aclose()
529+
530+
def get_session_id(self) -> str | None:
531+
return self._transport.session_id
532+
533+
async def __aenter__(self) -> _SessionAwareWriteStream:
534+
await self._inner.__aenter__()
535+
return self
536+
537+
async def __aexit__(
538+
self,
539+
exc_type: type[BaseException] | None,
540+
exc_val: BaseException | None,
541+
exc_tb: TracebackType | None,
542+
) -> bool | None:
543+
return await self._inner.__aexit__(exc_type, exc_val, exc_tb)
544+
545+
515546
# TODO(Marcelo): I've dropped the `get_session_id` callback because it breaks the Transport protocol. Is that needed?
516547
# It's a completely wrong abstraction, so removal is a good idea. But if we need the client to find the session ID,
517548
# we should think about a better way to do it. I believe we can achieve it with other means.
@@ -581,7 +612,7 @@ def start_get_stream() -> None:
581612
)
582613

583614
try:
584-
yield read_stream, write_stream
615+
yield read_stream, _SessionAwareWriteStream(write_stream, transport)
585616
finally:
586617
if transport.session_id and terminate_on_close:
587618
await transport.terminate_session(client)

src/mcp/server/lowlevel/server.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,14 @@ async def main():
6666
from mcp.server.streamable_http import EventStore
6767
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
6868
from mcp.server.transport_security import TransportSecuritySettings
69-
from mcp.shared._otel import extract_trace_context, otel_span
69+
from mcp.shared._otel import build_server_span_attributes, extract_trace_context, otel_span
7070
from mcp.shared._stream_protocols import ReadStream, WriteStream
7171
from mcp.shared.exceptions import MCPError
7272
from mcp.shared.message import ServerMessageMetadata, SessionMessage
7373
from mcp.shared.session import RequestResponder
7474

7575
logger = logging.getLogger(__name__)
76+
MCP_SESSION_ID_HEADER = "mcp-session-id"
7677

7778
LifespanResultT = TypeVar("LifespanResultT", default=Any)
7879

@@ -454,28 +455,32 @@ async def _handle_request(
454455
# Extract W3C trace context from _meta (SEP-414).
455456
meta = cast(dict[str, Any] | None, getattr(req.params, "meta", None)) if req.params else None
456457
parent_context = extract_trace_context(meta) if meta is not None else None
458+
request_data = None
459+
close_sse_stream_cb = None
460+
close_standalone_sse_stream_cb = None
461+
if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata):
462+
request_data = message.message_metadata.request_context
463+
close_sse_stream_cb = message.message_metadata.close_sse_stream
464+
close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream
465+
request_headers = getattr(request_data, "headers", None)
466+
session_id = request_headers.get(MCP_SESSION_ID_HEADER) if request_headers is not None else None
457467

458468
with otel_span(
459469
span_name,
460470
kind=SpanKind.SERVER,
461-
attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id},
471+
attributes=build_server_span_attributes(
472+
service_name=self.name,
473+
method=req.method,
474+
request_id=message.request_id,
475+
params=req.params,
476+
session_id=session_id,
477+
),
462478
context=parent_context,
463479
) as span:
464480
if handler := self._request_handlers.get(req.method):
465481
logger.debug("Dispatching request of type %s", type(req).__name__)
466482

467483
try:
468-
# Extract request context and close_sse_stream from message metadata
469-
request_data = None
470-
close_sse_stream_cb = None
471-
close_standalone_sse_stream_cb = None
472-
if message.message_metadata is not None and isinstance(
473-
message.message_metadata, ServerMessageMetadata
474-
):
475-
request_data = message.message_metadata.request_context
476-
close_sse_stream_cb = message.message_metadata.close_sse_stream
477-
close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream
478-
479484
client_capabilities = session.client_params.capabilities if session.client_params else None
480485
task_support = self._experimental_handlers.task_support if self._experimental_handlers else None
481486
# Get task metadata from request params if present

src/mcp/shared/_otel.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from opentelemetry.trace import SpanKind, get_tracer
1212

1313
_tracer = get_tracer("mcp-python-sdk")
14+
MCP_RPC_SYSTEM = "mcp"
1415

1516

1617
@contextmanager
@@ -34,3 +35,54 @@ def inject_trace_context(meta: dict[str, Any]) -> None:
3435
def extract_trace_context(meta: dict[str, Any]) -> Context:
3536
"""Extract W3C trace context from a `_meta` dict."""
3637
return extract(meta)
38+
39+
40+
def build_client_span_attributes(
41+
*,
42+
method: str,
43+
request_id: str | int,
44+
params: dict[str, Any] | None = None,
45+
session_id: str | None = None,
46+
) -> dict[str, Any]:
47+
"""Build OTel attributes for an MCP client request span."""
48+
attributes: dict[str, Any] = {
49+
"rpc.system": MCP_RPC_SYSTEM,
50+
"rpc.method": method,
51+
"mcp.method.name": method,
52+
"jsonrpc.request.id": request_id,
53+
}
54+
55+
if params is not None and (resource_uri := params.get("uri")) is not None:
56+
attributes["mcp.resource.uri"] = resource_uri
57+
58+
if session_id is not None:
59+
attributes["mcp.session.id"] = session_id
60+
61+
return attributes
62+
63+
64+
def build_server_span_attributes(
65+
*,
66+
service_name: str,
67+
method: str,
68+
request_id: str | int,
69+
params: Any = None,
70+
session_id: str | None = None,
71+
) -> dict[str, Any]:
72+
"""Build OTel attributes for an MCP server request span."""
73+
attributes: dict[str, Any] = {
74+
"rpc.system": MCP_RPC_SYSTEM,
75+
"rpc.service": service_name,
76+
"rpc.method": method,
77+
"mcp.method.name": method,
78+
"jsonrpc.request.id": request_id,
79+
}
80+
81+
resource_uri = getattr(params, "uri", None)
82+
if resource_uri is not None:
83+
attributes["mcp.resource.uri"] = str(resource_uri)
84+
85+
if session_id is not None:
86+
attributes["mcp.session.id"] = session_id
87+
88+
return attributes

src/mcp/shared/session.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from collections.abc import Callable
66
from contextlib import AsyncExitStack
77
from types import TracebackType
8-
from typing import Any, Generic, Protocol, TypeVar
8+
from typing import Any, Generic, Protocol, TypeVar, cast
99

1010
import anyio
1111
from anyio.streams.memory import MemoryObjectSendStream
1212
from opentelemetry.trace import SpanKind
1313
from pydantic import BaseModel, TypeAdapter
1414
from typing_extensions import Self
1515

16-
from mcp.shared._otel import inject_trace_context, otel_span
16+
from mcp.shared._otel import build_client_span_attributes, inject_trace_context, otel_span
1717
from mcp.shared._stream_protocols import ReadStream, WriteStream
1818
from mcp.shared.exceptions import MCPError
1919
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
@@ -236,6 +236,13 @@ async def __aexit__(
236236
self._task_group.cancel_scope.cancel()
237237
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
238238

239+
def _get_transport_session_id(self) -> str | None:
240+
"""Return the transport session ID when the write stream exposes it."""
241+
get_session_id = getattr(self._write_stream, "get_session_id", None)
242+
if callable(get_session_id):
243+
return cast("str | None", get_session_id())
244+
return None
245+
239246
async def send_request(
240247
self,
241248
request: SendRequestT,
@@ -276,7 +283,12 @@ async def send_request(
276283
with otel_span(
277284
span_name,
278285
kind=SpanKind.CLIENT,
279-
attributes={"mcp.method.name": request.method, "jsonrpc.request.id": request_id},
286+
attributes=build_client_span_attributes(
287+
method=request.method,
288+
request_id=request_id,
289+
params=request_data.get("params"),
290+
session_id=self._get_transport_session_id(),
291+
),
280292
):
281293
# Inject W3C trace context into _meta (SEP-414).
282294
meta: dict[str, Any] = request_data.setdefault("params", {}).setdefault("_meta", {})

tests/shared/test_otel.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,44 @@ def greet(name: str) -> str:
3737
client_span = next(s for s in spans if s["name"] == "MCP send tools/call greet")
3838
server_span = next(s for s in spans if s["name"] == "MCP handle tools/call greet")
3939

40+
assert client_span["attributes"]["rpc.system"] == "mcp"
41+
assert client_span["attributes"]["rpc.method"] == "tools/call"
4042
assert client_span["attributes"]["mcp.method.name"] == "tools/call"
43+
assert server_span["attributes"]["rpc.system"] == "mcp"
44+
assert server_span["attributes"]["rpc.service"] == "test"
45+
assert server_span["attributes"]["rpc.method"] == "tools/call"
4146
assert server_span["attributes"]["mcp.method.name"] == "tools/call"
4247

4348
# Server span should be in the same trace as the client span (context propagation).
4449
assert server_span["context"]["trace_id"] == client_span["context"]["trace_id"]
50+
51+
52+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
53+
async def test_resource_read_spans_include_resource_uri(capfire: CaptureLogfire):
54+
"""Verify that resource reads include MCP resource and RPC attributes."""
55+
server = MCPServer("test")
56+
57+
@server.resource("test://resource")
58+
def test_resource() -> str:
59+
return "hello"
60+
61+
async with Client(server) as client:
62+
result = await client.read_resource("test://resource")
63+
64+
assert result.contents[0].uri == "test://resource"
65+
66+
spans = capfire.exporter.exported_spans_as_dict()
67+
68+
client_span = next(s for s in spans if s["name"] == "MCP send resources/read")
69+
server_span = next(s for s in spans if s["name"] == "MCP handle resources/read")
70+
71+
assert client_span["attributes"]["rpc.system"] == "mcp"
72+
assert client_span["attributes"]["rpc.method"] == "resources/read"
73+
assert client_span["attributes"]["mcp.method.name"] == "resources/read"
74+
assert client_span["attributes"]["mcp.resource.uri"] == "test://resource"
75+
76+
assert server_span["attributes"]["rpc.system"] == "mcp"
77+
assert server_span["attributes"]["rpc.service"] == "test"
78+
assert server_span["attributes"]["rpc.method"] == "resources/read"
79+
assert server_span["attributes"]["mcp.method.name"] == "resources/read"
80+
assert server_span["attributes"]["mcp.resource.uri"] == "test://resource"

tests/shared/test_streamable_http.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import requests
2424
import uvicorn
2525
from httpx_sse import ServerSentEvent
26+
from logfire.testing import CaptureLogfire
2627
from starlette.applications import Starlette
2728
from starlette.requests import Request
2829
from starlette.routing import Mount
@@ -1081,6 +1082,26 @@ async def test_streamable_http_client_resource_read(initialized_client_session:
10811082
assert response.contents[0].text == "Read test-resource"
10821083

10831084

1085+
@pytest.mark.anyio
1086+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
1087+
async def test_streamable_http_resource_read_spans_include_session_id(
1088+
capfire: CaptureLogfire, basic_server: None, basic_server_url: str
1089+
):
1090+
"""Verify streamable HTTP spans include the negotiated MCP session ID."""
1091+
async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream):
1092+
async with ClientSession(read_stream, write_stream) as session:
1093+
await session.initialize()
1094+
response = await session.read_resource(uri="foobar://test-resource")
1095+
1096+
assert response.contents[0].uri == "foobar://test-resource"
1097+
1098+
spans = capfire.exporter.exported_spans_as_dict()
1099+
client_span = next(s for s in spans if s["name"] == "MCP send resources/read")
1100+
1101+
assert client_span["attributes"]["mcp.session.id"]
1102+
assert client_span["attributes"]["mcp.resource.uri"] == "foobar://test-resource"
1103+
1104+
10841105
@pytest.mark.anyio
10851106
async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession):
10861107
"""Test client tool invocation."""

0 commit comments

Comments
 (0)