Skip to content

Commit 6844f7a

Browse files
feat(server): add FastAPI-style dependency injection system
ROOT CAUSE: MCPServer lacked a built-in mechanism for injecting user-defined dependencies (database connections, auth services, configs, etc.) into tool handlers. Developers had to rely on global variables or lifespan context, making testing difficult and limiting code organization. CHANGES: 1. Created Depends() class for declaring dependencies in tool parameters 2. Implemented DependencyResolver for automatic dependency graph resolution 3. Added find_dependency_parameters() to detect Depends() in function signatures 4. Extended Tool class to support dependency_kwarg_names field 5. Modified Tool.from_function() to skip Depends() parameters from arg_model 6. Modified Tool.run() to resolve and inject dependencies before execution 7. Added dependency_overrides to ToolManager and PromptManager 8. Implemented MCPServer.override_dependency() for testing 9. Exported Depends class from mcp.server public API IMPACT: - Tools can now declare dependencies via Depends(get_dependency) - Nested dependencies (dependencies of dependencies) are automatically resolved - Dependencies can be overridden for easy testing - Backward compatible - existing tools without Depends() work unchanged - Per-request caching prevents redundant dependency instantiation TECHNICAL NOTES: - Fixed dict reference bug: dependency_overrides or {} created new dict on empty - Used "is not None" check instead of "or {}" to preserve dict reference - Both sync and async dependency functions are supported - Caching is opt-in via use_cache parameter (default: True) FILES MODIFIED: - src/mcp/server/__init__.py - src/mcp/server/mcpserver/utilities/dependencies.py (new) - src/mcp/server/mcpserver/utilities/dependency_resolver.py (new) - src/mcp/server/mcpserver/tools/base.py - src/mcp/server/mcpserver/tools/tool_manager.py - src/mcp/server/mcpserver/prompts/base.py - src/mcp/server/mcpserver/prompts/manager.py - src/mcp/server/mcpserver/server.py - tests/server/mcpserver/utilities/test_dependencies.py (new) - tests/server/mcpserver/test_dependency_injection.py (new) - docs/dependency_injection.md (new) Refs: #1254
1 parent be5bb7c commit 6844f7a

File tree

12 files changed

+1076
-10
lines changed

12 files changed

+1076
-10
lines changed

docs/dependency_injection.md

Lines changed: 503 additions & 0 deletions
Large diffs are not rendered by default.

src/mcp/server/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .context import ServerRequestContext
22
from .lowlevel import NotificationOptions, Server
33
from .mcpserver import MCPServer
4+
from .mcpserver.utilities.dependencies import Depends
45
from .models import InitializationOptions
56

6-
__all__ = ["Server", "ServerRequestContext", "MCPServer", "NotificationOptions", "InitializationOptions"]
7+
__all__ = ["Server", "ServerRequestContext", "MCPServer", "NotificationOptions", "InitializationOptions", "Depends"]

src/mcp/server/mcpserver/prompts/base.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pydantic import BaseModel, Field, TypeAdapter, validate_call
1111

1212
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context
13+
from mcp.server.mcpserver.utilities.dependencies import find_dependency_parameters
1314
from mcp.server.mcpserver.utilities.func_metadata import func_metadata
1415
from mcp.types import ContentBlock, Icon, TextContent
1516

@@ -72,6 +73,11 @@ class Prompt(BaseModel):
7273
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
7374
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this prompt")
7475
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context", exclude=True)
76+
dependency_kwarg_names: list[str] = Field(
77+
default_factory=list,
78+
description="Names of kwargs that receive dependencies",
79+
exclude=True,
80+
)
7581

