Skip to content

Commit e4ec882

Browse files
committed
chore: create private modules, and drop unnecessary modules
1 parent b43491a commit e4ec882

File tree

15 files changed

+116
-149
lines changed

15 files changed

+116
-149
lines changed

examples/servers/simple-auth/mcp_simple_auth/token_verifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any
55

66
from mcp.server.auth.provider import AccessToken, TokenVerifier
7-
from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url
7+
from mcp.shared._auth_utils import check_resource_allowed, resource_url_from_server_url
88

99
logger = logging.getLogger(__name__)
1010

src/mcp/client/auth/oauth2.py

Lines changed: 87 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
Implements authorization code flow with PKCE and automatic token refresh.
55
"""
66

7+
from __future__ import annotations as _annotations
8+
79
import base64
810
import hashlib
911
import logging
@@ -13,11 +15,11 @@
1315
from collections.abc import AsyncGenerator, Awaitable, Callable
1416
from dataclasses import dataclass, field
1517
from typing import Any, Protocol
16-
from urllib.parse import quote, urlencode, urljoin, urlparse
18+
from urllib.parse import quote, urlencode, urljoin, urlparse, urlsplit, urlunsplit
1719

1820
import anyio
1921
import httpx
20-
from pydantic import BaseModel, Field, ValidationError
22+
from pydantic import AnyUrl, BaseModel, Field, HttpUrl, ValidationError
2123

2224
from mcp.client.auth.exceptions import OAuthFlowError, OAuthTokenError
2325
from mcp.client.auth.utils import (
@@ -45,11 +47,6 @@
4547
OAuthToken,
4648
ProtectedResourceMetadata,
4749
)
48-
from mcp.shared.auth_utils import (
49-
calculate_token_expiry,
50-
check_resource_allowed,
51-
resource_url_from_server_url,
52-
)
5350

5451
logger = logging.getLogger(__name__)
5552

@@ -61,7 +58,7 @@ class PKCEParameters(BaseModel):
6158
code_challenge: str = Field(..., min_length=43, max_length=128)
6259

6360
@classmethod
64-
def generate(cls) -> "PKCEParameters":
61+
def generate(cls) -> PKCEParameters:
6562
"""Generate new PKCE parameters."""
6663
code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128))
6764
digest = hashlib.sha256(code_verifier.encode()).digest()
@@ -74,19 +71,15 @@ class TokenStorage(Protocol):
7471

7572
async def get_tokens(self) -> OAuthToken | None:
7673
"""Get stored tokens."""
77-
...
7874

7975
async def set_tokens(self, tokens: OAuthToken) -> None:
8076
"""Store tokens."""
81-
...
8277

8378
async def get_client_info(self) -> OAuthClientInformationFull | None:
8479
"""Get stored client information."""
85-
...
8680

8781
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
8882
"""Store client information."""
89-
...
9083

9184

9285
@dataclass
@@ -124,7 +117,7 @@ def get_authorization_base_url(self, server_url: str) -> str:
124117

125118
def update_token_expiry(self, token: OAuthToken) -> None:
126119
"""Update token expiry time using shared util function."""
127-
self.token_expiry_time = calculate_token_expiry(token.expires_in)
120+
self.token_expiry_time = _calculate_token_expiry(token.expires_in)
128121

129122
def is_token_valid(self) -> bool:
130123
"""Check if current token is valid."""
@@ -148,12 +141,12 @@ def get_resource_url(self) -> str:
148141
149142
Uses PRM resource if it's a valid parent, otherwise uses canonical server URL.
150143
"""
151-
resource = resource_url_from_server_url(self.server_url)
144+
resource = _resource_url_from_server_url(self.server_url)
152145

153146
# If PRM provides a resource that's a valid parent, use it
154147
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
155148
prm_resource = str(self.protected_resource_metadata.resource)
156-
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
149+
if _check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
157150
resource = prm_resource
158151

159152
return resource
@@ -614,3 +607,82 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
614607
# Retry with new tokens
615608
self._add_auth_header(request)
616609
yield request
610+
611+
612+
def _resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str:
613+
"""Convert server URL to canonical resource URL per RFC 8707.
614+
615+
RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component".
616+
Returns absolute URI with lowercase scheme/host for canonical form.
617+
618+
Args:
619+
url: Server URL to convert
620+
621+
Returns:
622+
Canonical resource URL string
623+
"""
624+
# Convert to string if needed
625+
url_str = str(url)
626+
627+
# Parse the URL and remove fragment, create canonical form
628+
parsed = urlsplit(url_str)
629+
canonical = urlunsplit(parsed._replace(scheme=parsed.scheme.lower(), netloc=parsed.netloc.lower(), fragment=""))
630+
631+
return canonical
632+
633+
634+
def _check_resource_allowed(requested_resource: str, configured_resource: str) -> bool:
635+
"""Check if a requested resource URL matches a configured resource URL.
636+
637+
A requested resource matches if it has the same scheme, domain, port,
638+
and its path starts with the configured resource's path. This allows
639+
hierarchical matching where a token for a parent resource can be used
640+
for child resources.
641+
642+
Args:
643+
requested_resource: The resource URL being requested
644+
configured_resource: The resource URL that has been configured
645+
646+
Returns:
647+
True if the requested resource matches the configured resource
648+
"""
649+
# Parse both URLs
650+
requested = urlparse(requested_resource)
651+
configured = urlparse(configured_resource)
652+
653+
# Compare scheme, host, and port (origin)
654+
if requested.scheme.lower() != configured.scheme.lower() or requested.netloc.lower() != configured.netloc.lower():
655+
return False
656+
657+
# Handle cases like requested=/foo and configured=/foo/
658+
requested_path = requested.path
659+
configured_path = configured.path
660+
661+
# If requested path is shorter, it cannot be a child
662+
if len(requested_path) < len(configured_path):
663+
return False
664+
665+
# Check if the requested path starts with the configured path
666+
# Ensure both paths end with / for proper comparison
667+
# This ensures that paths like "/api123" don't incorrectly match "/api"
668+
if not requested_path.endswith("/"):
669+
requested_path += "/"
670+
if not configured_path.endswith("/"):
671+
configured_path += "/"
672+
673+
return requested_path.startswith(configured_path)
674+
675+
676+
def _calculate_token_expiry(expires_in: int | str | None) -> float | None:
677+
"""Calculate token expiry timestamp from expires_in seconds.
678+
679+
Args:
680+
expires_in: Seconds until token expiration (may be string from some servers)
681+
682+
Returns:
683+
Unix timestamp when token expires, or None if no expiry specified
684+
"""
685+
if expires_in is None:
686+
return None # pragma: no cover
687+
# Defensive: handle servers that return expires_in as string
688+
return time.time() + int(expires_in)

