|
13 | 13 | from pydantic import AnyHttpUrl, AnyUrl |
14 | 14 |
|
15 | 15 | from mcp.client.auth import OAuthClientProvider, PKCEParameters |
| 16 | +from mcp.client.auth.exceptions import OAuthFlowError |
16 | 17 | from mcp.client.auth.utils import ( |
17 | 18 | build_oauth_authorization_server_metadata_discovery_urls, |
18 | 19 | build_protected_resource_metadata_discovery_urls, |
@@ -965,7 +966,7 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide |
965 | 966 | # Send a successful discovery response with minimal protected resource metadata |
966 | 967 | discovery_response = httpx.Response( |
967 | 968 | 200, |
968 | | - content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}', |
| 969 | + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', |
969 | 970 | request=discovery_request, |
970 | 971 | ) |
971 | 972 |
|
@@ -2030,3 +2031,63 @@ async def callback_handler() -> tuple[str, str | None]: |
2030 | 2031 | await auth_flow.asend(final_response) |
2031 | 2032 | except StopAsyncIteration: |
2032 | 2033 | pass |
| 2034 | + |
| 2035 | + |
| 2036 | +@pytest.mark.anyio |
| 2037 | +async def test_validate_resource_rejects_mismatched_resource( |
| 2038 | + client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 2039 | +) -> None: |
| 2040 | + """Client must reject PRM resource that doesn't match server URL.""" |
| 2041 | + provider = OAuthClientProvider( |
| 2042 | + server_url="https://api.example.com/v1/mcp", |
| 2043 | + client_metadata=client_metadata, |
| 2044 | + storage=mock_storage, |
| 2045 | + ) |
| 2046 | + provider._initialized = True |
| 2047 | + |
| 2048 | + prm = ProtectedResourceMetadata( |
| 2049 | + resource=AnyHttpUrl("https://evil.example.com/mcp"), |
| 2050 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 2051 | + ) |
| 2052 | + with pytest.raises(OAuthFlowError, match="does not match expected"): |
| 2053 | + await provider._validate_resource_match(prm) |
| 2054 | + |
| 2055 | + |
| 2056 | +@pytest.mark.anyio |
| 2057 | +async def test_validate_resource_accepts_matching_resource( |
| 2058 | + client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 2059 | +) -> None: |
| 2060 | + """Client must accept PRM resource that matches server URL.""" |
| 2061 | + provider = OAuthClientProvider( |
| 2062 | + server_url="https://api.example.com/v1/mcp", |
| 2063 | + client_metadata=client_metadata, |
| 2064 | + storage=mock_storage, |
| 2065 | + ) |
| 2066 | + provider._initialized = True |
| 2067 | + |
| 2068 | + prm = ProtectedResourceMetadata( |
| 2069 | + resource=AnyHttpUrl("https://api.example.com/v1/mcp"), |
| 2070 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 2071 | + ) |
| 2072 | + # Should not raise |
| 2073 | + await provider._validate_resource_match(prm) |
| 2074 | + |
| 2075 | + |
| 2076 | +@pytest.mark.anyio |
| 2077 | +async def test_validate_resource_accepts_root_url_with_trailing_slash( |
| 2078 | + client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 2079 | +) -> None: |
| 2080 | + """Root URLs with trailing slash normalization should match.""" |
| 2081 | + provider = OAuthClientProvider( |
| 2082 | + server_url="https://api.example.com", |
| 2083 | + client_metadata=client_metadata, |
| 2084 | + storage=mock_storage, |
| 2085 | + ) |
| 2086 | + provider._initialized = True |
| 2087 | + |
| 2088 | + prm = ProtectedResourceMetadata( |
| 2089 | + resource=AnyHttpUrl("https://api.example.com/"), |
| 2090 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 2091 | + ) |
| 2092 | + # Should not raise - both normalize to the same URL with trailing slash |
| 2093 | + await provider._validate_resource_match(prm) |
0 commit comments