7682
@classmethod
7783
def from_function(
@@ -100,10 +106,19 @@ def from_function(
100106
if context_kwarg is None: # pragma: no branch
101107
context_kwarg = find_context_parameter(fn)
102108

103-
# Get schema from func_metadata, excluding context parameter
109+
# Find dependency parameters
110+
dependency_params = find_dependency_parameters(fn)
111+
dependency_kwarg_names = list(dependency_params.keys())
112+
113+
# Get schema from func_metadata, excluding context and dependency parameters
114+
skip_names = []
115+
if context_kwarg:
116+
skip_names.append(context_kwarg)
117+
skip_names.extend(dependency_kwarg_names)
118+
104119
func_arg_metadata = func_metadata(
105120
fn,
106-
skip_names=[context_kwarg] if context_kwarg is not None else [],
121+
skip_names=skip_names,
107122
)
108123
parameters = func_arg_metadata.arg_model.model_json_schema()
109124

@@ -131,12 +146,14 @@ def from_function(
131146
fn=fn,
132147
icons=icons,
133148
context_kwarg=context_kwarg,
149+
dependency_kwarg_names=dependency_kwarg_names,
134150
)
135151

136152
async def render(
137153
self,
138154
arguments: dict[str, Any] | None = None,
139155
context: Context[LifespanContextT, RequestT] | None = None,
156+
dependency_resolver: Any = None,
140157
) -> list[Message]:
141158
"""Render the prompt with arguments."""
142159
# Validate required arguments
@@ -151,6 +168,13 @@ async def render(
151168
# Add context to arguments if needed
152169
call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg)
153170

171+
# Resolve dependencies if a resolver is provided
172+
if self.dependency_kwarg_names and dependency_resolver:
173+
deps = find_dependency_parameters(self.fn)
174+
for dep_name in self.dependency_kwarg_names:
175+
if dep_name in deps:
176+
call_args[dep_name] = await dependency_resolver.resolve(dep_name, deps[dep_name])
177+
154178
# Call function and check if result is a coroutine
155179
result = self.fn(**call_args)
156180
if inspect.iscoroutine(result):

src/mcp/server/mcpserver/prompts/manager.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from collections.abc import Callable
56
from typing import TYPE_CHECKING, Any
67

78
from mcp.server.mcpserver.prompts.base import Message, Prompt
@@ -17,9 +18,14 @@
1718
class PromptManager:
1819
"""Manages MCPServer prompts."""
1920

20-
def __init__(self, warn_on_duplicate_prompts: bool = True):
21+
def __init__(
22+
self,
23+
warn_on_duplicate_prompts: bool = True,
24+
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
25+
):
2126
self._prompts: dict[str, Prompt] = {}
2227
self.warn_on_duplicate_prompts = warn_on_duplicate_prompts
28+
self.dependency_overrides = dependency_overrides if dependency_overrides is not None else {}
2329

2430
def get_prompt(self, name: str) -> Prompt | None:
2531
"""Get prompt by name."""
@@ -56,4 +62,11 @@ async def render_prompt(
5662
if not prompt:
5763
raise ValueError(f"Unknown prompt: {name}")
5864

59-
return await prompt.render(arguments, context=context)
65+
# Create dependency resolver if prompt has dependencies
66+
dependency_resolver = None
67+
if prompt.dependency_kwarg_names:
68+
from mcp.server.mcpserver.utilities.dependency_resolver import DependencyResolver
69+
70+
dependency_resolver = DependencyResolver(context=context, overrides=self.dependency_overrides)
71+
72+
return await prompt.render(arguments, context=context, dependency_resolver=dependency_resolver)

src/mcp/server/mcpserver/server.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,19 @@ def __init__(
157157
auth=auth,
158158
)
159159

160-
self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools)
160+
# Initialize dependency overrides
161+
self._dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] = {}
162+
163+
self._tool_manager = ToolManager(
164+
tools=tools,
165+
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools,
166+
dependency_overrides=self._dependency_overrides,
167+
)
161168
self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources)
162-
self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts)
169+
self._prompt_manager = PromptManager(
170+
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts,
171+
dependency_overrides=self._dependency_overrides,
172+
)
163173
self._lowlevel_server = Server(
164174
name=name or "mcp-server",
165175
title=title,
@@ -502,6 +512,31 @@ def remove_tool(self, name: str) -> None:
502512
"""
503513
self._tool_manager.remove_tool(name)
504514

515+
def override_dependency(
516+
self,
517+
original: Callable[..., Any],
518+
override: Callable[..., Any],
519+
) -> None:
520+
"""Override a dependency for testing.
521+
522+
This allows you to replace a dependency function with an alternative implementation,
523+
typically used in testing to provide mock dependencies.
524+
525+
Usage:
526+
def get_db() -> Database:
527+
return Database()
528+
529+
def get_test_db() -> Database:
530+
return MockDatabase([...])
531+
532+
server.override_dependency(get_db, get_test_db)
533+
534+
Args:
535+
original: The original dependency function to override
536+
override: The override function to use instead
537+
"""
538+
self._dependency_overrides[original] = override
539+
505540
def tool(
506541
self,
507542
name: str | None = None,

src/mcp/server/mcpserver/tools/base.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from mcp.server.mcpserver.exceptions import ToolError
1212
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter
13+
from mcp.server.mcpserver.utilities.dependencies import find_dependency_parameters
1314
from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata
1415
from mcp.shared.exceptions import UrlElicitationRequiredError
1516
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
@@ -33,6 +34,10 @@ class Tool(BaseModel):
3334
)
3435
is_async: bool = Field(description="Whether the tool is async")
3536
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
37+
dependency_kwarg_names: list[str] = Field(
38+
default_factory=list,
39+
description="Names of kwargs that receive dependencies",
40+
)
3641
annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool")
3742
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this tool")
3843
meta: dict[str, Any] | None = Field(default=None, description="Optional metadata for this tool")
@@ -68,9 +73,19 @@ def from_function(
6873
if context_kwarg is None: # pragma: no branch
6974
context_kwarg = find_context_parameter(fn)
7075

76+
# Find dependency parameters
77+
dependency_params = find_dependency_parameters(fn)
78+
dependency_kwarg_names = list(dependency_params.keys())
79+
80+
# Skip both context and dependency params from arg_model
81+
skip_names = []
82+
if context_kwarg:
83+
skip_names.append(context_kwarg)
84+
skip_names.extend(dependency_kwarg_names)
85+
7186
func_arg_metadata = func_metadata(
7287
fn,
73-
skip_names=[context_kwarg] if context_kwarg is not None else [],
88+
skip_names=skip_names,
7489
structured_output=structured_output,
7590
)
7691
parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True)
@@ -84,6 +99,7 @@ def from_function(
8499
fn_metadata=func_arg_metadata,
85100
is_async=is_async,
86101
context_kwarg=context_kwarg,
102+
dependency_kwarg_names=dependency_kwarg_names,
87103
annotations=annotations,
88104
icons=icons,
89105
meta=meta,
@@ -94,14 +110,29 @@ async def run(
94110
arguments: dict[str, Any],
95111
context: Context[LifespanContextT, RequestT] | None = None,
96112
convert_result: bool = False,
113+
dependency_resolver: Any = None,
97114
) -> Any:
98115
"""Run the tool with arguments."""
99116
try:
117+
# Build direct args (context and dependencies)
118+
direct_args = {}
119+
if self.context_kwarg is not None and context is not None:
120+
direct_args[self.context_kwarg] = context
121+
122+
# Resolve dependencies if a resolver is provided
123+
if self.dependency_kwarg_names and dependency_resolver:
124+
from mcp.server.mcpserver.utilities.dependencies import find_dependency_parameters
125+
126+
deps = find_dependency_parameters(self.fn)
127+
for dep_name in self.dependency_kwarg_names:
128+
if dep_name in deps:
129+
direct_args[dep_name] = await dependency_resolver.resolve(dep_name, deps[dep_name])
130+
100131
result = await self.fn_metadata.call_fn_with_arg_validation(
101132
self.fn,
102133
self.is_async,
103134
arguments,
104-
{self.context_kwarg: context} if self.context_kwarg is not None else None,
135+
direct_args if direct_args else None,
105136
)
106137

107138
if convert_result:

src/mcp/server/mcpserver/tools/tool_manager.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
warn_on_duplicate_tools: bool = True,
2424
*,
2525
tools: list[Tool] | None = None,
26+
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
2627
):
2728
self._tools: dict[str, Tool] = {}
2829
if tools is not None:
@@ -32,6 +33,7 @@ def __init__(
3233
self._tools[tool.name] = tool
3334

3435
self.warn_on_duplicate_tools = warn_on_duplicate_tools
36+
self.dependency_overrides = dependency_overrides if dependency_overrides is not None else {}
3537

3638
def get_tool(self, name: str) -> Tool | None:
3739
"""Get tool by name."""
@@ -89,4 +91,13 @@ async def call_tool(
8991
if not tool:
9092
raise ToolError(f"Unknown tool: {name}")
9193

92-
return await tool.run(arguments, context=context, convert_result=convert_result)
94+
# Create dependency resolver if tool has dependencies
95+
dependency_resolver = None
96+
if tool.dependency_kwarg_names:
97+
from mcp.server.mcpserver.utilities.dependency_resolver import DependencyResolver
98+
99+
dependency_resolver = DependencyResolver(context=context, overrides=self.dependency_overrides)
100+
101+
return await tool.run(
102+
arguments, context=context, convert_result=convert_result, dependency_resolver=dependency_resolver
103+
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Dependency injection system for MCPServer."""
2+
3+
from __future__ import annotations
4+
5+
import inspect
6+
from collections.abc import Callable
7+
from typing import Any, Generic, TypeVar
8+
9+
T = TypeVar("T")
10+
11+
12+
class Depends(Generic[T]):
13+
"""Marker class for dependency injection.
14+
15+
Usage:
16+
def get_db() -> Database:
17+
return Database()
18+
19+
@server.tool()
20+
def my_tool(db: Database = Depends(get_db)):
21+
return db.query(...)
22+
23+
Args:
24+
dependency: A callable that provides the dependency
25+
scope: The scope of the dependency (for future use)
26+
use_cache: Whether to cache the dependency result
27+
28+
"""
29+
30+
def __init__(
31+
self,
32+
dependency: Callable[..., T],
33+
*,
34+
use_cache: bool = True,
35+
) -> None:
36+
self.dependency = dependency
37+
self.use_cache = use_cache
38+
39+
def __repr__(self) -> str:
40+
return f"Depends({self.dependency.__name__})"
41+
42+
43+
def find_dependency_parameters(
44+
fn: Callable[..., Any],
45+
) -> dict[str, Depends[Any]]:
46+
"""Find all parameters with Depends() default values.
47+
48+
Args:
49+
fn: Function to inspect
50+
51+
Returns:
52+
Dict mapping parameter names to Depends instances
53+
"""
54+
deps: dict[str, Depends[Any]] = {}
55+
try:
56+
sig = inspect.signature(fn, eval_str=True)
57+
except (ValueError, TypeError):
58+
return deps
59+
60+
for param_name, param in sig.parameters.items():
61+
if param.default is inspect.Parameter.empty:
62+
continue
63+
64+
# Check if default is Depends instance
65+
if isinstance(param.default, Depends):
66+
deps[param_name] = param.default
67+
68+
return deps

0 commit comments

Comments
 (0)