Skip to content

Commit 3c0aaff

Browse files
committed
fix(auth): OAuthClientProvider async_auth_flow lock 버그 수정
async with self.context.lock: 블록이 모든 yield 문을 감싸고 있어 generator suspend/resume 시 RuntimeError 발생. 해결: 각 yield 지점 전후로 lock acquire/release. 기존 기능 유지, 테스트 통과 (166 passed).
1 parent 239d682 commit 3c0aaff

File tree

1 file changed

+96
-56
lines changed

1 file changed

+96
-56
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 96 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,14 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None
503503
raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}")
504504

505505
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
506-
"""HTTPX auth flow integration."""
506+
"""HTTPX auth flow integration.
507+
508+
Note: We acquire/release the lock around each yield point to avoid
509+
holding the lock across generator suspensions, which can cause
510+
"current task is not holding this lock" errors when the generator
511+
is resumed in a different task context.
512+
"""
513+
# Phase 1: Initialize and check token validity (with lock)
507514
async with self.context.lock:
508515
if not self._initialized:
509516
await self._initialize() # pragma: no cover
@@ -514,33 +521,46 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
514521
if not self.context.is_token_valid() and self.context.can_refresh_token():
515522
# Try to refresh token
516523
refresh_request = await self._refresh_token() # pragma: no cover
517-
refresh_response = yield refresh_request # pragma: no cover
524+
else:
525+
refresh_request = None # pragma: no cover
518526

519-
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
527+
if self.context.is_token_valid():
528+
self._add_auth_header(request)
529+
530+
# Phase 2: Refresh token if needed (yield WITHOUT lock held)
531+
if refresh_request is not None: # pragma: no cover
532+
refresh_response = yield refresh_request
533+
534+
async with self.context.lock:
535+
if not await self._handle_refresh_response(refresh_response):
520536
# Refresh failed, need full re-authentication
521537
self._initialized = False
522538

523-
if self.context.is_token_valid():
524-
self._add_auth_header(request)
539+
async with self.context.lock:
540+
if self.context.is_token_valid():
541+
self._add_auth_header(request)
525542

526543
response = yield request
544+
else:
545+
response = yield request
527546

528-
if response.status_code == 401:
529-
# Perform full OAuth flow
530-
try:
531-
# OAuth flow must be inline due to generator constraints
532-
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)
547+
if response.status_code == 401:
548+
# Perform full OAuth flow (each step releases lock around yield)
549+
try:
550+
# OAuth flow must be inline due to generator constraints
551+
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)
533552

534-
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
535-
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
536-
www_auth_resource_metadata_url, self.context.server_url
537-
)
553+
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
554+
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
555+
www_auth_resource_metadata_url, self.context.server_url
556+
)
538557

539-
for url in prm_discovery_urls: # pragma: no branch
540-
discovery_request = create_oauth_metadata_request(url)
558+
for url in prm_discovery_urls: # pragma: no branch
559+
discovery_request = create_oauth_metadata_request(url)
541560

542-
discovery_response = yield discovery_request # sending request
561+
discovery_response = yield discovery_request # sending request
543562

563+
async with self.context.lock:
544564
prm = await handle_protected_resource_response(discovery_response)
545565
if prm:
546566
# Validate PRM resource matches server URL (RFC 8707)
@@ -553,19 +573,22 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
553573
) # this is always true as authorization_servers has a min length of 1
554574

555575
self.context.auth_server_url = str(prm.authorization_servers[0])
556-
break
557-
else:
558-
logger.debug(f"Protected resource metadata discovery failed: {url}")
576+
if prm:
577+
break
578+
else:
579+
logger.debug(f"Protected resource metadata discovery failed: {url}")
559580

581+
async with self.context.lock:
560582
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
561583
self.context.auth_server_url, self.context.server_url
562584
)
563585

564-
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
565-
for url in asm_discovery_urls: # pragma: no branch
566-
oauth_metadata_request = create_oauth_metadata_request(url)
567-
oauth_metadata_response = yield oauth_metadata_request
586+
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
587+
for url in asm_discovery_urls: # pragma: no branch
588+
oauth_metadata_request = create_oauth_metadata_request(url)
589+
oauth_metadata_response = yield oauth_metadata_request
568590

591+
async with self.context.lock:
569592
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
570593
if not ok:
571594
break
@@ -575,67 +598,84 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
575598
else:
576599
logger.debug(f"OAuth metadata discovery failed: {url}")
577600