src/mcp/client/session.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations as _annotations
2+
13
import logging
24
from typing import Any, Protocol, overload
35

@@ -22,22 +24,22 @@
2224
class SamplingFnT(Protocol):
2325
async def __call__(
2426
self,
25-
context: RequestContext["ClientSession", Any],
27+
context: RequestContext[ClientSession, Any],
2628
params: types.CreateMessageRequestParams,
2729
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch
2830

2931

3032
class ElicitationFnT(Protocol):
3133
async def __call__(
3234
self,
33-
context: RequestContext["ClientSession", Any],
35+
context: RequestContext[ClientSession, Any],
3436
params: types.ElicitRequestParams,
3537
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch
3638

3739

3840
class ListRootsFnT(Protocol):
3941
async def __call__(
40-
self, context: RequestContext["ClientSession", Any]
42+
self, context: RequestContext[ClientSession, Any]
4143
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch
4244

4345

@@ -62,7 +64,7 @@ async def _default_message_handler(
6264

6365

6466
async def _default_sampling_callback(
65-
context: RequestContext["ClientSession", Any],
67+
context: RequestContext[ClientSession, Any],
6668
params: types.CreateMessageRequestParams,
6769
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData:
6870
return types.ErrorData(
@@ -72,7 +74,7 @@ async def _default_sampling_callback(
7274

7375

7476
async def _default_elicitation_callback(
75-
context: RequestContext["ClientSession", Any],
77+
context: RequestContext[ClientSession, Any],
7678
params: types.ElicitRequestParams,
7779
) -> types.ElicitResult | types.ErrorData:
7880
return types.ErrorData( # pragma: no cover
@@ -82,7 +84,7 @@ async def _default_elicitation_callback(
8284

8385

8486
async def _default_list_roots_callback(
85-
context: RequestContext["ClientSession", Any],
87+
context: RequestContext[ClientSession, Any],
8688
) -> types.ListRootsResult | types.ErrorData:
8789
return types.ErrorData(
8890
code=types.INVALID_REQUEST,

src/mcp/client/session_group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
class SseServerParameters(BaseModel):
34-
"""Parameters for intializing a sse_client."""
34+
"""Parameters for initializing a sse_client."""
3535

3636
# The endpoint URL.
3737
url: str

src/mcp/client/websocket.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from __future__ import annotations as _annotations
2+
13
import json
2-
import logging
34
from collections.abc import AsyncGenerator
45
from contextlib import asynccontextmanager
56

@@ -12,8 +13,6 @@
1213
import mcp.types as types
1314
from mcp.shared.message import SessionMessage
1415

15-
logger = logging.getLogger(__name__)
16-
1716

1817
@asynccontextmanager
1918
async def websocket_client(
@@ -64,10 +63,7 @@ async def ws_reader():
6463
await read_stream_writer.send(exc)
6564

6665
async def ws_writer():
67-
"""
68-
Reads JSON-RPC messages from write_stream_reader and
69-
sends them to the server.
70-
"""
66+
"""Reads JSON-RPC messages from write_stream_reader and sends them to the server."""
7167
async with write_stream_reader:
7268
async for session_message in write_stream_reader:
7369
# Convert to a dict, then to JSON

src/mcp/server/fastmcp/tools/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
1313
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
1414
from mcp.shared.exceptions import UrlElicitationRequiredError
15-
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
15+
from mcp.shared._tool_name_validation import validate_and_warn_tool_name
1616
from mcp.types import Icon, ToolAnnotations
1717

1818
if TYPE_CHECKING:

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,11 @@ async def main():
8989
from mcp.server.lowlevel.helper_types import ReadResourceContents
9090
from mcp.server.models import InitializationOptions
9191
from mcp.server.session import ServerSession
92+
from mcp.shared._tool_name_validation import validate_and_warn_tool_name
9293
from mcp.shared.context import RequestContext
9394
from mcp.shared.exceptions import McpError, UrlElicitationRequiredError
9495
from mcp.shared.message import ServerMessageMetadata, SessionMessage
9596
from mcp.shared.session import RequestResponder
96-
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
9797

9898
logger = logging.getLogger(__name__)
9999

src/mcp/server/models.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55

66
from pydantic import BaseModel
77

8-
from mcp.types import (
9-
Icon,
10-
ServerCapabilities,
11-
)
8+
from mcp.types import Icon, ServerCapabilities
129

1310

1411
class InitializationOptions(BaseModel):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
See: https://modelcontextprotocol.io/specification/2025-11-25/server/tools#tool-names
1010
"""
1111

12-
from __future__ import annotations
12+
from __future__ import annotations as _annotations
1313

1414
import logging
1515
import re

0 commit comments

Comments
 (0)