|
13 | 13 | from pydantic import AnyHttpUrl, AnyUrl |
14 | 14 |
|
15 | 15 | from mcp.client.auth import OAuthClientProvider, PKCEParameters |
| 16 | +from mcp.client.auth.exceptions import OAuthFlowError |
16 | 17 | from mcp.client.auth.utils import ( |
17 | 18 | build_oauth_authorization_server_metadata_discovery_urls, |
18 | 19 | build_protected_resource_metadata_discovery_urls, |
@@ -965,7 +966,7 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide |
965 | 966 | # Send a successful discovery response with minimal protected resource metadata |
966 | 967 | discovery_response = httpx.Response( |
967 | 968 | 200, |
968 | | - content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}', |
| 969 | + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', |
969 | 970 | request=discovery_request, |
970 | 971 | ) |
971 | 972 |
|
@@ -2030,3 +2031,105 @@ async def callback_handler() -> tuple[str, str | None]: |
2030 | 2031 | await auth_flow.asend(final_response) |
2031 | 2032 | except StopAsyncIteration: |
2032 | 2033 | pass |
| 2034 | + |
| 2035 | + |
| 2036 | +@pytest.mark.anyio |
| 2037 | +async def test_validate_resource_rejects_mismatched_resource( |
| 2038 | + client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 2039 | +) -> None: |
| 2040 | + """Client must reject PRM resource that doesn't match server URL.""" |
| 2041 | + provider = OAuthClientProvider( |
| 2042 | + server_url="https://api.example.com/v1/mcp", |
| 2043 | + client_metadata=client_metadata, |
| 2044 | + storage=mock_storage, |
| 2045 | + ) |
| 2046 | + provider._initialized = True |
| 2047 | + |
| 2048 | + prm = ProtectedResourceMetadata( |
| 2049 | + resource=AnyHttpUrl("https://evil.example.com/mcp"), |
| 2050 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 2051 | + ) |
| 2052 | + with pytest.raises(OAuthFlowError, match="does not match expected"): |
| 2053 | + await provider._validate_resource_match(prm) |
| 2054 | + |
| 2055 | + |
| 2056 | +@pytest.mark.anyio |
| 2057 | +async def test_validate_resource_accepts_matching_resource( |
| 2058 | + client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 2059 | +) -> None: |
| 2060 | + """Client must accept PRM resource that matches server URL.""" |
| 2061 | + provider = OAuthClientProvider( |
| 2062 | + server_url="https://api.example.com/v1/mcp", |
| 2063 | + client_metadata=client_metadata, |
| 2064 | + storage=mock_storage, |
| 2065 | + ) |
| 2066 | + provider._initialized = True |
| 2067 | + |
| 2068 | + prm = ProtectedResourceMetadata( |
| 2069 | + resource=AnyHttpUrl("https://api.example.com/v1/mcp"), |
| 2070 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 2071 | + ) |
| 2072 | + # Should not raise |
| 2073 | + await provider._validate_resource_match(prm) |
| 2074 | + |
| 2075 | + |
| 2076 | +@pytest.mark.anyio |
| 2077 | +async def test_validate_resource_accepts_root_url_with_trailing_slash( |
| 2078 | + client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 2079 | +) -> None: |
| 2080 | + """Root URLs with trailing slash normalization should match.""" |
| 2081 | + provider = OAuthClientProvider( |
| 2082 | + server_url="https://api.example.com", |
| 2083 | + client_metadata=client_metadata, |
| 2084 | + storage=mock_storage, |
| 2085 | + ) |
| 2086 | + provider._initialized = True |
| 2087 | + |
| 2088 | + prm = ProtectedResourceMetadata( |
| 2089 | + resource=AnyHttpUrl("https://api.example.com/"), |
| 2090 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 2091 | + ) |
| 2092 | + # Should not raise - both normalize to the same URL with trailing slash |
| 2093 | + await provider._validate_resource_match(prm) |
| 2094 | + |
| 2095 | + |
| 2096 | +@pytest.mark.anyio |
| 2097 | +async def test_validate_resource_match_when_resource_url_has_trailing_slash( |
| 2098 | + client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 2099 | +) -> None: |
| 2100 | + """Validation works when resource_url_from_server_url already returns a trailing slash.""" |
| 2101 | + provider = OAuthClientProvider( |
| 2102 | + server_url="https://api.example.com/", |
| 2103 | + client_metadata=client_metadata, |
| 2104 | + storage=mock_storage, |
| 2105 | + ) |
| 2106 | + provider._initialized = True |
| 2107 | + |
| 2108 | + prm = ProtectedResourceMetadata( |
| 2109 | + resource=AnyHttpUrl("https://api.example.com/"), |
| 2110 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 2111 | + ) |
| 2112 | + # Should not raise - default_resource already ends with / |
| 2113 | + await provider._validate_resource_match(prm) |
| 2114 | + |
| 2115 | + |
| 2116 | +@pytest.mark.anyio |
| 2117 | +async def test_get_resource_url_falls_back_when_prm_mismatches( |
| 2118 | + client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 2119 | +) -> None: |
| 2120 | + """get_resource_url returns canonical URL when PRM resource doesn't match.""" |
| 2121 | + provider = OAuthClientProvider( |
| 2122 | + server_url="https://api.example.com/v1/mcp", |
| 2123 | + client_metadata=client_metadata, |
| 2124 | + storage=mock_storage, |
| 2125 | + ) |
| 2126 | + provider._initialized = True |
| 2127 | + |
| 2128 | + # Set PRM with a resource that is NOT a parent of the server URL |
| 2129 | + provider.context.protected_resource_metadata = ProtectedResourceMetadata( |
| 2130 | + resource=AnyHttpUrl("https://other.example.com/mcp"), |
| 2131 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 2132 | + ) |
| 2133 | + |
| 2134 | + # get_resource_url should return the canonical server URL, not the PRM resource |
| 2135 | + assert provider.context.get_resource_url() == "https://api.example.com/v1/mcp" |
0 commit comments