Skip to content

Commit b603626

Browse files
committed
backport: phase-4 wave 1 (proof) — 25/25 pass across lowlevel/mcpserver/transports/auth
- transports/test_bridge (4, no edits), test_stdio (2 + _stdio_server rewrite) - lowlevel/test_completion (5), test_logging (3), test_tools b1/3 (5) - mcpserver/test_completion (1) - auth/test_flow (5) - _requirements.py: tools:call:unknown-name + protocol:error:internal-error divergences updated for v1
1 parent 7d9881d commit b603626

8 files changed

Lines changed: 241 additions & 211 deletions

File tree

tests/interaction/_requirements.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,11 @@ def __post_init__(self) -> None:
349349
),
350350
divergence=Divergence(
351351
note=(
352-
"The low-level Server returns code 0 (not a defined JSON-RPC code) instead of -32603 and "
353-
"leaks str(exc) as the error message."
352+
"For tools/call the lowlevel @server.call_tool() decorator wraps the handler in a broad "
353+
"try/except that converts every Exception to CallToolResult(isError=True, "
354+
"content=[TextContent(text=str(exc))]), so the dispatcher's JSON-RPC error path is never "
355+
"reached for tool calls and the test pins the isError=True result. For other request "
356+
"handlers the dispatcher returns code 0 (not -32603) with str(exc) as the message."
354357
),
355358
),
356359
),
@@ -559,6 +562,14 @@ def __post_init__(self) -> None:
559562
"tools:call:unknown-name": Requirement(
560563
source=f"{SPEC_BASE_URL}/server/tools#error-handling",
561564
behavior="tools/call for a name the server does not recognise returns a JSON-RPC error.",
565+
divergence=Divergence(
566+
note=(
567+
"The lowlevel @server.call_tool() decorator catches every handler exception (including "
568+
"McpError) and converts it to CallToolResult(isError=True, content=[TextContent(text=str(exc))]), "
569+
"so a handler cannot produce a protocol-level JSON-RPC error for tools/call; the test pins "
570+
"the isError=True result instead."
571+
),
572+
),
562573
),
563574
"tools:capability:declared": Requirement(
564575
source=f"{SPEC_BASE_URL}/server/tools#capabilities",

tests/interaction/auth/test_flow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from mcp.server.auth.middleware.auth_context import get_access_token
2424
from mcp.shared.auth import OAuthClientInformationFull
2525
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
26-
from tests.interaction._connect import BASE_URL
26+
from tests.interaction._connect import BASE_URL, build_streamable_http_app
2727
from tests.interaction._requirements import requirement
2828
from tests.interaction.auth._harness import (
2929
REDIRECT_URI,
@@ -229,11 +229,11 @@ async def test_shimmed_app_serves_overrides_404s_and_otherwise_forwards_to_the_w
229229
own routing; provided here so the discovery tests can rely on the shim without each adding
230230
their own contract test.
231231
"""
232-
server = Server("bare")
232+
server: Server[object] = Server("bare")
233233
provider = InMemoryAuthorizationServerProvider()
234-
real_app = server.streamable_http_app(auth=auth_settings(), auth_server_provider=provider)
234+
real_app, manager = build_streamable_http_app(server, auth=auth_settings(), auth_server_provider=provider)
235235
app = shimmed_app(real_app, not_found=frozenset({"/missing"}), serve={"/override": b'{"shimmed": true}'})
236-
async with server.session_manager.run():
236+
async with manager.run():
237237
async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http:
238238
served = await http.get("/override")
239239
assert served.status_code == 200

tests/interaction/lowlevel/test_completion.py

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
"""Completion interactions against the low-level Server, driven through the public Client API."""
1+
"""Completion interactions against the low-level Server, driven through the public client API."""
22

33
import pytest
44
from inline_snapshot import snapshot
55

6-
from mcp import MCPError, types
7-
from mcp.server import Server, ServerRequestContext
6+
from mcp import McpError
7+
from mcp.server.lowlevel import Server
88
from mcp.types import (
99
INVALID_PARAMS,
1010
METHOD_NOT_FOUND,
1111
CompleteResult,
1212
Completion,
13+
CompletionArgument,
14+
CompletionContext,
1315
ErrorData,
1416
PromptReference,
1517
ResourceTemplateReference,
@@ -27,16 +29,20 @@ async def test_complete_prompt_argument(connect: Connect) -> None:
2729
2830
The returned values are filtered by the argument's value, proving the value reached the handler.
2931
"""
30-
31-
async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult:
32-
assert isinstance(params.ref, PromptReference)
33-
assert params.ref.name == "code_review"
34-
assert params.argument.name == "language"
32+
server = Server("completer")
33+
34+
@server.completion()
35+
async def completion(
36+
ref: PromptReference | ResourceTemplateReference,
37+
argument: CompletionArgument,
38+
context: CompletionContext | None,
39+
) -> Completion | None:
40+
assert isinstance(ref, PromptReference)
41+
assert ref.name == "code_review"
42+
assert argument.name == "language"
3543
candidates = ["python", "pytorch", "ruby"]
36-
matches = [candidate for candidate in candidates if candidate.startswith(params.argument.value)]
37-
return CompleteResult(completion=Completion(values=matches, total=len(matches), hasMore=False))
38-
39-
server = Server("completer", on_completion=completion)
44+
matches = [candidate for candidate in candidates if candidate.startswith(argument.value)]
45+
return Completion(values=matches, total=len(matches), hasMore=False)
4046

4147
async with connect(server) as client:
4248
result = await client.complete(
@@ -51,14 +57,18 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar
5157
@requirement("completion:resource-template-arg")
5258
async def test_complete_resource_template_variable(connect: Connect) -> None:
5359
"""Completing a URI template variable delivers the template URI and variable name to the handler."""
54-
55-
async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult:
56-
assert isinstance(params.ref, ResourceTemplateReference)
57-
assert params.ref.uri == "github://repos/{owner}/{repo}"
58-
assert params.argument.name == "owner"
59-
return CompleteResult(completion=Completion(values=[f"{params.argument.value}contextprotocol"]))
60-
61-
server = Server("completer", on_completion=completion)
60+
server = Server("completer")
61+
62+
@server.completion()
63+
async def completion(
64+
ref: PromptReference | ResourceTemplateReference,
65+
argument: CompletionArgument,
66+
context: CompletionContext | None,
67+
) -> Completion | None:
68+
assert isinstance(ref, ResourceTemplateReference)
69+
assert ref.uri == "github://repos/{owner}/{repo}"
70+
assert argument.name == "owner"
71+
return Completion(values=[f"{argument.value}contextprotocol"])
6272

6373
async with connect(server) as client:
6474
result = await client.complete(
@@ -75,14 +85,18 @@ async def test_complete_receives_context_arguments(connect: Connect) -> None:
7585
7686
The returned value is derived from the context, proving it arrived.
7787
"""
78-
79-
async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult:
80-
assert params.argument.name == "repo"
81-
assert params.context is not None
82-
assert params.context.arguments is not None
83-
return CompleteResult(completion=Completion(values=[f"{params.context.arguments['owner']}/python-sdk"]))
84-
85-
server = Server("completer", on_completion=completion)
88+
server = Server("completer")
89+
90+
@server.completion()
91+
async def completion(
92+
ref: PromptReference | ResourceTemplateReference,
93+
argument: CompletionArgument,
94+
context: CompletionContext | None,
95+
) -> Completion | None:
96+
assert argument.name == "repo"
97+
assert context is not None
98+
assert context.arguments is not None
99+
return Completion(values=[f"{context.arguments['owner']}/python-sdk"])
86100

87101
async with connect(server) as client:
88102
result = await client.complete(
@@ -102,15 +116,19 @@ async def test_completion_against_an_unknown_ref_is_rejected_with_invalid_params
102116
against); rejecting an unknown ref is the handler's job, and this test pins the spec-recommended
103117
way to do it.
104118
"""
119+
server = Server("completer")
105120

106-
async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult:
107-
assert isinstance(params.ref, PromptReference)
108-
raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.ref.name!r}")
109-
110-
server = Server("completer", on_completion=completion)
121+
@server.completion()
122+
async def completion(
123+
ref: PromptReference | ResourceTemplateReference,
124+
argument: CompletionArgument,
125+
context: CompletionContext | None,
126+
) -> Completion | None:
127+
assert isinstance(ref, PromptReference)
128+
raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Unknown prompt: {ref.name!r}"))
111129

112130
async with connect(server) as client:
113-
with pytest.raises(MCPError) as exc_info:
131+
with pytest.raises(McpError) as exc_info:
114132
await client.complete(PromptReference(type="ref/prompt", name="ghost"), argument={"name": "x", "value": ""})
115133

116134
assert exc_info.value.error.code == INVALID_PARAMS
@@ -123,9 +141,11 @@ async def test_complete_without_handler_is_method_not_found(connect: Connect) ->
123141
server = Server("incomplete")
124142

125143
async with connect(server) as client:
126-
assert client.initialize_result.capabilities.completions is None
144+
capabilities = client.get_server_capabilities()
145+
assert capabilities is not None
146+
assert capabilities.completions is None
127147

128-
with pytest.raises(MCPError) as exc_info:
148+
with pytest.raises(McpError) as exc_info:
129149
await client.complete(
130150
PromptReference(type="ref/prompt", name="anything"), argument={"name": "topic", "value": ""}
131151
)

tests/interaction/lowlevel/test_logging.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
assert after the request completes on every transport leg -- no events, no waiting.
1010
"""
1111

12+
from typing import Any
13+
1214
import pytest
1315
from inline_snapshot import snapshot
1416

1517
from mcp import types
16-
from mcp.server import Server, ServerRequestContext
18+
from mcp.server import Server
1719
from mcp.types import CallToolResult, EmptyResult, LoggingMessageNotificationParams, TextContent
1820
from tests.interaction._connect import Connect
1921
from tests.interaction._requirements import requirement
@@ -35,12 +37,11 @@
3537
@requirement("logging:set-level")
3638
async def test_set_logging_level_reaches_handler(connect: Connect) -> None:
3739
"""The level requested by the client is delivered to the server's handler verbatim."""
40+
server = Server("logger")
3841

39-
async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult:
40-
assert params.level == "warning"
41-
return EmptyResult()
42-
43-
server = Server("logger", on_set_logging_level=set_logging_level)
42+
@server.set_logging_level()
43+
async def set_logging_level(level: types.LoggingLevel) -> None:
44+
assert level == "warning"
4445

4546
async with connect(server) as client:
4647
result = await client.set_logging_level("warning")
@@ -61,27 +62,29 @@ async def test_log_messages_reach_logging_callback_in_order(connect: Connect) ->
6162
async def collect(params: LoggingMessageNotificationParams) -> None:
6263
received.append(params)
6364

64-
async def list_tools(
65-
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
66-
) -> types.ListToolsResult:
67-
return types.ListToolsResult(tools=[types.Tool(name="chatty", inputSchema={"type": "object"})])
65+
server = Server("logger")
66+
67+
@server.list_tools()
68+
async def list_tools() -> list[types.Tool]:
69+
return [types.Tool(name="chatty", inputSchema={"type": "object"})]
6870

69-
async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult:
70-
assert params.name == "chatty"
71+
@server.call_tool()
72+
async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]:
73+
assert name == "chatty"
74+
ctx = server.request_context
7175
await ctx.session.send_log_message(
7276
level="info", data="starting up", logger="app.lifecycle", related_request_id=ctx.request_id
7377
)
7478
await ctx.session.send_log_message(
7579
level="error", data={"code": 502, "retryable": True}, related_request_id=ctx.request_id
7680
)
77-
return CallToolResult(content=[TextContent(type="text", text="done")])
81+
return [TextContent(type="text", text="done")]
7882

79-
async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult:
83+
@server.set_logging_level()
84+
async def set_logging_level(level: types.LoggingLevel) -> None:
8085
"""Registered so the logging capability is advertised; the client never sets a level."""
8186
raise NotImplementedError
8287

83-
server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level)
84-
8588
async with connect(server, logging_callback=collect) as client:
8689
result = await client.call_tool("chatty", {})
8790

@@ -102,25 +105,27 @@ async def test_log_messages_at_every_severity_level(connect: Connect) -> None:
102105
async def collect(params: LoggingMessageNotificationParams) -> None:
103106
received.append(params)
104107

105-
async def list_tools(
106-
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
107-
) -> types.ListToolsResult:
108-
return types.ListToolsResult(tools=[types.Tool(name="siren", inputSchema={"type": "object"})])
108+
server = Server("logger")
109+
110+
@server.list_tools()
111+
async def list_tools() -> list[types.Tool]:
112+
return [types.Tool(name="siren", inputSchema={"type": "object"})]
109113

110-
async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult:
111-
assert params.name == "siren"
114+
@server.call_tool()
115+
async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]:
116+
assert name == "siren"
117+
ctx = server.request_context
112118
for level in ALL_LEVELS:
113119
await ctx.session.send_log_message(
114120
level=level, data=f"a {level} message", related_request_id=ctx.request_id
115121
)
116-
return CallToolResult(content=[TextContent(type="text", text="logged")])
122+
return [TextContent(type="text", text="logged")]
117123

118-
async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult:
124+
@server.set_logging_level()
125+
async def set_logging_level(level: types.LoggingLevel) -> None:
119126
"""Registered so the logging capability is advertised; the client never sets a level."""
120127
raise NotImplementedError
121128

122-
server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level)
123-
124129
async with connect(server, logging_callback=collect) as client:
125130
await client.call_tool("siren", {})
126131

0 commit comments

Comments
 (0)