Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ repos:
- id: ruff
name: mcp-lint
files: ^mcp/
args: [--config, mcp/pyproject.toml, --config, "src = ['mcp/src']", --fix]
args:
[--config, mcp/pyproject.toml, --config, "src = ['mcp/src']", --fix]
- id: ruff-format
name: mcp-format
files: ^mcp/
Expand Down Expand Up @@ -43,6 +44,7 @@ repos:
language: system
entry: make -C api generate-docs
pass_filenames: false
files: ^api/
types_or: [python, toml]
- id: api-typecheck
name: api-typecheck
Expand All @@ -69,7 +71,7 @@ repos:
- id: flagsmith-lint-tests

- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.11.18 # Ensure this matches the version in api/pyproject.toml
rev: 0.11.18 # Ensure this matches the version in api/pyproject.toml
hooks:
- id: uv-lock
name: api-lockcheck
Expand Down
2 changes: 2 additions & 0 deletions mcp/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ requires-python = ">=3.11"
dependencies = [
"fastmcp>=3.3.1,<4.0.0", # Base MCP functionality
"flagsmith-common[otel]>=3.10.0,<4.0.0", # Logging and OTel export
"opentelemetry-instrumentation-httpx>=0.46b0,<1.0.0", # Trace upstream API calls
"prometheus-client>=0.21.0,<1.0.0", # Export Prometheus metrics
"pydantic-settings>=2.0.0,<3.0.0", # Environment-driven configuration
]
Expand All @@ -24,6 +25,7 @@ dev = [
"pytest-asyncio>=1.3.0,<2.0.0", # Run asynchronous tests
"pytest-cov>=7.0.0,<8.0.0", # Measure test coverage
"pytest-mock>=3.15.1,<4.0.0", # Mock via fixtures
"pytest-structlog>=1.1,<2.0.0", # Assert structlog events
"respx>=0.22,<1.0", # Mock HTTP interactions
"ruff>=0.15.12,<0.16.0", # Lint and format
]
Expand Down
3 changes: 3 additions & 0 deletions mcp/src/flagsmith_mcp/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# TODO: consume a version-controlled schema — https://github.com/Flagsmith/flagsmith/issues/7669
OPENAPI_SPEC_URL = "https://api.flagsmith.com/api/v1/swagger.json"
OAUTH_SCOPES = ["mcp"]

# How this service identifies itself as a client of the Flagsmith API.
FLAGSMITH_CLIENT_NAME = "flagsmith-mcp"
65 changes: 65 additions & 0 deletions mcp/src/flagsmith_mcp/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import mcp.types as mt
import structlog
from fastmcp.server.dependencies import get_context
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
from fastmcp.tools.base import ToolResult

logger = structlog.get_logger("mcp")


def get_client_info() -> mt.Implementation | None:
"""The connected client's self-declared identity, captured by the
session during initialize."""
try:
client_params = get_context().session.client_params
except RuntimeError:
return None
if client_params is None:
return None
return client_params.clientInfo


class EventLoggingMiddleware(Middleware):
"""Emit structured product events for MCP sessions and tool calls."""

async def on_initialize(
self,
context: MiddlewareContext[mt.InitializeRequest],
call_next: CallNext[mt.InitializeRequest, mt.InitializeResult | None],
) -> mt.InitializeResult | None:
result = await call_next(context)
client_info = context.message.params.clientInfo
logger.info(
"session.opened",
flagsmith__mcp__client__name=client_info.name,
flagsmith__mcp__client__version=client_info.version,
)
return result

async def on_call_tool(
self,
context: MiddlewareContext[mt.CallToolRequestParams],
call_next: CallNext[mt.CallToolRequestParams, ToolResult],
) -> ToolResult:
client_info = get_client_info()
client_name = client_info.name if client_info else ""
client_version = client_info.version if client_info else ""
try:
result = await call_next(context)
except Exception:
logger.info(
"tool.called",
tool__name=context.message.name,
flagsmith__mcp__client__name=client_name,
flagsmith__mcp__client__version=client_version,
status="error",
)
raise
logger.info(
"tool.called",
tool__name=context.message.name,
flagsmith__mcp__client__name=client_name,
flagsmith__mcp__client__version=client_version,
status="success",
)
return result
19 changes: 14 additions & 5 deletions mcp/src/flagsmith_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
from fastmcp.utilities.components import FastMCPComponent
from fastmcp.utilities.openapi.models import HttpMethod, HTTPRoute
from mcp.types import ToolAnnotations
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from prometheus_client import start_http_server
from starlette.requests import Request
from starlette.responses import PlainTextResponse

