@@ -249,6 +249,35 @@ def get_all_awaited_by(pid):
249249 raise RuntimeError ("Failed to get all awaited_by after retries" )
250250
251251
252+ def _get_stack_trace_with_retry (unwinder , timeout = SHORT_TIMEOUT ):
253+ """Get stack trace from an existing unwinder with retry for transient errors.
254+
255+ This handles the case where we want to reuse an existing RemoteUnwinder
256+ instance but still handle transient failures like "Failed to parse initial
257+ frame in chain" that can occur when sampling at an inopportune moment.
258+
259+ Args:
260+ unwinder: An existing RemoteUnwinder instance
261+ timeout: Maximum time to retry (default SHORT_TIMEOUT)
262+
263+ Returns:
264+ The stack trace result from unwinder.get_stack_trace()
265+
266+ Raises:
267+ RuntimeError: If all retry attempts fail
268+ """
269+ last_error = None
270+ for _ in busy_retry (timeout ):
271+ try :
272+ return unwinder .get_stack_trace ()
273+ except (OSError , RuntimeError ) as e :
274+ last_error = e
275+ continue
276+ raise RuntimeError (
277+ f"Failed to get stack trace after retries: { last_error } "
278+ )
279+
280+
252281# ============================================================================
253282# Base test class with shared infrastructure
254283# ============================================================================
@@ -1704,16 +1733,16 @@ def main_work():
17041733
17051734 # Get stack trace with all threads
17061735 unwinder_all = RemoteUnwinder (p .pid , all_threads = True )
1707- for _ in range ( MAX_TRIES ):
1708- all_traces = unwinder_all . get_stack_trace ()
1709- found = self . _find_frame_in_trace (
1710- all_traces ,
1711- lambda f : f . funcname == "main_work"
1712- and f . location . lineno > 12 ,
1713- )
1714- if found :
1715- break
1716- time . sleep ( 0.1 )
1736+ for _ in busy_retry ( SHORT_TIMEOUT ):
1737+ with contextlib . suppress ( OSError , RuntimeError ):
1738+ all_traces = unwinder_all . get_stack_trace ()
1739+ found = self . _find_frame_in_trace (
1740+ all_traces ,
1741+ lambda f : f . funcname == "main_work"
1742+ and f . location . lineno > 12 ,
1743+ )
1744+ if found :
1745+ break
17171746 else :
17181747 self .fail (
17191748 "Main thread did not start its busy work on time"
@@ -1723,7 +1752,7 @@ def main_work():
17231752 unwinder_gil = RemoteUnwinder (
17241753 p .pid , only_active_thread = True
17251754 )
1726- gil_traces = unwinder_gil . get_stack_trace ( )
1755+ gil_traces = _get_stack_trace_with_retry ( unwinder_gil )
17271756
17281757 # Count threads
17291758 total_threads = sum (
@@ -1998,15 +2027,15 @@ def busy():
19982027 mode = mode ,
19992028 skip_non_matching_threads = False ,
20002029 )
2001- for _ in range (MAX_TRIES ):
2002- traces = unwinder .get_stack_trace ()
2003- statuses = self ._get_thread_statuses (traces )
2030+ for _ in busy_retry (SHORT_TIMEOUT ):
2031+ with contextlib .suppress (OSError , RuntimeError ):
2032+ traces = unwinder .get_stack_trace ()
2033+ statuses = self ._get_thread_statuses (traces )
20042034
2005- if check_condition (
2006- statuses , sleeper_tid , busy_tid
2007- ):
2008- break
2009- time .sleep (0.5 )
2035+ if check_condition (
2036+ statuses , sleeper_tid , busy_tid
2037+ ):
2038+ break
20102039
20112040 return statuses , sleeper_tid , busy_tid
20122041 finally :
@@ -2150,29 +2179,29 @@ def busy_thread():
21502179 mode = PROFILING_MODE_ALL ,
21512180 skip_non_matching_threads = False ,
21522181 )
2153- for _ in range ( MAX_TRIES ):
2154- traces = unwinder . get_stack_trace ()
2155- statuses = self . _get_thread_statuses ( traces )
2156-
2157- # Check ALL mode provides both GIL and CPU info
2158- if (
2159- sleeper_tid in statuses
2160- and busy_tid in statuses
2161- and not (
2162- statuses [ sleeper_tid ]
2163- & THREAD_STATUS_ON_CPU
2164- )
2165- and not (
2166- statuses [ sleeper_tid ]
2167- & THREAD_STATUS_HAS_GIL
2168- )
2169- and ( statuses [ busy_tid ] & THREAD_STATUS_ON_CPU )
2170- and (
2171- statuses [ busy_tid ] & THREAD_STATUS_HAS_GIL
2172- )
2173- ):
2174- break
2175- time . sleep ( 0.5 )
2182+ for _ in busy_retry ( SHORT_TIMEOUT ):
2183+ with contextlib . suppress ( OSError , RuntimeError ):
2184+ traces = unwinder . get_stack_trace ( )
2185+ statuses = self . _get_thread_statuses ( traces )
2186+
2187+ # Check ALL mode provides both GIL and CPU info
2188+ if (
2189+ sleeper_tid in statuses
2190+ and busy_tid in statuses
2191+ and not (
2192+ statuses [ sleeper_tid ]
2193+ & THREAD_STATUS_ON_CPU
2194+ )
2195+ and not (
2196+ statuses [ sleeper_tid ]
2197+ & THREAD_STATUS_HAS_GIL
2198+ )
2199+ and (statuses [ busy_tid ] & THREAD_STATUS_ON_CPU )
2200+ and (
2201+ statuses [ busy_tid ] & THREAD_STATUS_HAS_GIL
2202+ )
2203+ ):
2204+ break
21762205
21772206 self .assertIsNotNone (
21782207 sleeper_tid , "Sleeper thread id not received"
@@ -2296,18 +2325,18 @@ def test_thread_status_exception_detection(self):
22962325 mode = PROFILING_MODE_ALL ,
22972326 skip_non_matching_threads = False ,
22982327 )
2299- for _ in range ( MAX_TRIES ):
2300- traces = unwinder . get_stack_trace ()
2301- statuses = self . _get_thread_statuses ( traces )
2302-
2303- if (
2304- exception_tid in statuses
2305- and normal_tid in statuses
2306- and ( statuses [ exception_tid ] & THREAD_STATUS_HAS_EXCEPTION )
2307- and not (statuses [normal_tid ] & THREAD_STATUS_HAS_EXCEPTION )
2308- ):
2309- break
2310- time . sleep ( 0.5 )
2328+ for _ in busy_retry ( SHORT_TIMEOUT ):
2329+ with contextlib . suppress ( OSError , RuntimeError ):
2330+ traces = unwinder . get_stack_trace ( )
2331+ statuses = self . _get_thread_statuses ( traces )
2332+
2333+ if (
2334+ exception_tid in statuses
2335+ and normal_tid in statuses
2336+ and (statuses [exception_tid ] & THREAD_STATUS_HAS_EXCEPTION )
2337+ and not ( statuses [ normal_tid ] & THREAD_STATUS_HAS_EXCEPTION )
2338+ ):
2339+ break
23112340
23122341 self .assertIn (exception_tid , statuses )
23132342 self .assertIn (normal_tid , statuses )
@@ -2339,18 +2368,18 @@ def test_thread_status_exception_mode_filtering(self):
23392368 mode = PROFILING_MODE_EXCEPTION ,
23402369 skip_non_matching_threads = True ,
23412370 )
2342- for _ in range ( MAX_TRIES ):
2343- traces = unwinder . get_stack_trace ()
2344- statuses = self . _get_thread_statuses ( traces )
2345-
2346- if exception_tid in statuses :
2347- self . assertNotIn (
2348- normal_tid ,
2349- statuses ,
2350- "Normal thread should be filtered out in exception mode" ,
2351- )
2352- return
2353- time . sleep ( 0.5 )
2371+ for _ in busy_retry ( SHORT_TIMEOUT ):
2372+ with contextlib . suppress ( OSError , RuntimeError ):
2373+ traces = unwinder . get_stack_trace ( )
2374+ statuses = self . _get_thread_statuses ( traces )
2375+
2376+ if exception_tid in statuses :
2377+ self . assertNotIn (
2378+ normal_tid ,
2379+ statuses ,
2380+ "Normal thread should be filtered out in exception mode" ,
2381+ )
2382+ return
23542383
23552384 self .fail ("Never found exception thread in exception mode" )
23562385
@@ -2504,18 +2533,17 @@ def _check_exception_status(self, p, thread_tid, expect_exception):
25042533
25052534 # Collect multiple samples for reliability
25062535 results = []
2507- for _ in range (MAX_TRIES ):
2508- traces = unwinder .get_stack_trace ()
2509- statuses = self ._get_thread_statuses (traces )
2510-
2511- if thread_tid in statuses :
2512- has_exc = bool (statuses [thread_tid ] & THREAD_STATUS_HAS_EXCEPTION )
2513- results .append (has_exc )
2536+ for _ in busy_retry (SHORT_TIMEOUT ):
2537+ with contextlib .suppress (OSError , RuntimeError ):
2538+ traces = unwinder .get_stack_trace ()
2539+ statuses = self ._get_thread_statuses (traces )
25142540
2515- if len (results ) >= 3 :
2516- break
2541+ if thread_tid in statuses :
2542+ has_exc = bool (statuses [thread_tid ] & THREAD_STATUS_HAS_EXCEPTION )
2543+ results .append (has_exc )
25172544
2518- time .sleep (0.2 )
2545+ if len (results ) >= 3 :
2546+ break
25192547
25202548 # Check majority of samples match expected
25212549 if not results :
0 commit comments