|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +from typing import Any |
| 6 | + |
5 | 7 | import pytest |
6 | 8 |
|
7 | 9 | from mcp import types |
8 | 10 | from mcp.client import BaseClientSession |
9 | 11 | from mcp.client.client import Client |
10 | 12 | from mcp.client.session import ClientSession |
11 | 13 | from mcp.server.mcpserver import MCPServer |
| 14 | +from mcp.shared.session import ProgressFnT |
| 15 | +from mcp.types._types import RequestParamsMeta |
12 | 16 |
|
13 | 17 | pytestmark = pytest.mark.anyio |
14 | 18 |
|
@@ -44,7 +48,9 @@ def echo(text: str) -> str: |
44 | 48 | result = await session.call_tool("echo", {"text": "hello"}) |
45 | 49 | assert result.is_error is False |
46 | 50 | assert len(result.content) == 1 |
47 | | - assert result.content[0].text == "hello" |
| 51 | + first_content = result.content[0] |
| 52 | + assert isinstance(first_content, types.TextContent) |
| 53 | + assert first_content.text == "hello" |
48 | 54 |
|
49 | 55 | # List tools through the Protocol interface |
50 | 56 | tools_result = await session.list_tools() |
@@ -78,56 +84,103 @@ class StubClientSession: |
78 | 84 | """Minimal stub that satisfies BaseClientSession protocol.""" |
79 | 85 |
|
80 | 86 | async def send_request( |
81 | | - self, request, result_type, request_read_timeout_seconds=None, metadata=None, progress_callback=None |
82 | | - ): |
| 87 | + self, |
| 88 | + request: types.ClientRequest, |
| 89 | + result_type: type[Any], |
| 90 | + request_read_timeout_seconds: float | None = None, |
| 91 | + metadata: Any = None, |
| 92 | + progress_callback: ProgressFnT | None = None, |
| 93 | + ) -> Any: |
83 | 94 | return types.EmptyResult() |
84 | 95 |
|
85 | | - async def send_notification(self, notification, related_request_id=None): |
| 96 | + async def send_notification( |
| 97 | + self, |
| 98 | + notification: types.ClientNotification, |
| 99 | + related_request_id: Any = None, |
| 100 | + ) -> None: |
86 | 101 | pass |
87 | 102 |
|
88 | | - async def send_progress_notification(self, progress_token, progress, total=None, message=None, *, meta=None): |
| 103 | + async def send_progress_notification( |
| 104 | + self, |
| 105 | + progress_token: types.ProgressToken, |
| 106 | + progress: float, |
| 107 | + total: float | None = None, |
| 108 | + message: str | None = None, |
| 109 | + *, |
| 110 | + meta: RequestParamsMeta | None = None, |
| 111 | + ) -> None: |
89 | 112 | pass |
90 | 113 |
|
91 | | - async def initialize(self): |
92 | | - return types.InitializeResult() |
| 114 | + async def initialize(self) -> types.InitializeResult: |
| 115 | + return types.InitializeResult( |
| 116 | + protocol_version="2024-11-05", |
| 117 | + capabilities=types.ServerCapabilities(), |
| 118 | + server_info=types.Implementation(name="stub", version="0"), |
| 119 | + ) |
93 | 120 |
|
94 | | - async def send_ping(self, *, meta=None): |
| 121 | + async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: |
95 | 122 | return types.EmptyResult() |
96 | 123 |
|
97 | | - async def list_resources(self, *, params=None): |
98 | | - return types.ListResourcesResult() |
| 124 | + async def list_resources(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListResourcesResult: |
| 125 | + return types.ListResourcesResult(resources=[]) |
99 | 126 |
|
100 | | - async def list_resource_templates(self, *, params=None): |
101 | | - return types.ListResourceTemplatesResult() |
| 127 | + async def list_resource_templates( |
| 128 | + self, *, params: types.PaginatedRequestParams | None = None |
| 129 | + ) -> types.ListResourceTemplatesResult: |
| 130 | + return types.ListResourceTemplatesResult(resource_templates=[]) |
102 | 131 |
|
103 | | - async def read_resource(self, uri, *, meta=None): |
104 | | - return types.ReadResourceResult() |
| 132 | + async def read_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.ReadResourceResult: |
| 133 | + return types.ReadResourceResult(contents=[]) |
105 | 134 |
|
106 | | - async def subscribe_resource(self, uri, *, meta=None): |
| 135 | + async def subscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: |
107 | 136 | return types.EmptyResult() |
108 | 137 |
|
109 | | - async def unsubscribe_resource(self, uri, *, meta=None): |
| 138 | + async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: |
110 | 139 | return types.EmptyResult() |
111 | 140 |
|
112 | | - async def call_tool(self, name, arguments=None, read_timeout_seconds=None, progress_callback=None, *, meta=None): |
113 | | - return types.CallToolResult() |
114 | | - |
115 | | - async def list_prompts(self, *, params=None): |
116 | | - return types.ListPromptsResult() |
117 | | - |
118 | | - async def get_prompt(self, name, arguments=None, *, meta=None): |
119 | | - return types.GetPromptResult() |
120 | | - |
121 | | - async def list_tools(self, *, params=None): |
122 | | - return types.ListToolsResult() |
123 | | - |
124 | | - async def complete(self, ref, argument, context_arguments=None): |
125 | | - return types.CompleteResult() |
126 | | - |
127 | | - async def set_logging_level(self, level, *, meta=None): |
| 141 | + async def call_tool( |
| 142 | + self, |
| 143 | + name: str, |
| 144 | + arguments: dict[str, Any] | None = None, |
| 145 | + read_timeout_seconds: float | None = None, |
| 146 | + progress_callback: ProgressFnT | None = None, |
| 147 | + *, |
| 148 | + meta: RequestParamsMeta | None = None, |
| 149 | + ) -> types.CallToolResult: |
| 150 | + return types.CallToolResult(content=[]) |
| 151 | + |
| 152 | + async def list_prompts(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListPromptsResult: |
| 153 | + return types.ListPromptsResult(prompts=[]) |
| 154 | + |
| 155 | + async def get_prompt( |
| 156 | + self, |
| 157 | + name: str, |
| 158 | + arguments: dict[str, str] | None = None, |
| 159 | + *, |
| 160 | + meta: RequestParamsMeta | None = None, |
| 161 | + ) -> types.GetPromptResult: |
| 162 | + return types.GetPromptResult(messages=[]) |
| 163 | + |
| 164 | + async def list_tools(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListToolsResult: |
| 165 | + return types.ListToolsResult(tools=[]) |
| 166 | + |
| 167 | + async def complete( |
| 168 | + self, |
| 169 | + ref: types.ResourceTemplateReference | types.PromptReference, |
| 170 | + argument: dict[str, str], |
| 171 | + context_arguments: dict[str, str] | None = None, |
| 172 | + ) -> types.CompleteResult: |
| 173 | + return types.CompleteResult(completion=types.Completion(values=[])) |
| 174 | + |
| 175 | + async def set_logging_level( |
| 176 | + self, |
| 177 | + level: types.LoggingLevel, |
| 178 | + *, |
| 179 | + meta: RequestParamsMeta | None = None, |
| 180 | + ) -> types.EmptyResult: |
128 | 181 | return types.EmptyResult() |
129 | 182 |
|
130 | | - async def send_roots_list_changed(self): |
| 183 | + async def send_roots_list_changed(self) -> None: |
131 | 184 | pass |
132 | 185 |
|
133 | 186 |
|
|
0 commit comments