@@ -503,7 +503,14 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None
503503 raise OAuthFlowError (f"Protected resource { prm_resource } does not match expected { default_resource } " )
504504
505505 async def async_auth_flow (self , request : httpx .Request ) -> AsyncGenerator [httpx .Request , httpx .Response ]:
506- """HTTPX auth flow integration."""
506+ """HTTPX auth flow integration.
507+
508+ Note: We acquire/release the lock around each yield point to avoid
509+ holding the lock across generator suspensions, which can cause
510+ "current task is not holding this lock" errors when the generator
511+ is resumed in a different task context.
512+ """
513+ # Phase 1: Initialize and check token validity (with lock)
507514 async with self .context .lock :
508515 if not self ._initialized :
509516 await self ._initialize () # pragma: no cover
@@ -514,33 +521,46 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
514521 if not self .context .is_token_valid () and self .context .can_refresh_token ():
515522 # Try to refresh token
516523 refresh_request = await self ._refresh_token () # pragma: no cover
517- refresh_response = yield refresh_request # pragma: no cover
524+ else :
525+ refresh_request = None # pragma: no cover
518526
519- if not await self ._handle_refresh_response (refresh_response ): # pragma: no cover
527+ if self .context .is_token_valid ():
528+ self ._add_auth_header (request )
529+
530+ # Phase 2: Refresh token if needed (yield WITHOUT lock held)
531+ if refresh_request is not None : # pragma: no cover
532+ refresh_response = yield refresh_request
533+
534+ async with self .context .lock :
535+ if not await self ._handle_refresh_response (refresh_response ):
520536 # Refresh failed, need full re-authentication
521537 self ._initialized = False
522538
523- if self .context .is_token_valid ():
524- self ._add_auth_header (request )
539+ async with self .context .lock :
540+ if self .context .is_token_valid ():
541+ self ._add_auth_header (request )
525542
526543 response = yield request
544+ else :
545+ response = yield request
527546
528- if response .status_code == 401 :
529- # Perform full OAuth flow
530- try :
531- # OAuth flow must be inline due to generator constraints
532- www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth (response )
547+ if response .status_code == 401 :
548+ # Perform full OAuth flow (each step releases lock around yield)
549+ try :
550+ # OAuth flow must be inline due to generator constraints
551+ www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth (response )
533552
534- # Step 1: Discover protected resource metadata (SEP-985 with fallback support)
535- prm_discovery_urls = build_protected_resource_metadata_discovery_urls (
536- www_auth_resource_metadata_url , self .context .server_url
537- )
553+ # Step 1: Discover protected resource metadata (SEP-985 with fallback support)
554+ prm_discovery_urls = build_protected_resource_metadata_discovery_urls (
555+ www_auth_resource_metadata_url , self .context .server_url
556+ )
538557
539- for url in prm_discovery_urls : # pragma: no branch
540- discovery_request = create_oauth_metadata_request (url )
558+ for url in prm_discovery_urls : # pragma: no branch
559+ discovery_request = create_oauth_metadata_request (url )
541560
542- discovery_response = yield discovery_request # sending request
561+ discovery_response = yield discovery_request # sending request
543562
563+ async with self .context .lock :
544564 prm = await handle_protected_resource_response (discovery_response )
545565 if prm :
546566 # Validate PRM resource matches server URL (RFC 8707)
@@ -553,19 +573,22 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
553573 ) # this is always true as authorization_servers has a min length of 1
554574
555575 self .context .auth_server_url = str (prm .authorization_servers [0 ])
556- break
557- else :
558- logger .debug (f"Protected resource metadata discovery failed: { url } " )
576+ if prm :
577+ break
578+ else :
579+ logger .debug (f"Protected resource metadata discovery failed: { url } " )
559580
581+ async with self .context .lock :
560582 asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls (
561583 self .context .auth_server_url , self .context .server_url
562584 )
563585
564- # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
565- for url in asm_discovery_urls : # pragma: no branch
566- oauth_metadata_request = create_oauth_metadata_request (url )
567- oauth_metadata_response = yield oauth_metadata_request
586+ # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
587+ for url in asm_discovery_urls : # pragma: no branch
588+ oauth_metadata_request = create_oauth_metadata_request (url )
589+ oauth_metadata_response = yield oauth_metadata_request
568590
591+ async with self .context .lock :
569592 ok , asm = await handle_auth_metadata_response (oauth_metadata_response )
570593 if not ok :
571594 break
@@ -575,67 +598,84 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
575598 else :
576599 logger .debug (f"OAuth metadata discovery failed: { url } " )
577600
578- # Step 3: Apply scope selection strategy
601+ # Step 3: Apply scope selection strategy
602+ async with self .context .lock :
579603 self .context .client_metadata .scope = get_client_metadata_scopes (
580604 extract_scope_from_www_auth (response ),
581605 self .context .protected_resource_metadata ,
582606 self .context .oauth_metadata ,
583607 )
584608
585- # Step 4: Register client or use URL-based client ID (CIMD)
586- if not self .context .client_info :
587- if should_use_client_metadata_url (
588- self .context .oauth_metadata , self .context .client_metadata_url
589- ):
590- # Use URL-based client ID (CIMD)
609+ # Step 4: Register client or use URL-based client ID (CIMD)
610+ async with self .context .lock :
611+ need_registration = not self .context .client_info
612+ use_cimd = need_registration and should_use_client_metadata_url (
613+ self .context .oauth_metadata , self .context .client_metadata_url
614+ )
615+
616+ if need_registration :
617+ if use_cimd :
618+ # Use URL-based client ID (CIMD)
619+ async with self .context .lock :
591620 logger .debug (f"Using URL-based client ID (CIMD): { self .context .client_metadata_url } " )
592621 client_information = create_client_info_from_metadata_url (
593622 self .context .client_metadata_url , # type: ignore[arg-type]
594623 redirect_uris = self .context .client_metadata .redirect_uris ,
595624 )
596625 self .context .client_info = client_information
597626 await self .context .storage .set_client_info (client_information )
598- else :
599- # Fallback to Dynamic Client Registration
627+ else :
628+ # Fallback to Dynamic Client Registration
629+ async with self .context .lock :
600630 registration_request = create_client_registration_request (
601631 self .context .oauth_metadata ,
602632 self .context .client_metadata ,
603633 self .context .get_authorization_base_url (self .context .server_url ),
604634 )
605- registration_response = yield registration_request
635+ registration_response = yield registration_request
636+ async with self .context .lock :
606637 client_information = await handle_registration_response (registration_response )
607638 self .context .client_info = client_information
608639 await self .context .storage .set_client_info (client_information )
609640
610- # Step 5: Perform authorization and complete token exchange
611- token_response = yield await self ._perform_authorization ()
641+ # Step 5: Perform authorization and complete token exchange
642+ async with self .context .lock :
643+ authorization_request = await self ._perform_authorization ()
644+ token_response = yield authorization_request
645+ async with self .context .lock :
612646 await self ._handle_token_response (token_response )
613- except Exception : # pragma: no cover
614- logger .exception ("OAuth flow error" )
615- raise
647+ except Exception : # pragma: no cover
648+ logger .exception ("OAuth flow error" )
649+ raise
616650
617- # Retry with new tokens
651+ # Retry with new tokens
652+ async with self .context .lock :
618653 self ._add_auth_header (request )
619- yield request
620- elif response .status_code == 403 :
621- # Step 1: Extract error field from WWW-Authenticate header
622- error = extract_field_from_www_auth (response , "error" )
623-
624- # Step 2: Check if we need to step-up authorization
625- if error == "insufficient_scope" : # pragma: no branch
626- try :
627- # Step 2a: Update the required scopes
654+ yield request
655+ elif response .status_code == 403 :
656+ # Step 1: Extract error field from WWW-Authenticate header
657+ error = extract_field_from_www_auth (response , "error" )
658+
659+ # Step 2: Check if we need to step-up authorization
660+ if error == "insufficient_scope" : # pragma: no branch
661+ try :
662+ # Step 2a: Update the required scopes
663+ async with self .context .lock :
628664 self .context .client_metadata .scope = get_client_metadata_scopes (
629665 extract_scope_from_www_auth (response ), self .context .protected_resource_metadata
630666 )
631667
632- # Step 2b: Perform (re-)authorization and token exchange
633- token_response = yield await self ._perform_authorization ()
668+ # Step 2b: Perform (re-)authorization and token exchange
669+ async with self .context .lock :
670+ authorization_request = await self ._perform_authorization ()
671+ token_response = yield authorization_request
672+ async with self .context .lock :
634673 await self ._handle_token_response (token_response )
635- except Exception : # pragma: no cover
636- logger .exception ("OAuth flow error" )
637- raise
674+ except Exception : # pragma: no cover
675+ logger .exception ("OAuth flow error" )
676+ raise
638677
639- # Retry with new tokens
678+ # Retry with new tokens
679+ async with self .context .lock :
640680 self ._add_auth_header (request )
641- yield request
681+ yield request
0 commit comments