@@ -140,15 +140,19 @@ class _TrackedChannelState:
140140class _InFlightChannelTracker :
141141 def __init__ (self ):
142142 self ._lock = Lock ()
143- self ._states : dict [int , _TrackedChannelState ] = {}
143+ # Keyed on the channel itself; gRPC channels are hashable by identity
144+ # and we keep a strong reference via _TrackedChannelState so reuse-after-
145+ # GC isn't a concern. Using the channel directly (instead of ``id()``)
146+ # makes the invariant local to this class rather than something a
147+ # reader has to verify by tracing _TrackedChannelState lifetimes.
148+ self ._states : dict [Any , _TrackedChannelState ] = {}
144149
145150 def acquire (self , channel : Any ):
146- channel_key = id (channel )
147151 with self ._lock :
148- state = self ._states .get (channel_key )
152+ state = self ._states .get (channel )
149153 if state is None :
150154 state = _TrackedChannelState (channel = channel )
151- self ._states [channel_key ] = state
155+ self ._states [channel ] = state
152156 state .ref_count += 1
153157
154158 released = False
@@ -161,26 +165,25 @@ def release() -> None:
161165
162166 channel_to_close = None
163167 with self ._lock :
164- state = self ._states .get (channel_key )
168+ state = self ._states .get (channel )
165169 if state is None :
166170 return
167171
168172 state .ref_count -= 1
169173 if state .ref_count == 0 :
170174 if state .close_when_released :
171175 channel_to_close = state .channel
172- del self ._states [channel_key ]
176+ del self ._states [channel ]
173177
174178 if channel_to_close is not None :
175179 self ._close_channel (channel_to_close )
176180
177181 return release
178182
179183 def retire (self , channel : Any ) -> None :
180- channel_key = id (channel )
181184 channel_to_close = None
182185 with self ._lock :
183- state = self ._states .get (channel_key )
186+ state = self ._states .get (channel )
184187 if state is None :
185188 channel_to_close = channel
186189 else :
@@ -533,6 +536,10 @@ def __init__(
533536 self ._shutdown = Event ()
534537 self ._is_running = False
535538 self ._channel = channel
539+ # The SDK owns (and may recreate) the gRPC channel only when the caller
540+ # did not provide one. Mirrors ``TaskHubGrpcClient._owns_channel`` so
541+ # both files use the same name for the same concept.
542+ self ._owns_channel = channel is None
536543 self ._secure_channel = secure_channel
537544 self ._payload_store = payload_store
538545 self ._channel_options = channel_options
@@ -598,9 +605,6 @@ def _should_count_worker_failure(
598605 ) -> bool :
599606 return is_worker_transport_failure (status_code )
600607
601- def _can_recreate_channel (self ) -> bool :
602- return self ._channel is None
603-
604608 def add_orchestrator (self , fn : task .Orchestrator [TInput , TOutput ]) -> str :
605609 """Registers an orchestrator function with the worker."""
606610 if self ._is_running :
@@ -742,16 +746,7 @@ def create_fresh_connection():
742746 current_stub = None
743747 raise
744748
745- def wrap_execution (handler , release ):
746- def wrapped (* args , ** kwargs ):
747- try :
748- return handler (* args , ** kwargs )
749- finally :
750- release ()
751-
752- return wrapped
753-
754- def wrap_cancellation (handler , release ):
749+ def wrap_with_release (handler , release ):
755750 def wrapped (* args , ** kwargs ):
756751 try :
757752 return handler (* args , ** kwargs )
@@ -772,8 +767,8 @@ def submit_work_item(
772767 release = in_flight_channel_tracker .acquire (channel )
773768 try :
774769 submit_func (
775- wrap_execution (handler , release ),
776- wrap_cancellation (cancellation_handler , release ),
770+ wrap_with_release (handler , release ),
771+ wrap_with_release (cancellation_handler , release ),
777772 request ,
778773 stub ,
779774 completion_token ,
@@ -808,7 +803,7 @@ def invalidate_connection(
808803
809804 if (
810805 current_channel is not None
811- and self ._can_recreate_channel ()
806+ and self ._owns_channel
812807 and (recreate_channel or close_channel )
813808 ):
814809 in_flight_channel_tracker .retire (current_channel )
@@ -837,7 +832,7 @@ def should_invalidate_connection(rpc_error):
837832 if self ._should_count_worker_failure (error_code ):
838833 recreate_channel = (
839834 failure_tracker .record_failure ()
840- and self ._can_recreate_channel ()
835+ and self ._owns_channel
841836 )
842837 invalidate_connection (recreate_channel = recreate_channel )
843838 conn_retry_count += 1
@@ -995,7 +990,7 @@ def stream_reader():
995990 )
996991 recreate_channel = (
997992 failure_tracker .record_failure ()
998- and self ._can_recreate_channel ()
993+ and self ._owns_channel
999994 )
1000995 invalidate_connection (recreate_channel = recreate_channel )
1001996 conn_retry_count += 1
@@ -1010,7 +1005,7 @@ def stream_reader():
10101005 if should_invalidate and self ._should_count_worker_failure (error_code ):
10111006 recreate_channel = (
10121007 failure_tracker .record_failure ()
1013- and self ._can_recreate_channel ()
1008+ and self ._owns_channel
10141009 )
10151010 if should_invalidate :
10161011 invalidate_connection (recreate_channel = recreate_channel )
@@ -2893,6 +2888,7 @@ def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logg
28932888 self ._pending_orchestration_work : list = []
28942889 self ._pending_entity_batch_work : list = []
28952890 self .thread_pool = self ._create_thread_pool ()
2891+ self ._pool_is_shutdown = False
28962892 self ._shutdown = False
28972893
28982894 def _create_thread_pool (self ) -> ThreadPoolExecutor :
@@ -2902,8 +2898,12 @@ def _create_thread_pool(self) -> ThreadPoolExecutor:
29022898 )
29032899
29042900 def _ensure_thread_pool (self ) -> None :
2905- if getattr (self .thread_pool , "_shutdown" , False ):
2901+ # Track the pool's shutdown state explicitly instead of reading
2902+ # ``ThreadPoolExecutor._shutdown`` (which is a CPython implementation
2903+ # detail and not part of ``concurrent.futures``'s public API).
2904+ if self ._pool_is_shutdown :
29062905 self .thread_pool = self ._create_thread_pool ()
2906+ self ._pool_is_shutdown = False
29072907
29082908 def prepare_for_run (self ) -> None :
29092909 self ._shutdown = False
@@ -3045,8 +3045,9 @@ async def run(self):
30453045 self ._logger .error (f"Uncaught error while cancelling entity batch work item: { cancellation_exception } " )
30463046 self .shutdown ()
30473047 finally :
3048- if not getattr ( self .thread_pool , "_shutdown" , False ) :
3048+ if not self ._pool_is_shutdown :
30493049 self .thread_pool .shutdown (wait = True )
3050+ self ._pool_is_shutdown = True
30503051
30513052 async def _consume_queue (self , queue : asyncio .Queue , semaphore : asyncio .Semaphore ):
30523053 # List to track running tasks
0 commit comments