Skip to content

Commit d0debb6

Browse files
Theodor N. EngøyTheodor N. Engøy
authored andcommitted
auth: restrict CORS to loopback by default
1 parent dda845a commit d0debb6

File tree

5 files changed

+60
-1
lines changed

5 files changed

+60
-1
lines changed

src/mcp/server/auth/routes.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
2121
from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata
2222

23+
DEFAULT_AUTH_CORS_ORIGIN_REGEX = r"^https?://(localhost|127\.0\.0\.1|\[::1\])(?::\d+)?$"
24+
2325

2426
def validate_issuer_url(url: AnyHttpUrl):
2527
"""Validate that the issuer URL meets OAuth 2.0 requirements.
@@ -55,10 +57,17 @@ def validate_issuer_url(url: AnyHttpUrl):
5557
def cors_middleware(
5658
handler: Callable[[Request], Response | Awaitable[Response]],
5759
allow_methods: list[str],
60+
*,
61+
allow_origin_regex: str | None = None,
5862
) -> ASGIApp:
63+
# Default: allow loopback browser clients (e.g., MCP Inspector) without allowing arbitrary sites.
64+
if allow_origin_regex is None:
65+
allow_origin_regex = DEFAULT_AUTH_CORS_ORIGIN_REGEX
66+
5967
cors_app = CORSMiddleware(
6068
app=request_response(handler),
61-
allow_origins="*",
69+
allow_origins=[],
70+
allow_origin_regex=allow_origin_regex,
6271
allow_methods=allow_methods,
6372
allow_headers=[MCP_PROTOCOL_VERSION_HEADER],
6473
)
@@ -71,6 +80,7 @@ def create_auth_routes(
7180
service_documentation_url: AnyHttpUrl | None = None,
7281
client_registration_options: ClientRegistrationOptions | None = None,
7382
revocation_options: RevocationOptions | None = None,
83+
cors_origin_regex: str | None = None,
7484
) -> list[Route]:
7585
validate_issuer_url(issuer_url)
7686

@@ -94,6 +104,7 @@ def create_auth_routes(
94104
endpoint=cors_middleware(
95105
MetadataHandler(metadata).handle,
96106
["GET", "OPTIONS"],
107+
allow_origin_regex=cors_origin_regex,
97108
),
98109
methods=["GET", "OPTIONS"],
99110
),
@@ -109,6 +120,7 @@ def create_auth_routes(
109120
endpoint=cors_middleware(
110121
TokenHandler(provider, client_authenticator).handle,
111122
["POST", "OPTIONS"],
123+
allow_origin_regex=cors_origin_regex,
112124
),
113125
methods=["POST", "OPTIONS"],
114126
),
@@ -125,6 +137,7 @@ def create_auth_routes(
125137
endpoint=cors_middleware(
126138
registration_handler.handle,
127139
["POST", "OPTIONS"],
140+
allow_origin_regex=cors_origin_regex,
128141
),
129142
methods=["POST", "OPTIONS"],
130143
)
@@ -138,6 +151,7 @@ def create_auth_routes(
138151
endpoint=cors_middleware(
139152
revocation_handler.handle,
140153
["POST", "OPTIONS"],
154+
allow_origin_regex=cors_origin_regex,
141155
),
142156
methods=["POST", "OPTIONS"],
143157
)

src/mcp/server/auth/settings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ class AuthSettings(BaseModel):
2121
client_registration_options: ClientRegistrationOptions | None = None
2222
revocation_options: RevocationOptions | None = None
2323
required_scopes: list[str] | None = None
24+
cors_origin_regex: str | None = Field(
25+
default=None,
26+
description=(
27+
"Regex for allowed browser Origin values on the authorization server endpoints "
28+
"(/token, /register, /.well-known/oauth-authorization-server, etc). "
29+
"If unset, a safe default allows only loopback origins (localhost/127.0.0.1/[::1])."
30+
),
31+
)
2432

2533
# Resource Server settings (when operating as RS only)
2634
resource_server_url: AnyHttpUrl | None = Field(

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,7 @@ def streamable_http_app(
867867
service_documentation_url=auth.service_documentation_url,
868868
client_registration_options=auth.client_registration_options,
869869
revocation_options=auth.revocation_options,
870+
cors_origin_regex=auth.cors_origin_regex,
870871
)
871872
)
872873

src/mcp/server/mcpserver/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no
840840
service_documentation_url=self.settings.auth.service_documentation_url,
841841
client_registration_options=self.settings.auth.client_registration_options,
842842
revocation_options=self.settings.auth.revocation_options,
843+
cors_origin_regex=self.settings.auth.cors_origin_regex,
843844
)
844845
)
845846

tests/server/mcpserver/auth/test_auth_integration.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
construct_redirect_uri,
2323
)
2424
from mcp.server.auth.routes import ClientRegistrationOptions, RevocationOptions, create_auth_routes
25+
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
2526
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
2627

2728

@@ -325,6 +326,40 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient):
325326
]
326327
assert metadata["service_documentation"] == "https://docs.example.com/"
327328

329+
@pytest.mark.anyio
330+
async def test_cors_allows_loopback_origin_by_default(self, test_client: httpx.AsyncClient):
331+
origin = "http://localhost:5173"
332+
response = await test_client.get(
333+
"/.well-known/oauth-authorization-server",
334+
headers={"Origin": origin},
335+
)
336+
assert response.status_code == 200
337+
assert response.headers.get("access-control-allow-origin") == origin
338+
339+
@pytest.mark.anyio
340+
async def test_cors_blocks_non_loopback_origin_by_default(self, test_client: httpx.AsyncClient):
341+
origin = "https://evil.example"
342+
response = await test_client.get(
343+
"/.well-known/oauth-authorization-server",
344+
headers={"Origin": origin},
345+
)
346+
assert response.status_code == 200
347+
assert "access-control-allow-origin" not in response.headers
348+
349+
@pytest.mark.anyio
350+
async def test_cors_preflight_allows_loopback_origin_by_default(self, test_client: httpx.AsyncClient):
351+
origin = "http://127.0.0.1:3000"
352+
response = await test_client.options(
353+
"/token",
354+
headers={
355+
"Origin": origin,
356+
"Access-Control-Request-Method": "POST",
357+
"Access-Control-Request-Headers": MCP_PROTOCOL_VERSION_HEADER,
358+
},
359+
)
360+
assert response.status_code == 200
361+
assert response.headers.get("access-control-allow-origin") == origin
362+
328363
@pytest.mark.anyio
329364
async def test_token_validation_error(self, test_client: httpx.AsyncClient):
330365
"""Test token endpoint error - validation error."""

0 commit comments

Comments
 (0)