Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ This document contains critical information about working with this codebase. Fo
- Bug fixes require regression tests
- IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns.
- IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible.
- IMPORTANT: Do NOT test private functions (prefixed with `_`). Test them indirectly through the public API.

Test files mirror the source tree: `src/mcp/client/streamable_http.py` → `tests/client/test_streamable_http.py`
Add tests to the existing file for that module.
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"pyjwt[crypto]>=2.10.1",
"typing-extensions>=4.13.0",
"typing-inspection>=0.4.1",
"opentelemetry-api>=1.28.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -71,6 +72,7 @@ dev = [
"coverage[toml]>=7.10.7,<=7.13",
"pillow>=12.0",
"strict-no-cover",
"opentelemetry-sdk>=1.28.0",
]
docs = [
"mkdocs>=1.6.1",
Expand Down
33 changes: 32 additions & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from opentelemetry import trace
from pydantic import BaseModel, TypeAdapter
from typing_extensions import Self

from mcp.shared.exceptions import MCPError
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.shared.response_router import ResponseRouter
from mcp.shared.tracing import end_span_error, end_span_ok, start_client_span, start_server_span
from mcp.types import (
CONNECTION_CLOSED,
INVALID_PARAMS,
Expand Down Expand Up @@ -77,6 +79,7 @@ def __init__(
session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any],
message_metadata: MessageMetadata = None,
span: trace.Span | None = None,
) -> None:
self.request_id = request_id
self.request_meta = request_meta
Expand All @@ -87,6 +90,7 @@ def __init__(
self._cancel_scope = anyio.CancelScope()
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager
self._span = span

def __enter__(self) -> RequestResponder[ReceiveRequestT, SendResultT]:
"""Enter the context manager, enabling request cancellation tracking."""
Expand Down Expand Up @@ -126,6 +130,12 @@ async def respond(self, response: SendResultT | ErrorData) -> None:
if not self.cancelled: # pragma: no branch
self._completed = True

if self._span is not None:
if isinstance(response, ErrorData):
end_span_error(self._span, MCPError(code=response.code, message=response.message))
else:
end_span_ok(self._span)

await self._session._send_response( # type: ignore[reportPrivateUsage]
request_id=self.request_id, response=response
)
Expand All @@ -139,6 +149,10 @@ async def cancel(self) -> None:

self._cancel_scope.cancel()
self._completed = True # Mark as completed so it's removed from in_flight

if self._span is not None:
end_span_error(self._span, MCPError(code=0, message="Request cancelled"))

# Send an error response to indicate cancellation
await self._session._send_response( # type: ignore[reportPrivateUsage]
request_id=self.request_id,
Expand Down Expand Up @@ -260,6 +274,9 @@ async def send_request(
# Store the callback for this request
self._progress_callbacks[request_id] = progress_callback

method: str = request_data["method"]
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the cast needed?

span = start_client_span(method, request_data.get("params"))

try:
jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data)
await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata))
Expand All @@ -278,7 +295,15 @@ async def send_request(
if isinstance(response_or_error, JSONRPCError):
raise MCPError.from_jsonrpc_error(response_or_error)
else:
return result_type.model_validate(response_or_error.result, by_name=False)
result = result_type.model_validate(response_or_error.result, by_name=False)
if span is not None:
end_span_ok(span)
return result

except BaseException as exc:
if span is not None:
end_span_error(span, exc)
raise

finally:
self._response_streams.pop(request_id, None)
Expand Down Expand Up @@ -339,13 +364,19 @@ async def _receive_loop(self) -> None:
message.message.model_dump(by_alias=True, mode="json", exclude_none=True),
by_name=False,
)
request_data = message.message.model_dump(by_alias=True, mode="json", exclude_none=True)
server_span = start_server_span(
request_data.get("method", ""),
request_data.get("params"),
)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single line please. Is the method not always available in request_data?

responder = RequestResponder(
request_id=message.message.id,
request_meta=validated_request.params.meta if validated_request.params else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
span=server_span,
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
Expand Down
81 changes: 81 additions & 0 deletions src/mcp/shared/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

from typing import Any

from opentelemetry import trace
from opentelemetry.trace import StatusCode

_tracer = trace.get_tracer("mcp")

_EXCLUDED_METHODS: frozenset[str] = frozenset({"notifications/message"})

# Semantic convention attribute keys
ATTR_MCP_METHOD_NAME = "mcp.method.name"
ATTR_ERROR_TYPE = "error.type"

# Methods that have a meaningful target name in params
_TARGET_PARAM_KEY: dict[str, str] = {
"tools/call": "name",
"prompts/get": "name",
"resources/read": "uri",
}


def _extract_target(method: str, params: dict[str, Any] | None) -> str | None:
"""Extract the target (e.g. tool name, prompt name) from request params."""
key = _TARGET_PARAM_KEY.get(method)
if key is None or params is None:
return None
value = params.get(key)
if isinstance(value, str):
return value
return None


def start_client_span(method: str, params: dict[str, Any] | None) -> trace.Span | None:
"""Start a CLIENT span for an outgoing MCP request.

Returns None if the method is excluded from tracing.
"""
if method in _EXCLUDED_METHODS:
return None

target = _extract_target(method, params)
span_name = f"{method} {target}" if target else method
span = _tracer.start_span(
span_name,
kind=trace.SpanKind.CLIENT,
attributes={ATTR_MCP_METHOD_NAME: method},
)
return span


def start_server_span(method: str, params: dict[str, Any] | None) -> trace.Span | None:
"""Start a SERVER span for an incoming MCP request.

Returns None if the method is excluded from tracing.
"""
if method in _EXCLUDED_METHODS:
return None

target = _extract_target(method, params)
span_name = f"{method} {target}" if target else method
span = _tracer.start_span(
span_name,
kind=trace.SpanKind.SERVER,
attributes={ATTR_MCP_METHOD_NAME: method},
)
return span


def end_span_ok(span: trace.Span) -> None:
"""Mark a span as successful and end it."""
span.set_status(StatusCode.OK)
span.end()


def end_span_error(span: trace.Span, error: BaseException) -> None:
"""Mark a span as errored and end it."""
span.set_status(StatusCode.ERROR, str(error))
span.set_attribute(ATTR_ERROR_TYPE, type(error).__qualname__)
span.end()
Loading
Loading