Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ venv = ".venv"
executionEnvironments = [
{ root = "tests", extraPaths = [
".",
], reportUnusedFunction = false, reportPrivateUsage = false },
], reportUnusedFunction = false, reportPrivateUsage = false, reportUnknownMemberType = false, reportArgumentType = false, reportUnknownVariableType = false, reportAttributeAccessIssue = false },
{ root = "examples/servers", reportUnusedFunction = false },
]

Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .context import ServerRequestContext
from .lowlevel import NotificationOptions, Server
from .mcpserver import MCPServer
from .mcpserver.utilities.dependencies import Depends
from .models import InitializationOptions

__all__ = ["Server", "ServerRequestContext", "MCPServer", "NotificationOptions", "InitializationOptions"]
__all__ = ["Server", "ServerRequestContext", "MCPServer", "NotificationOptions", "InitializationOptions", "Depends"]
28 changes: 26 additions & 2 deletions src/mcp/server/mcpserver/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import BaseModel, Field, TypeAdapter, validate_call

from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context
from mcp.server.mcpserver.utilities.dependencies import find_dependency_parameters
from mcp.server.mcpserver.utilities.func_metadata import func_metadata
from mcp.types import ContentBlock, Icon, TextContent

Expand Down Expand Up @@ -72,6 +73,11 @@ class Prompt(BaseModel):
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this prompt")
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context", exclude=True)
dependency_kwarg_names: list[str] = Field(
default_factory=list,
description="Names of kwargs that receive dependencies",
exclude=True,
)