from flagsmith_mcp import config, constants
from flagsmith_mcp.auth import FlagsmithAuth
from flagsmith_mcp.events import EventLoggingMiddleware
from flagsmith_mcp.metrics import PrometheusMiddleware
from flagsmith_mcp.oauth import FlagsmithResourceAuth
from flagsmith_mcp.telemetry import setup_telemetry
from flagsmith_mcp.telemetry import propagate_span_attributes, setup_telemetry

ROUTE_MAPS = [
RouteMap(tags={"mcp"}, mcp_type=MCPType.TOOL),
Expand Down Expand Up @@ -56,12 +58,18 @@ def create_server(settings: config.Settings) -> FastMCP[None]:
resource_url=settings.mcp_server_url,
authorization_server=settings.flagsmith_api_url,
)
api_client = httpx.AsyncClient(
base_url=settings.flagsmith_api_url,
auth=FlagsmithAuth(settings.flagsmith_api_token),
event_hooks={"request": [propagate_span_attributes]},
)
# Instrument only the Flagsmith API client: emit a span per upstream
# call and propagate W3C trace context; the event hook passes the MCP
# call context to the API as W3C Baggage.
HTTPXClientInstrumentor().instrument_client(api_client)
server = FastMCP.from_openapi(
openapi_spec=_fetch_spec(),
client=httpx.AsyncClient(
base_url=settings.flagsmith_api_url,
auth=FlagsmithAuth(settings.flagsmith_api_token),
),
client=api_client,
name="Flagsmith",
route_maps=ROUTE_MAPS,
mcp_component_fn=_customise,
Expand All @@ -70,6 +78,7 @@ def create_server(settings: config.Settings) -> FastMCP[None]:
)

server.add_middleware(PrometheusMiddleware())
server.add_middleware(EventLoggingMiddleware())

@server.custom_route("/health", methods=["GET"])
async def health(request: Request) -> PlainTextResponse:
Expand Down
47 changes: 38 additions & 9 deletions mcp/src/flagsmith_mcp/telemetry.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,45 @@
import httpx
from common.core.logging import setup_logging
from common.core.otel import (
add_otel_trace_context,
build_otel_log_provider,
build_tracer_provider,
make_structlog_otel_processor,
)
from opentelemetry import trace
from opentelemetry import baggage, trace
from opentelemetry.baggage.propagation import W3CBaggagePropagator
from opentelemetry.context import Context
from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor, TracerProvider
from structlog.typing import Processor

from flagsmith_mcp import config
from flagsmith_mcp import config, constants
from flagsmith_mcp.events import get_client_info

APPLICATION_LOGGERS = ["flagsmith_mcp", "fastmcp", "mcp"]


class ClientInfoSpanProcessor(SpanProcessor):
"""Annotate started spans with this service's identity and the MCP
client identity."""

def on_start(self, span: Span, parent_context: Context | None = None) -> None:
span.set_attribute("flagsmith.client.name", constants.FLAGSMITH_CLIENT_NAME)
if (client_info := get_client_info()) is not None:
span.set_attribute("flagsmith.mcp.client.name", client_info.name)
span.set_attribute("flagsmith.mcp.client.version", client_info.version)


async def propagate_span_attributes(request: httpx.Request) -> None:
span = trace.get_current_span()
if not isinstance(span, ReadableSpan):
return
ctx: Context | None = None
for key, value in (span.attributes or {}).items():
ctx = baggage.set_baggage(key, str(value), context=ctx)
if ctx is not None:
W3CBaggagePropagator().inject(request.headers, context=ctx)