578-
# Step 3: Apply scope selection strategy
601+
# Step 3: Apply scope selection strategy
602+
async with self.context.lock:
579603
self.context.client_metadata.scope = get_client_metadata_scopes(
580604
extract_scope_from_www_auth(response),
581605
self.context.protected_resource_metadata,
582606
self.context.oauth_metadata,
583607
)
584608

585-
# Step 4: Register client or use URL-based client ID (CIMD)
586-
if not self.context.client_info:
587-
if should_use_client_metadata_url(
588-
self.context.oauth_metadata, self.context.client_metadata_url
589-
):
590-
# Use URL-based client ID (CIMD)
609+
# Step 4: Register client or use URL-based client ID (CIMD)
610+
async with self.context.lock:
611+
need_registration = not self.context.client_info
612+
use_cimd = need_registration and should_use_client_metadata_url(
613+
self.context.oauth_metadata, self.context.client_metadata_url
614+
)
615+
616+
if need_registration:
617+
if use_cimd:
618+
# Use URL-based client ID (CIMD)
619+
async with self.context.lock:
591620
logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}")
592621
client_information = create_client_info_from_metadata_url(
593622
self.context.client_metadata_url, # type: ignore[arg-type]
594623
redirect_uris=self.context.client_metadata.redirect_uris,
595624
)
596625
self.context.client_info = client_information
597626
await self.context.storage.set_client_info(client_information)
598-
else:
599-
# Fallback to Dynamic Client Registration
627+
else:
628+
# Fallback to Dynamic Client Registration
629+
async with self.context.lock:
600630
registration_request = create_client_registration_request(
601631
self.context.oauth_metadata,
602632
self.context.client_metadata,
603633
self.context.get_authorization_base_url(self.context.server_url),
604634
)
605-
registration_response = yield registration_request
635+
registration_response = yield registration_request
636+
async with self.context.lock:
606637
client_information = await handle_registration_response(registration_response)
607638
self.context.client_info = client_information
608639
await self.context.storage.set_client_info(client_information)
609640

610-
# Step 5: Perform authorization and complete token exchange
611-
token_response = yield await self._perform_authorization()
641+
# Step 5: Perform authorization and complete token exchange
642+
async with self.context.lock:
643+
authorization_request = await self._perform_authorization()
644+
token_response = yield authorization_request
645+
async with self.context.lock:
612646
await self._handle_token_response(token_response)
613-
except Exception: # pragma: no cover
614-
logger.exception("OAuth flow error")
615-
raise
647+
except Exception: # pragma: no cover
648+
logger.exception("OAuth flow error")
649+
raise
616650

617-
# Retry with new tokens
651+
# Retry with new tokens
652+
async with self.context.lock:
618653
self._add_auth_header(request)
619-
yield request
620-
elif response.status_code == 403:
621-
# Step 1: Extract error field from WWW-Authenticate header
622-
error = extract_field_from_www_auth(response, "error")
623-
624-
# Step 2: Check if we need to step-up authorization
625-
if error == "insufficient_scope": # pragma: no branch
626-
try:
627-
# Step 2a: Update the required scopes
654+
yield request
655+
elif response.status_code == 403:
656+
# Step 1: Extract error field from WWW-Authenticate header
657+
error = extract_field_from_www_auth(response, "error")
658+
659+
# Step 2: Check if we need to step-up authorization
660+
if error == "insufficient_scope": # pragma: no branch
661+
try:
662+
# Step 2a: Update the required scopes
663+
async with self.context.lock:
628664
self.context.client_metadata.scope = get_client_metadata_scopes(
629665
extract_scope_from_www_auth(response), self.context.protected_resource_metadata
630666
)
631667

632-
# Step 2b: Perform (re-)authorization and token exchange
633-
token_response = yield await self._perform_authorization()
668+
# Step 2b: Perform (re-)authorization and token exchange
669+
async with self.context.lock:
670+
authorization_request = await self._perform_authorization()
671+
token_response = yield authorization_request
672+
async with self.context.lock:
634673
await self._handle_token_response(token_response)
635-
except Exception: # pragma: no cover
636-
logger.exception("OAuth flow error")
637-
raise
674+
except Exception: # pragma: no cover
675+
logger.exception("OAuth flow error")
676+
raise
638677

639-
# Retry with new tokens
678+
# Retry with new tokens
679+
async with self.context.lock:
640680
self._add_auth_header(request)
641-
yield request
681+
yield request

0 commit comments

Comments
 (0)