Skip to content

Commit 9b06bce

Browse files
committed
refactor: enhance BaseClientSession with type hints and method signatures
- Added type hints to methods in StubClientSession for better clarity and type safety. - Updated the test for BaseClientSession to assert the type of content returned. - Improved method signatures to include specific types for parameters and return values, enhancing code readability and maintainability.
1 parent 6b6a73f commit 9b06bce

File tree

1 file changed

+86
-33
lines changed

1 file changed

+86
-33
lines changed

tests/client/test_base_client_session.py

Lines changed: 86 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22

33
from __future__ import annotations
44

5+
from typing import Any
6+
57
import pytest
68

79
from mcp import types
810
from mcp.client import BaseClientSession
911
from mcp.client.client import Client
1012
from mcp.client.session import ClientSession
1113
from mcp.server.mcpserver import MCPServer
14+
from mcp.shared.session import ProgressFnT
15+
from mcp.types._types import RequestParamsMeta
1216

1317
pytestmark = pytest.mark.anyio
1418

@@ -44,7 +48,9 @@ def echo(text: str) -> str:
4448
result = await session.call_tool("echo", {"text": "hello"})
4549
assert result.is_error is False
4650
assert len(result.content) == 1
47-
assert result.content[0].text == "hello"
51+
first_content = result.content[0]
52+
assert isinstance(first_content, types.TextContent)
53+
assert first_content.text == "hello"
4854

4955
# List tools through the Protocol interface
5056
tools_result = await session.list_tools()
@@ -78,56 +84,103 @@ class StubClientSession:
7884
"""Minimal stub that satisfies BaseClientSession protocol."""
7985

8086
async def send_request(
81-
self, request, result_type, request_read_timeout_seconds=None, metadata=None, progress_callback=None
82-
):
87+
self,
88+
request: types.ClientRequest,
89+
result_type: type[Any],
90+
request_read_timeout_seconds: float | None = None,
91+
metadata: Any = None,
92+
progress_callback: ProgressFnT | None = None,
93+
) -> Any:
8394
return types.EmptyResult()
8495

85-
async def send_notification(self, notification, related_request_id=None):
96+
async def send_notification(
97+
self,
98+
notification: types.ClientNotification,
99+
related_request_id: Any = None,
100+
) -> None:
86101
pass
87102

88-
async def send_progress_notification(self, progress_token, progress, total=None, message=None, *, meta=None):
103+
async def send_progress_notification(
104+
self,
105+
progress_token: types.ProgressToken,
106+
progress: float,
107+
total: float | None = None,
108+
message: str | None = None,
109+
*,
110+
meta: RequestParamsMeta | None = None,
111+
) -> None:
89112
pass
90113

91-
async def initialize(self):
92-
return types.InitializeResult()
114+
async def initialize(self) -> types.InitializeResult:
115+
return types.InitializeResult(
116+
protocol_version="2024-11-05",
117+
capabilities=types.ServerCapabilities(),
118+
server_info=types.Implementation(name="stub", version="0"),
119+
)
93120

94-
async def send_ping(self, *, meta=None):
121+
async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
95122
return types.EmptyResult()
96123

97-
async def list_resources(self, *, params=None):
98-
return types.ListResourcesResult()
124+
async def list_resources(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListResourcesResult:
125+
return types.ListResourcesResult(resources=[])
99126

100-
async def list_resource_templates(self, *, params=None):
101-
return types.ListResourceTemplatesResult()
127+
async def list_resource_templates(
128+
self, *, params: types.PaginatedRequestParams | None = None
129+
) -> types.ListResourceTemplatesResult:
130+
return types.ListResourceTemplatesResult(resource_templates=[])
102131

103-
async def read_resource(self, uri, *, meta=None):
104-
return types.ReadResourceResult()
132+
async def read_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.ReadResourceResult:
133+
return types.ReadResourceResult(contents=[])
105134

106-
async def subscribe_resource(self, uri, *, meta=None):
135+
async def subscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
107136
return types.EmptyResult()
108137

109-
async def unsubscribe_resource(self, uri, *, meta=None):
138+
async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
110139
return types.EmptyResult()
111140

112-
async def call_tool(self, name, arguments=None, read_timeout_seconds=None, progress_callback=None, *, meta=None):
113-
return types.CallToolResult()
114-
115-
async def list_prompts(self, *, params=None):
116-
return types.ListPromptsResult()
117-
118-
async def get_prompt(self, name, arguments=None, *, meta=None):
119-
return types.GetPromptResult()
120-
121-
async def list_tools(self, *, params=None):
122-
return types.ListToolsResult()
123-
124-
async def complete(self, ref, argument, context_arguments=None):
125-
return types.CompleteResult()
126-
127-
async def set_logging_level(self, level, *, meta=None):
141+
async def call_tool(
142+
self,
143+
name: str,
144+
arguments: dict[str, Any] | None = None,
145+
read_timeout_seconds: float | None = None,
146+
progress_callback: ProgressFnT | None = None,
147+
*,
148+
meta: RequestParamsMeta | None = None,
149+
) -> types.CallToolResult:
150+
return types.CallToolResult(content=[])
151+
152+
async def list_prompts(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListPromptsResult:
153+
return types.ListPromptsResult(prompts=[])
154+
155+
async def get_prompt(
156+
self,
157+
name: str,
158+
arguments: dict[str, str] | None = None,
159+
*,
160+
meta: RequestParamsMeta | None = None,
161+
) -> types.GetPromptResult:
162+
return types.GetPromptResult(messages=[])
163+
164+
async def list_tools(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListToolsResult:
165+
return types.ListToolsResult(tools=[])
166+
167+
async def complete(
168+
self,
169+
ref: types.ResourceTemplateReference | types.PromptReference,
170+
argument: dict[str, str],
171+
context_arguments: dict[str, str] | None = None,
172+
) -> types.CompleteResult:
173+
return types.CompleteResult(completion=types.Completion(values=[]))
174+
175+
async def set_logging_level(
176+
self,
177+
level: types.LoggingLevel,
178+
*,
179+
meta: RequestParamsMeta | None = None,
180+
) -> types.EmptyResult:
128181
return types.EmptyResult()
129182

130-
async def send_roots_list_changed(self):
183+
async def send_roots_list_changed(self) -> None:
131184
pass
132185

133186

0 commit comments

Comments
 (0)