@classmethod
def from_function(
Expand Down Expand Up @@ -100,10 +106,19 @@ def from_function(
if context_kwarg is None: # pragma: no branch
context_kwarg = find_context_parameter(fn)

# Get schema from func_metadata, excluding context parameter
# Find dependency parameters
dependency_params = find_dependency_parameters(fn)
dependency_kwarg_names = list(dependency_params.keys())

# Get schema from func_metadata, excluding context and dependency parameters
skip_names: list[str] = []
if context_kwarg:
skip_names.append(context_kwarg)
skip_names.extend(dependency_kwarg_names)

func_arg_metadata = func_metadata(
fn,
skip_names=[context_kwarg] if context_kwarg is not None else [],
skip_names=skip_names,
)
parameters = func_arg_metadata.arg_model.model_json_schema()

Expand Down Expand Up @@ -131,12 +146,14 @@ def from_function(
fn=fn,
icons=icons,
context_kwarg=context_kwarg,
dependency_kwarg_names=dependency_kwarg_names,
)

async def render(
self,
arguments: dict[str, Any] | None = None,
context: Context[LifespanContextT, RequestT] | None = None,
dependency_resolver: Any = None,
) -> list[Message]:
"""Render the prompt with arguments."""
# Validate required arguments
Expand All @@ -151,6 +168,13 @@ async def render(
# Add context to arguments if needed
call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg)

# Resolve dependencies if a resolver is provided
if self.dependency_kwarg_names and dependency_resolver: # pragma: no cover
deps = find_dependency_parameters(self.fn)
for dep_name in self.dependency_kwarg_names:
if dep_name in deps:
call_args[dep_name] = await dependency_resolver.resolve(dep_name, deps[dep_name])

# Call function and check if result is a coroutine
result = self.fn(**call_args)
if inspect.iscoroutine(result):
Expand Down
17 changes: 15 additions & 2 deletions src/mcp/server/mcpserver/prompts/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from mcp.server.mcpserver.prompts.base import Message, Prompt
Expand All @@ -17,9 +18,14 @@
class PromptManager:
"""Manages MCPServer prompts."""

def __init__(self, warn_on_duplicate_prompts: bool = True):
def __init__(
self,
warn_on_duplicate_prompts: bool = True,
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
):
self._prompts: dict[str, Prompt] = {}
self.warn_on_duplicate_prompts = warn_on_duplicate_prompts
self.dependency_overrides = dependency_overrides if dependency_overrides is not None else {}

def get_prompt(self, name: str) -> Prompt | None:
"""Get prompt by name."""
Expand Down Expand Up @@ -56,4 +62,11 @@ async def render_prompt(
if not prompt:
raise ValueError(f"Unknown prompt: {name}")

return await prompt.render(arguments, context=context)
# Create dependency resolver if prompt has dependencies
dependency_resolver = None
if prompt.dependency_kwarg_names: # pragma: no cover
from mcp.server.mcpserver.utilities.dependency_resolver import DependencyResolver

dependency_resolver = DependencyResolver(context=context, overrides=self.dependency_overrides)

return await prompt.render(arguments, context=context, dependency_resolver=dependency_resolver)
39 changes: 37 additions & 2 deletions src/mcp/server/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,19 @@ def __init__(
auth=auth,
)

self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools)
# Initialize dependency overrides
self._dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] = {}

self._tool_manager = ToolManager(
tools=tools,
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools,
dependency_overrides=self._dependency_overrides,
)
self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources)
self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts)
self._prompt_manager = PromptManager(
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts,
dependency_overrides=self._dependency_overrides,
)
self._lowlevel_server = Server(
name=name or "mcp-server",
title=title,
Expand Down Expand Up @@ -502,6 +512,31 @@ def remove_tool(self, name: str) -> None:
"""
self._tool_manager.remove_tool(name)

def override_dependency(
self,
original: Callable[..., Any],
override: Callable[..., Any],
) -> None:
"""Override a dependency for testing.

This allows you to replace a dependency function with an alternative implementation,
typically used in testing to provide mock dependencies.

Usage:
def get_db() -> Database:
return Database()

def get_test_db() -> Database:
return MockDatabase([...])

server.override_dependency(get_db, get_test_db)

Args:
original: The original dependency function to override
override: The override function to use instead
"""
self._dependency_overrides[original] = override

def tool(
self,
name: str | None = None,
Expand Down
35 changes: 33 additions & 2 deletions src/mcp/server/mcpserver/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from mcp.server.mcpserver.exceptions import ToolError
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter
from mcp.server.mcpserver.utilities.dependencies import find_dependency_parameters
from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata
from mcp.shared.exceptions import UrlElicitationRequiredError
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
Expand All @@ -33,6 +34,10 @@ class Tool(BaseModel):
)
is_async: bool = Field(description="Whether the tool is async")
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
dependency_kwarg_names: list[str] = Field(
default_factory=list,
description="Names of kwargs that receive dependencies",
)
annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool")
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this tool")
meta: dict[str, Any] | None = Field(default=None, description="Optional metadata for this tool")
Expand Down Expand Up @@ -68,9 +73,19 @@ def from_function(
if context_kwarg is None: # pragma: no branch
context_kwarg = find_context_parameter(fn)

# Find dependency parameters
dependency_params = find_dependency_parameters(fn)
dependency_kwarg_names = list(dependency_params.keys())

# Skip both context and dependency params from arg_model
skip_names: list[str] = []
if context_kwarg:
skip_names.append(context_kwarg)
skip_names.extend(dependency_kwarg_names)

func_arg_metadata = func_metadata(
fn,
skip_names=[context_kwarg] if context_kwarg is not None else [],
skip_names=skip_names,
structured_output=structured_output,
)
parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True)
Expand All @@ -84,6 +99,7 @@ def from_function(
fn_metadata=func_arg_metadata,
is_async=is_async,
context_kwarg=context_kwarg,
dependency_kwarg_names=dependency_kwarg_names,
annotations=annotations,
icons=icons,
meta=meta,
Expand All @@ -94,14 +110,29 @@ async def run(
arguments: dict[str, Any],
context: Context[LifespanContextT, RequestT] | None = None,
convert_result: bool = False,
dependency_resolver: Any = None,
) -> Any:
"""Run the tool with arguments."""
try:
# Build direct args (context and dependencies)
direct_args: dict[str, Any] = {}
if self.context_kwarg is not None:
direct_args[self.context_kwarg] = context

# Resolve dependencies if a resolver is provided
if self.dependency_kwarg_names and dependency_resolver:
from mcp.server.mcpserver.utilities.dependencies import find_dependency_parameters

deps = find_dependency_parameters(self.fn)
for dep_name in self.dependency_kwarg_names:
if dep_name in deps:
direct_args[dep_name] = await dependency_resolver.resolve(dep_name, deps[dep_name])

result = await self.fn_metadata.call_fn_with_arg_validation(
self.fn,
self.is_async,
arguments,
{self.context_kwarg: context} if self.context_kwarg is not None else None,
direct_args if direct_args else None,
)

if convert_result:
Expand Down
13 changes: 12 additions & 1 deletion src/mcp/server/mcpserver/tools/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
warn_on_duplicate_tools: bool = True,
*,
tools: list[Tool] | None = None,
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
):
self._tools: dict[str, Tool] = {}
if tools is not None:
Expand All @@ -32,6 +33,7 @@ def __init__(
self._tools[tool.name] = tool

self.warn_on_duplicate_tools = warn_on_duplicate_tools
self.dependency_overrides = dependency_overrides if dependency_overrides is not None else {}

def get_tool(self, name: str) -> Tool | None:
"""Get tool by name."""
Expand Down Expand Up @@ -89,4 +91,13 @@ async def call_tool(
if not tool:
raise ToolError(f"Unknown tool: {name}")

return await tool.run(arguments, context=context, convert_result=convert_result)
# Create dependency resolver if tool has dependencies
dependency_resolver = None
if tool.dependency_kwarg_names:
from mcp.server.mcpserver.utilities.dependency_resolver import DependencyResolver

dependency_resolver = DependencyResolver(context=context, overrides=self.dependency_overrides)

return await tool.run(
arguments, context=context, convert_result=convert_result, dependency_resolver=dependency_resolver
)
68 changes: 68 additions & 0 deletions src/mcp/server/mcpserver/utilities/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Dependency injection system for MCPServer."""

from __future__ import annotations

import inspect
from collections.abc import Callable
from typing import Any, Generic, TypeVar

T = TypeVar("T")


class Depends(Generic[T]):
"""Marker class for dependency injection.

Usage:
def get_db() -> Database:
return Database()

@server.tool()
def my_tool(db: Database = Depends(get_db)):
return db.query(...)

Args:
dependency: A callable that provides the dependency
scope: The scope of the dependency (for future use)
use_cache: Whether to cache the dependency result

"""

def __init__(
self,
dependency: Callable[..., T],
*,
use_cache: bool = True,
) -> None:
self.dependency = dependency
self.use_cache = use_cache

def __repr__(self) -> str:
return f"Depends({self.dependency.__name__})"


def find_dependency_parameters(
fn: Callable[..., Any],
) -> dict[str, Depends[Any]]:
"""Find all parameters with Depends() default values.

Args:
fn: Function to inspect

Returns:
Dict mapping parameter names to Depends instances
"""
deps: dict[str, Depends[Any]] = {}
try:
sig = inspect.signature(fn, eval_str=True)
except (ValueError, TypeError): # pragma: no cover (defensive)
return deps

for param_name, param in sig.parameters.items():
if param.default is inspect.Parameter.empty:
continue

# Check if default is Depends instance
if isinstance(param.default, Depends):
deps[param_name] = param.default # type: ignore[assignment]

return deps
Loading
Loading