Skip to content

Commit 345d941

Browse files
Bernd VerstCopilot
andcommitted
Address andystaples PR review feedback
- Make FailureTracker thread-safe with an internal lock so multi-threaded sync clients can't race the consecutive-failure counter (review [3/10]). - Track _AsyncWorkerManager pool shutdown via an explicit _pool_is_shutdown flag instead of reading ThreadPoolExecutor._shutdown (CPython private API, review [4/10]). - Collapse identical wrap_execution/wrap_cancellation closures in the worker stream loop into a single wrap_with_release helper (review [5/10]). - Promote the retired-channel close delay and jitter exponent cap to named module-level constants (review [7/10]). - Key _InFlightChannelTracker on the channel object instead of id(channel) so the lifetime invariant is local to the tracker (review [9/10]). - Rename TaskHubGrpcWorker._can_recreate_channel() to the existing _owns_channel attribute used by the clients, so both files use the same name for the same concept (review [2/10]). - Add regression tests for FailureTracker concurrency and for thread-pool recreation after manager shutdown. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 78829a0 commit 345d941

5 files changed

Lines changed: 106 additions & 37 deletions

File tree

durabletask/client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationStat
166166
failure_details)
167167

168168

169+
# Grace period before a retired SDK-owned channel is force-closed. Long enough
170+
# for in-flight unary RPCs to drain on their own, short enough that recreate
171+
# storms don't pile up dozens of half-closed channels.
172+
_RETIRED_CHANNEL_CLOSE_DELAY_SECONDS = 30.0
173+
174+
169175
class TaskHubGrpcClient:
170176
def __init__(self, *,
171177
host_address: Optional[str] = None,
@@ -264,7 +270,7 @@ def _maybe_recreate_channel(self) -> None:
264270
self._last_recreate_time = now
265271
self._client_failure_tracker.record_success()
266272
close_timer = threading.Timer(
267-
30.0,
273+
_RETIRED_CHANNEL_CLOSE_DELAY_SECONDS,
268274
self._close_retired_channel,
269275
args=(old_channel,),
270276
)
@@ -730,7 +736,7 @@ async def _maybe_recreate_channel(self) -> None:
730736

731737
async def _close_retired_channel(self, channel: grpc.aio.Channel) -> None:
732738
try:
733-
await asyncio.sleep(30.0)
739+
await asyncio.sleep(_RETIRED_CHANNEL_CLOSE_DELAY_SECONDS)
734740
await channel.close()
735741
finally:
736742
async with self._recreate_lock:

durabletask/internal/grpc_resiliency.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,59 @@
22
# Licensed under the MIT License.
33

44
import random
5-
from dataclasses import dataclass
5+
import threading
6+
from dataclasses import dataclass, field
67

78
import grpc
89

10+
# Sidecar RPCs that legitimately block on the server until an instance reaches a
11+
# terminal state. ``DEADLINE_EXCEEDED`` on these is the caller's chosen timeout
12+
# expiring rather than a transport failure, so we do not treat it as one.
913
LONG_POLL_METHODS = {"WaitForInstanceStart", "WaitForInstanceCompletion"}
1014

15+
# Cap the attempt number fed into ``2 ** attempt`` to keep the jitter calculation
16+
# bounded for callers that retry indefinitely; once we hit the cap, the upper
17+
# bound is fully governed by ``cap_seconds``.
18+
_MAX_JITTER_ATTEMPT_EXPONENT = 30
19+
1120

1221
def get_full_jitter_delay_seconds(
1322
attempt: int,
1423
*,
1524
base_seconds: float,
1625
cap_seconds: float,
1726
) -> float:
18-
capped_attempt = min(attempt, 30)
27+
capped_attempt = min(attempt, _MAX_JITTER_ATTEMPT_EXPONENT)
1928
upper_bound = min(cap_seconds, base_seconds * (2 ** capped_attempt))
2029
return random.random() * upper_bound
2130

2231

2332
@dataclass
2433
class FailureTracker:
34+
"""Counts consecutive transport failures with thread-safe mutation.
35+
36+
The sync ``TaskHubGrpcClient`` is commonly invoked from multiple worker
37+
threads, so ``record_failure``/``record_success`` need a lock to keep the
38+
increment-and-compare atomic. The async client only mutates this from a
39+
single event loop, but the extra lock has negligible cost on that path.
40+
"""
41+
2542
threshold: int
2643
consecutive_failures: int = 0
44+
_lock: threading.Lock = field(
45+
default_factory=threading.Lock, init=False, repr=False, compare=False
46+
)
2747

2848
def record_failure(self) -> bool:
2949
if self.threshold <= 0:
3050
return False
31-
self.consecutive_failures += 1
32-
return self.consecutive_failures >= self.threshold
51+
with self._lock:
52+
self.consecutive_failures += 1
53+
return self.consecutive_failures >= self.threshold
3354

3455
def record_success(self) -> None:
35-
self.consecutive_failures = 0
56+
with self._lock:
57+
self.consecutive_failures = 0
3658

3759

3860
def is_client_transport_failure(method_name: str, status_code: grpc.StatusCode) -> bool:

durabletask/worker.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,19 @@ class _TrackedChannelState:
140140
class _InFlightChannelTracker:
141141
def __init__(self):
142142
self._lock = Lock()
143-
self._states: dict[int, _TrackedChannelState] = {}
143+
# Keyed on the channel itself; gRPC channels are hashable by identity
144+
# and we keep a strong reference via _TrackedChannelState so reuse-after-
145+
# GC isn't a concern. Using the channel directly (instead of ``id()``)
146+
# makes the invariant local to this class rather than something a
147+
# reader has to verify by tracing _TrackedChannelState lifetimes.
148+
self._states: dict[Any, _TrackedChannelState] = {}
144149

145150
def acquire(self, channel: Any):
146-
channel_key = id(channel)
147151
with self._lock:
148-
state = self._states.get(channel_key)
152+
state = self._states.get(channel)
149153
if state is None:
150154
state = _TrackedChannelState(channel=channel)
151-
self._states[channel_key] = state
155+
self._states[channel] = state
152156
state.ref_count += 1
153157

154158
released = False
@@ -161,26 +165,25 @@ def release() -> None:
161165

162166
channel_to_close = None
163167
with self._lock:
164-
state = self._states.get(channel_key)
168+
state = self._states.get(channel)
165169
if state is None:
166170
return
167171

168172
state.ref_count -= 1
169173
if state.ref_count == 0:
170174
if state.close_when_released:
171175
channel_to_close = state.channel
172-
del self._states[channel_key]
176+
del self._states[channel]
173177

174178
if channel_to_close is not None:
175179
self._close_channel(channel_to_close)
176180

177181
return release
178182

179183
def retire(self, channel: Any) -> None:
180-
channel_key = id(channel)
181184
channel_to_close = None
182185
with self._lock:
183-
state = self._states.get(channel_key)
186+
state = self._states.get(channel)
184187
if state is None:
185188
channel_to_close = channel
186189
else:
@@ -533,6 +536,10 @@ def __init__(
533536
self._shutdown = Event()
534537
self._is_running = False
535538
self._channel = channel
539+
# The SDK owns (and may recreate) the gRPC channel only when the caller
540+
# did not provide one. Mirrors ``TaskHubGrpcClient._owns_channel`` so
541+
# both files use the same name for the same concept.
542+
self._owns_channel = channel is None
536543
self._secure_channel = secure_channel
537544
self._payload_store = payload_store
538545
self._channel_options = channel_options
@@ -598,9 +605,6 @@ def _should_count_worker_failure(
598605
) -> bool:
599606
return is_worker_transport_failure(status_code)
600607

601-
def _can_recreate_channel(self) -> bool:
602-
return self._channel is None
603-
604608
def add_orchestrator(self, fn: task.Orchestrator[TInput, TOutput]) -> str:
605609
"""Registers an orchestrator function with the worker."""
606610
if self._is_running:
@@ -742,16 +746,7 @@ def create_fresh_connection():
742746
current_stub = None
743747
raise
744748

745-
def wrap_execution(handler, release):
746-
def wrapped(*args, **kwargs):
747-
try:
748-
return handler(*args, **kwargs)
749-
finally:
750-
release()
751-
752-
return wrapped
753-
754-
def wrap_cancellation(handler, release):
749+
def wrap_with_release(handler, release):
755750
def wrapped(*args, **kwargs):
756751
try:
757752
return handler(*args, **kwargs)
@@ -772,8 +767,8 @@ def submit_work_item(
772767
release = in_flight_channel_tracker.acquire(channel)
773768
try:
774769
submit_func(
775-
wrap_execution(handler, release),
776-
wrap_cancellation(cancellation_handler, release),
770+
wrap_with_release(handler, release),
771+
wrap_with_release(cancellation_handler, release),
777772
request,
778773
stub,
779774
completion_token,
@@ -808,7 +803,7 @@ def invalidate_connection(
808803

809804
if (
810805
current_channel is not None
811-
and self._can_recreate_channel()
806+
and self._owns_channel
812807
and (recreate_channel or close_channel)
813808
):
814809
in_flight_channel_tracker.retire(current_channel)
@@ -837,7 +832,7 @@ def should_invalidate_connection(rpc_error):
837832
if self._should_count_worker_failure(error_code):
838833
recreate_channel = (
839834
failure_tracker.record_failure()
840-
and self._can_recreate_channel()
835+
and self._owns_channel
841836
)
842837
invalidate_connection(recreate_channel=recreate_channel)
843838
conn_retry_count += 1
@@ -995,7 +990,7 @@ def stream_reader():
995990
)
996991
recreate_channel = (
997992
failure_tracker.record_failure()
998-
and self._can_recreate_channel()
993+
and self._owns_channel
999994
)
1000995
invalidate_connection(recreate_channel=recreate_channel)
1001996
conn_retry_count += 1
@@ -1010,7 +1005,7 @@ def stream_reader():
10101005
if should_invalidate and self._should_count_worker_failure(error_code):
10111006
recreate_channel = (
10121007
failure_tracker.record_failure()
1013-
and self._can_recreate_channel()
1008+
and self._owns_channel
10141009
)
10151010
if should_invalidate:
10161011
invalidate_connection(recreate_channel=recreate_channel)
@@ -2893,6 +2888,7 @@ def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logg
28932888
self._pending_orchestration_work: list = []
28942889
self._pending_entity_batch_work: list = []
28952890
self.thread_pool = self._create_thread_pool()
2891+
self._pool_is_shutdown = False
28962892
self._shutdown = False
28972893

28982894
def _create_thread_pool(self) -> ThreadPoolExecutor:
@@ -2902,8 +2898,12 @@ def _create_thread_pool(self) -> ThreadPoolExecutor:
29022898
)
29032899

29042900
def _ensure_thread_pool(self) -> None:
2905-
if getattr(self.thread_pool, "_shutdown", False):
2901+
# Track the pool's shutdown state explicitly instead of reading
2902+
# ``ThreadPoolExecutor._shutdown`` (which is a CPython implementation
2903+
# detail and not part of ``concurrent.futures``'s public API).
2904+
if self._pool_is_shutdown:
29062905
self.thread_pool = self._create_thread_pool()
2906+
self._pool_is_shutdown = False
29072907

29082908
def prepare_for_run(self) -> None:
29092909
self._shutdown = False
@@ -3045,8 +3045,9 @@ async def run(self):
30453045
self._logger.error(f"Uncaught error while cancelling entity batch work item: {cancellation_exception}")
30463046
self.shutdown()
30473047
finally:
3048-
if not getattr(self.thread_pool, "_shutdown", False):
3048+
if not self._pool_is_shutdown:
30493049
self.thread_pool.shutdown(wait=True)
3050+
self._pool_is_shutdown = True
30503051

30513052
async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore):
30523053
# List to track running tasks

tests/durabletask/test_grpc_resiliency.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,26 @@ def test_failure_tracker_threshold_zero_never_trips():
147147
assert tracker.consecutive_failures == 0
148148

149149

150+
def test_failure_tracker_record_failure_is_thread_safe():
151+
import threading
152+
153+
tracker = FailureTracker(threshold=10_000)
154+
iterations = 500
155+
workers = 8
156+
157+
def increment() -> None:
158+
for _ in range(iterations):
159+
tracker.record_failure()
160+
161+
threads = [threading.Thread(target=increment) for _ in range(workers)]
162+
for thread in threads:
163+
thread.start()
164+
for thread in threads:
165+
thread.join()
166+
167+
assert tracker.consecutive_failures == iterations * workers
168+
169+
150170
@pytest.mark.parametrize(
151171
"method_name",
152172
[

tests/durabletask/test_worker_resiliency.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,26 @@ async def test_async_worker_manager_honors_shutdown_requested_before_run():
131131
await asyncio.wait_for(manager.run(), timeout=1.0)
132132

133133

134+
@pytest.mark.asyncio
135+
async def test_async_worker_manager_recreates_thread_pool_after_run():
136+
manager = _AsyncWorkerManager(
137+
ConcurrencyOptions(maximum_thread_pool_workers=1),
138+
MagicMock(),
139+
)
140+
141+
original_pool = manager.thread_pool
142+
143+
manager.shutdown()
144+
await asyncio.wait_for(manager.run(), timeout=1.0)
145+
146+
assert manager._pool_is_shutdown is True
147+
148+
manager.prepare_for_run()
149+
150+
assert manager._pool_is_shutdown is False
151+
assert manager.thread_pool is not original_pool
152+
153+
134154
def test_worker_start_clears_prior_shutdown_request():
135155
worker = TaskHubGrpcWorker()
136156
worker._shutdown.set()
@@ -190,7 +210,7 @@ def test_worker_counts_only_transport_failures_for_recreation():
190210

191211
def test_worker_does_not_recreate_caller_owned_channel():
192212
worker = TaskHubGrpcWorker(channel=MagicMock())
193-
assert worker._can_recreate_channel() is False
213+
assert worker._owns_channel is False
194214

195215

196216
@pytest.mark.asyncio

0 commit comments

Comments
 (0)