def setup_telemetry(settings: config.Settings) -> None:
"""Set up logging, exporting structlog events and traces to OpenTelemetry
when an OTLP endpoint is configured."""
Expand All @@ -27,14 +54,16 @@ def setup_telemetry(settings: config.Settings) -> None:
add_otel_trace_context,
make_structlog_otel_processor(log_provider),
]
# Setting a global tracer provider also activates FastMCP's built-in
# per-request server spans.
trace.set_tracer_provider(
build_tracer_provider(
endpoint=f"{endpoint}/v1/traces",
service_name=settings.otel_service_name,
)
tracer_provider = build_tracer_provider(
endpoint=f"{endpoint}/v1/traces",
service_name=settings.otel_service_name,
)
else:
# No exporter: spans stay in-process, but still feed the API
# baggage propagation.
tracer_provider = TracerProvider()
tracer_provider.add_span_processor(ClientInfoSpanProcessor())
trace.set_tracer_provider(tracer_provider)
setup_logging(
log_level=settings.log_level,
log_format=settings.log_format,
Expand Down
25 changes: 25 additions & 0 deletions mcp/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,39 @@
import pytest
from fastmcp import Client, FastMCP
from fastmcp.client.transports import FastMCPTransport
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from respx import MockRouter

from flagsmith_mcp import config, constants
from flagsmith_mcp import server as server_module
from flagsmith_mcp.telemetry import ClientInfoSpanProcessor

HTTPClientFactoryFixture = Callable[[FastMCP], AsyncIterator[httpx.AsyncClient]]


@pytest.fixture(scope="session")
def span_exporter() -> InMemorySpanExporter:
# The global tracer provider can only be set once per process, hence
# the session scope.
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
provider.add_span_processor(ClientInfoSpanProcessor())
trace.set_tracer_provider(provider)
return exporter


@pytest.fixture
def finished_spans(span_exporter: InMemorySpanExporter) -> InMemorySpanExporter:
span_exporter.clear()
return span_exporter


@pytest.fixture
def openapi_spec() -> openapi.OpenAPI:
ok = openapi.Response(description="OK")
Expand Down
69 changes: 69 additions & 0 deletions mcp/tests/integration/test_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
from fastmcp import Client
from fastmcp.client.transports import FastMCPTransport
from fastmcp.exceptions import ToolError
from pytest_structlog import StructuredLogCapture
from respx import MockRouter


async def test_events__session_initialised__emits_session_opened(
log: StructuredLogCapture,
client: Client[FastMCPTransport],
) -> None:
# Given the server started via the client fixture
# When the session is initialised by the fixture
# Then the client's self-declared identity is reported
assert log.has(
"session.opened",
flagsmith__mcp__client__name="mcp",
flagsmith__mcp__client__version="0.1.0",
)


async def test_events__successful_tool_call__emits_tool_called(
log: StructuredLogCapture,
client: Client[FastMCPTransport],
respx_mock: MockRouter,
) -> None:
# Given
respx_mock.get("https://api.flagsmith.com/environments/").respond(
json={"results": []}
)

# When
await client.call_tool("list_environments", {})

# Then
[event] = [e for e in log.events if e["event"] == "tool.called"]
assert event == {
"event": "tool.called",
"level": "info",
"tool__name": "list_environments",
"flagsmith__mcp__client__name": "mcp",
"flagsmith__mcp__client__version": "0.1.0",
"status": "success",
}


async def test_events__failing_tool_call__emits_tool_called_with_error_status(
log: StructuredLogCapture,
client: Client[FastMCPTransport],
respx_mock: MockRouter,
) -> None:
# Given
respx_mock.get("https://api.flagsmith.com/environments/").respond(status_code=502)

# When
with pytest.raises(ToolError):
await client.call_tool("list_environments", {})

# Then
[event] = [e for e in log.events if e["event"] == "tool.called"]
assert event == {
"event": "tool.called",
"level": "info",
"tool__name": "list_environments",
"flagsmith__mcp__client__name": "mcp",
"flagsmith__mcp__client__version": "0.1.0",
"status": "error",
}
Loading
Loading