CP Tests batching using subprocess worker pool#2993
Conversation
The existing test path spawns one torchrun per parametrized case, paying NCCL init + CUDA context + Python startup on every call. With ~hundreds of cases the launch overhead dominates wall time and was a primary driver of the L3 timeout that prior batching PRs worked around. This change replaces the per-case subprocess with one long-lived torchrun per (world_size). NCCL is initialized once at session start and reused across cases. Pytest sends one JSON request per case over rank-0 stdin; the worker dispatches to run_dpa_with_cp(**kwargs), gathers (ok, error) from every rank, and writes one JSON response on rank-0 stdout. run_attention_with_cp.py is left almost untouched; a new NVTE_CP_POOL_PG=1 env var gates the dist.init_process_group() and dist.destroy_process_group() calls so the function reuses the pool's main PG instead of creating its own. The per-case cp_comm_group (and a2a+p2p sub-groups) are explicitly destroyed at function exit to prevent communicator leakage across cases. The PoolWorker class adds two pieces of error recovery that the prior subprocess-per-case design got for free: a select-based per-call timeout (default 600s, NVTE_CP_POOL_TIMEOUT_SEC) and auto-respawn on worker death or timeout. A test-level exception is reported as an AssertionError and the pool keeps running for the next case. Two pool sizes are needed because cp_comm_type='a2a+p2p' requires world_size=4 and the others use world_size=2; you can't resize an active PG. Pools are spawned lazily so a 2-GPU-only run never pays the 4-GPU init. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Two resilience fixes carried over from the existing batching PR
(sudhakars/cp_test_batching_pr) without which the pool will
cascade-fail FP8 tests and silently propagate NCCL desync.
1. FP8GlobalStateManager.reset() between cases. FP8 quantizer state
(recipe handles, autocast counters) lives in module-level globals.
Reusing one Python process across cases otherwise carries that state
forward. The prior batching PR landed an explicit fix for the same
issue ("Fix FP8 cascade failures") after observing real test
failures from this.
2. dist.barrier() after each case. If one rank's case errored before
its last collective, the others can be stuck waiting on a comm that
will never complete. The barrier here surfaces that immediately as
a timeout in this case rather than letting the corruption leak into
the next case's collectives.
Also pops the transient NVTE_* env vars run_dpa_with_cp sets at the
top of each call. run_dpa_with_cp already sets them unconditionally so
this is defensive, but cheap insurance against future variants that
might not.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The model_configs_{flash,fused}_attn dicts are module-level and shared
across pool cases. The THD branch below rewrites config.attn_mask_type
in place (causal -> padding_causal, no_mask -> padding). With the
persistent-pool runner, the next case looking up the same model key
gets the mutated config and fails the "causal or no_mask only" assert.
Caught at benchmark time on cp_2_0 + thd, identical to the cascade the
existing batching PR (sudhakars/cp_test_batching_pr) hit and fixed the
same way in commit 6355f62.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Mirrors the two pre-emptive skips on the PR-batching branch: * non-vanilla softmax with FusedAttention is not deterministic * post_scale_bias with requires_grad is not deterministic Without these skips, the corresponding configs propagate into the pool worker under NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 and fail inside run_dpa_with_cp instead of being marked SKIPPED. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The pool worker reused RNG state across cases, which produced small numerical drift on some non-FP8 fused-attention configs (cp_1_0 + thd/p2p, cp_1_0 + sbhd/all_gather) compared to the single-shot worker. Matches the per-case startup of the single-shot worker: torch.manual_seed(1234) + torch.cuda.manual_seed(1234) at the start of every case, alongside the existing FP8 / env / cache resets. Moved the reset call from the post-case finally block to the start of _run_one so the first case is also seeded consistently with subsequent cases. Otherwise the first case would inherit the process-default RNG and only the second-and-later cases would be deterministic. Validated locally: 38 passed, 0 failed (was 36 passed, 2 failed). Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR replaces per-test
Confidence Score: 5/5Safe to merge. The stream-race fix is correct, the pool protocol is well-guarded with sentinel prefixes and a 90s per-case timeout, and test_essential is left as True. The stream-race fix in context_parallel.py is a precise, well-commented one-liner that inserts the missing wait_stream at exactly the right point — no-op for the default-stream case (i=1) and fires only when reading across stream boundaries (i=2). The pool infrastructure handles all three terminal outcomes (response, timeout, worker death) with appropriate cleanup and respawn. Cross-rank failure detail via gather_object is strictly better than the all_reduce-min approach. No new uncovered failure modes were found beyond those already discussed in earlier review threads. No files require special attention. Important Files Changed
Reviews (9): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| def _send_response(rank: int, payload: dict) -> None: | ||
| if rank == 0: | ||
| sys.stdout.write(json.dumps(payload) + "\n") | ||
| sys.stdout.flush() |
There was a problem hiding this comment.
stdout pollution can silently corrupt the JSON protocol
torchrun (and the worker processes it spawns for ranks 1–N) all inherit the same stdout file descriptor as rank 0. If torchrun writes any status line to stdout, or if any non-rank-0 worker accidentally prints (e.g. via a print call in a library, NCCL debug output, or a Python warning), those bytes are interleaved with rank 0's JSON responses. The parent's readline() in PoolWorker.submit would then receive a non-JSON line and raise a json.JSONDecodeError, killing the pool and failing the test with a misleading error message.
Consider redirecting torchrun's own output or adding a sentinel prefix to every response line so the reader can skip unrecognised lines.
There was a problem hiding this comment.
Implemented the stronger fix in 36b65fb3: redirect non-rank-0 stdout to /dev/null at the fd level via dup2 at worker startup. Catches both Python print and C-level (NCCL, libc, etc.) writes — the sentinel-prefix approach could only mitigate the former by scanning + skipping.
With non-rank-0 stdout silenced, rank 0's JSON line is the only thing that reaches the parent's pipe, so the _RESP_PREFIX machinery + the sentinel-scanning while loop in PoolWorker._submit_once are gone. The reader collapses to one select + one readline + one json.loads.
Validated on 8×H100 (test_essential=True, flash-attn): 9 passed / 55 skipped / 0 failed in 56 s; no JSONDecodeError, no protocol corruption. Closes the M2 follow-up listed in the PR description.
Three changes that bring the pool's failure semantics on par with the per-batch torchrun approach in PR NVIDIA#2965 and remove a couple of footguns: 1. Capture pool-worker stderr into a ring buffer and attach the tail to crash-path AssertionErrors. Equivalent in spirit to PR NVIDIA#2965's run_distributed() — CI JUnit XML now shows the actual cause (NCCL error, Python traceback, OOM) inline with the failing test, instead of just "pool worker died mid-request" / "timed out". A daemon drainer thread reads stderr line-by-line into a deque(maxlen=200) and also echoes to sys.stderr so pytest's per-test capture still gets every line. Maximum buffered footprint ~40 KB. 2. Tighten POOL_SUBMIT_TIMEOUT_SEC default 600 -> 90. On H100 the slowest observed per-case wall is ~15 s (p99 also 15 s, p50 ~5 s). 90 s gives ~6x headroom over the worst observed case while still detecting a genuine hang within ~1.5 min instead of ~10 min. Env var still overrides for slower machines or expanded test matrices. 3. Optional per-case wall-time logging (NVTE_CP_POOL_TIMING=1) prints "[POOL-TIMING] case_idx=N world_size=W wall_s=X.XXX ok=B" to stderr on rank 0 only. Grep-friendly; lets future tuning recalibrate the timeout against the observed distribution. Off by default so normal runs stay quiet. Validated: 38 passed / 0 failed in 248 s on H100, test_essential=True, with no perf regression vs the un-patched 256 s. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Three fixes responding to NVIDIA#2993 review comments: P1: NCCL communicator leak on exception (run_attention_with_cp.py) run_dpa_with_cp() created cp_comm_group (and optionally cp_comm_sub_groups) near the top, but the destroy_process_group() calls ran only on the success path at the end of the function. Any exception in between (tensor assertion, OOM, NCCL error) skipped the cleanup, leaking communicators in pool mode. Long sessions with repeated failures could exhaust NCCL internal tracking. Wrap the test work in try/finally so the destroy logic always runs. Initialise cp_comm_sub_groups = [] unconditionally so the finally block is safe even when cp_comm_type != "a2a+p2p" (or when an assert fires before the populate loop). Each destroy is itself try/except so a destroy failure on one group doesn't leak the others. P2: stdout protocol can be corrupted by interleaved chatter torchrun and ranks 1..N share rank 0's stdout fd. Any non-rank-0 print, NCCL debug line, or torchrun status output interleaves with the JSON response and breaks json.loads, killing the pool with a misleading "json decode error". Prefix every response with "[CP_POOL_RESP] " in run_attention_with_cp_pool.py and have PoolWorker.submit() scan stdout for sentinel-prefixed lines, echoing non-protocol lines to stderr for visibility. Bounded scan (MAX_NOISE_LINES=1000) so a chatty worker can't stall the parent. P2 (doc): select.select on a pipe fd is Linux/macOS only Added a short comment noting Windows portability. CP attention tests run on Linux GPU hosts; this is a documentation issue, not a real bug. Validated: 38 passed / 0 failed in 270 s on H100, test_essential=True (was 248 s pre-P2 — the +22 s is the new sentinel-scan loop's per-line overhead at ~600 ms/case, within noise). Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
|
Thanks for the review @greptile-apps — all three findings addressed in P1 — NCCL communicator leak on exception ( Wrapped the body of P2 — stdout protocol can be corrupted by interleaved chatter ( Worker now prefixes every response with P2 — Added a short comment noting the Windows portability gap. Tests target Linux GPU hosts only; this is a doc issue, not a real bug, but flagged for future readers. PR-body noteThe "default 600 s" inconsistency you flagged on the env-var table was already reconciled in the previous push — ValidationRe-ran The 5 state-cascade failures on |
|
All three fixes look correct. Let me note one small remaining issue in the deadline logic before signing off. In deadline = None
while scanned < self._MAX_NOISE_LINES:
remaining = timeout if deadline is None else max(0.0, deadline - time.monotonic())
ready, _, _ = select.select([self.proc.stdout], [], [], remaining)
if deadline is None:
deadline = time.monotonic() + timeout # ← set after first select completesIf the first line (noise or response) arrives after, say, 50 s, the deadline is then set to The fix is to set the deadline once, before the loop: deadline = time.monotonic() + timeout
while scanned < self._MAX_NOISE_LINES:
remaining = max(0.0, deadline - time.monotonic())
ready, _, _ = select.select([self.proc.stdout], [], [], remaining)
if not ready:
...In practice this won't matter much — torchrun status chatter arrives nearly instantaneously before the response — but the current code doesn't actually enforce the stated Everything else looks good:
The deadline fix above is the only remaining item; everything else is addressed cleanly. Tip: You can customize Greptile's behavior for this repo with |
| cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") | ||
| # Always defined so the finally cleanup below is safe even when cp_comm_type != "a2a+p2p". | ||
| cp_comm_sub_groups = [] | ||
| if cp_comm_type == "a2a+p2p": | ||
| assert world_size % 2 == 0, ( | ||
| "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size" | ||
| " = 2." | ||
| ) | ||
| cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] | ||
| cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] | ||
| cp_comm_sub_groups = [] | ||
| for sub_ranks in cp_comm_sub_ranks: | ||
| sub_group = dist.new_group(sub_ranks, backend="nccl") | ||
| if rank in sub_ranks: | ||
| cp_comm_sub_groups.append(sub_group) |
There was a problem hiding this comment.
Group creation sits outside the
try/finally guard
cp_comm_group (line 250) and the sub-group creation loop (lines 260–263) all run before the try block that starts at line 268. If the sub-group loop raises — e.g., NCCL rejects a new_group call on a partially recovered communicator after a previous case failed — the finally block is never entered, so cp_comm_group and any sub-groups already appended to cp_comm_sub_groups are never destroyed. In a long-running pool this accumulates a communicator leak on every such failure, eventually hitting NCCL's internal communicator limit and breaking unrelated subsequent cases. Moving all three dist.new_group sites inside the try block (or at minimum wrapping the sub-group loop and the unconditional cp_comm_group creation in a nested try/finally) would close the gap.
There was a problem hiding this comment.
Same fix as #3243445773 — the pool worker now owns the group lifecycle, so the dist.new_group calls live in the single-shot branch only (run_attention_with_cp.py:275-285, inside if not _reusing_pool_groups:). The original accumulation concern doesn't apply because the pool never creates per-case groups to leak; the single-shot path's subprocess exits at function return and the OS reclaims communicators. (01db9f5 + fc0d929)
cyanguwa
left a comment
There was a problem hiding this comment.
I think the structure is a lot simpler/cleaner than #2965 - thanks!
For the 5 remaining failing tests, I wonder if there's something else we need to reset for FP8 (please check with @ksivaman).
Also, please compare the number of tests before/after this PR and make sure we're still running the same number of tests! If the new, reduced runtime allows now, we can turn on test_essential=False, but I'll leave that to you.
If no major issues, I approve! Thanks!
| if _pool_managed_pg: | ||
| # Pool owns the main PG; only clean up groups created for this case. | ||
| try: | ||
| dist.destroy_process_group(cp_comm_group) |
There was a problem hiding this comment.
Do we want to destroy (and create earlier in the file) cp_comm_group for every config? I feel if the pool is the same, the world_size would be the same, and so is cp_comm_group?
There was a problem hiding this comment.
Yes — exactly that, fixed in this direction. The pool worker creates cp_comm_group once per world_size at startup (run_attention_with_cp_pool.py:154); run_dpa_with_cp in pool mode reads the cached pointer instead of dist.new_group'ing per call (run_attention_with_cp.py:271-273). Destruction is once-per-worker at shutdown (run_attention_with_cp_pool.py:182-193). The a2a+p2p sub-groups are pre-created the same way. This dropped per-case NCCL overhead and also closed the leak class greptile raised in #3243445773 / #3244111437. Thanks for the catch — clean simplification. (01db9f5 + fc0d929)
In AttnFuncWithCPAndKVAllGather.forward, max_logit_per_step[i] is written inside `with torch.cuda.stream(flash_attn_streams[i])`. For i=1, flash_attn_streams[1] is cp_stream — i.e. *not* the default stream. Later, at loop iteration i=2, the code reads max_logit_per_step[1] via `torch.maximum(max_logit, max_logit_per_step[i-1])` which runs on the default stream. Without an explicit wait_stream, this is a read-after-write race across streams. The post-loop `current_stream().wait_stream(cp_stream)` is too late — the race has already fired. The race is latent: outcome depends on stream scheduling. In a fresh-process subprocess (one-torchrun-per-test path), streams are cleanly initialised and timing happens to put the write before the read. In a long-running persistent-worker process — exposed by PR NVIDIA#2993's pool design — prior workloads shape stream state differently, the read can fire before the write completes, and max_logit ends up with stale values in some heads (~0.3 abs diff, 3/12 elements wrong on the H100 matrix). Fix: insert `current_stream().wait_stream(flash_attn_streams[i-1])` before the torch.maximum read. No-op when the streams are identical (i=1 case, where flash_attn_streams[0] is current_stream), only fires when reading from cp_stream (i=2 case). Validated: 8xH100, test_essential=False, 348 passed / 0 failed in 27m 10s (was 323 passed + 5 failed at this commit's parent, all 5 failing on cp_comm_type=all_gather with mismatched max_logit). The failing configs (all_gather + cp_1_0/cp_1_1 + bshd or fp16) now pass under the pool — confirming the race was the sole root cause. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Line-level cleanups from the second reviewer pass on PR NVIDIA#2993. Each item is dead/redundant; none changes behaviour. Full-matrix test_essential=False on 8xH100 still passes 348/0 in 26m 23s after these. run_attention_with_cp_pool.py: - Drop _TRANSIENT_ENV_KEYS tuple + pop loop. run_dpa_with_cp already re-sets NVTE_FUSED_ATTN/NVTE_FLASH_ATTN unconditionally at the top and pops the FP8 ones itself. The pop loop was defensive against a hypothetical "future caller that doesn't re-set them" that doesn't exist. - Drop gc.collect() after torch.cuda.empty_cache(). The cases create no Python reference cycles between iterations and empty_cache only frees CUDA blocks PyTorch already considers free; the combination was no-op here. - Drop dist.barrier() after dist.gather_object(). gather_object is itself a collective synchronization point — if every rank reaches it, none is ahead. The "surface a wedged communicator here" comment was wishful: a wedged communicator would already wedge the gather. test_attention_with_cp.py (PoolWorker): - Drop _MAX_NOISE_LINES = 1000 + the scanned counter + the unreachable post-loop "1000+ lines" branch. select()'s deadline already bounds the loop; the line-count cap was redundant and the over-limit branch was unreachable in practice. - Inline _stderr_tail() into _diag(). Single caller, single use. - Drop the _stderr_thread attribute. The drainer is daemon and self-terminates when the pipe closes; we never read the field anywhere, so initialising and nulling it was bookkeeping for no reason. - Drop the dead assert in submit() — _ensure_alive() on the prior line already guarantees proc/stdin/stdout exist. Deferred to a follow-up: - L8 (drop try/except around dist.destroy_process_group). Real semantic change: hides errors that occur when a previous test wedged the communicator. Worth doing but needs its own validation. - R1 medium items M1 (module-level flag vs NVTE_CP_POOL_PG env var), M2 (redirect rank>0 stdout vs sentinel scan), M3 (explicit CUDA_VISIBLE_DEVICES per pool). Same reasoning — separate PRs. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
| torch.cuda.manual_seed(seed) | ||
|
|
||
| test_essential = True | ||
| test_essential = False |
There was a problem hiding this comment.
test_essential left as False — will break CI
The default was True (38 essential configs, all passing) and has been changed to False (full matrix, 328 runnable, 5 known failures). The PR description explicitly calls these out as "state-cascade failures" under Known Issues. Merging with test_essential = False will cause CI to report 5 failures by default, regressing from the current baseline.
There was a problem hiding this comment.
Reverted in a018a53 — test_essential = True at line 47 to match the shipping default. The full-matrix run (test_essential=False, 348/0) in the PR description is a one-off validation pass, not the CI default.
world_size and the rank set don't change for the lifetime of one pool, so recreating the world group and a2a+p2p sub-groups per case wastes ~50-100 ms of NCCL setup each. Pre-create them once in the pool worker (new helper _create_cp_comm_groups), stash on the run_attention_with_cp module via module-level _pool_cp_comm_group / _pool_cp_comm_sub_groups pointers, and reuse them from run_dpa_with_cp in pool mode. Pool teardown destroys them once at shutdown. Also move per-case dist.new_group() calls inside the try/finally in run_dpa_with_cp: a failure mid-loop in the a2a+p2p sub_group population otherwise leaks every communicator created before the failure. The finally now only destroys groups we created locally (cp_comm_group / sub_groups populated in the else-branch), leaving pool-owned groups alone for reuse. cyanguwa's review feedback on PR NVIDIA#2993. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
world_size and the rank set don't change for the lifetime of one pool, so recreating the world group and a2a+p2p sub-groups per case wastes ~50-100 ms of NCCL setup each. Pre-create them once in the pool worker (new helper _create_cp_comm_groups), stash on the run_attention_with_cp module via module-level _pool_cp_comm_group / _pool_cp_comm_sub_groups pointers, and reuse them from run_dpa_with_cp in pool mode. Pool teardown destroys them once at shutdown. Also move per-case dist.new_group() calls inside the try/finally in run_dpa_with_cp: a failure mid-loop in the a2a+p2p sub_group population otherwise leaks every communicator created before the failure. The finally now only destroys groups we created locally (cp_comm_group / sub_groups populated in the else-branch), leaving pool-owned groups alone for reuse. cyanguwa's review feedback on PR NVIDIA#2993. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…kars/cp_batching_pool
The Round-1 P1 NCCL-communicator-leak fix (e162a9e) wrapped the ~540-line body of run_dpa_with_cp in try/finally. The wrap itself was tiny but it re-indented every line of the body by one level, inflating the PR diff of run_attention_with_cp.py to ~1000 lines against origin/main. Items 2+3 (d15bfce) since made the wrap unnecessary: - In pool mode, cp_comm_group and cp_comm_sub_groups are owned by the pool worker (which destroys them once at pool shutdown). run_dpa_with_cp neither creates nor destroys them, so an in-body exception can't leak communicators. - In single-shot mode, groups are still created locally, but the subprocess exits at function return; NCCL releases everything at process teardown, so a stray exception leaks communicators only for the milliseconds before the process dies — a bounded one-off cost, not the unbounded accumulation that Round-1 flagged for pool mode. Removing the wrap drops the run_attention_with_cp.py diff against origin/main from ~1000 lines to ~120 lines without changing observable behaviour. Smoke-tested: 4 representative cases pass. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
…kars/cp_batching_pool
…karsingh27/TransformerEngine into sudhakars/cp_batching_v5
|
/te-ci pytorch L3 |
Round-3 review (greptile, discussion_r3250016711) flagged that the working tree had test_essential=False — i.e. the full ~328-config matrix instead of the ~38-config essential subset that the rest of the CI matrix expects. Flipping back to True so CI doesn't regress baseline on the known H1-style cascade configs that only appear in the full matrix. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Items 2+3 (cp_comm_group reuse via pool worker ownership) closed the NCCL communicator leak that the Round-1 P1 fix had wrapped against: in pool mode run_dpa_with_cp no longer creates or destroys groups, so no in-body exception can leak them. In single-shot mode the subprocess exits at function return, releasing any unfreed communicators at process teardown. The explicit try/finally wrap from e162a9e is now redundant. Removing it drops the run_attention_with_cp.py diff against origin/main from ~1072 lines to ~210 (most of the previous diff was indentation deltas from wrapping ~540 lines of function body). Ported from PR NVIDIA#2993's 87c67ac. Pad_between_seqs additions are preserved at their original indent. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…race The pool worker subprocess can die mid-case due to async NCCL aborts or flaky 4-GPU collective state that doesn't reproduce on a fresh pool. Without retry, these manifest as one-off CI failures attributable to infrastructure, not the PR's content. Add a single-attempt retry around PoolWorker.submit() that fires only on infrastructure failure modes (pool-worker-died, timeout, broken-pipe-pre-send). Test-assertion failures from the worker (resp["error"]) carry full per-rank tracebacks and propagate without retry — so a real bug still surfaces as FAILED. Visibility: every retry attempt writes a [POOL-RETRY] line to stderr. pytest captures per-test stderr and writes it into JUnit <testcase>/<system-err>. A flaky test will appear as PASSED in the case row but with a [POOL-RETRY] line in <system-err> — visible to the reviewer, and queryable by CI dashboards looking for flake patterns (e.g. "same test_id retries across multiple CI runs"). If both attempts die, a [POOL-RETRY-FAIL] line is also logged with the first error's headline, then the second attempt's full traceback propagates as the test failure. Smoke-tested: 3 representative cases (p2p, a2a flash; p2p fused) still PASS in 19 s. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
…race The pool worker subprocess can die mid-case due to async NCCL aborts or flaky 4-GPU collective state that doesn't reproduce on a fresh pool. Without retry, these manifest as one-off CI failures attributable to infrastructure, not the PR's content. Add a single-attempt retry around PoolWorker.submit() that fires only on infrastructure failure modes (pool-worker-died, timeout, broken-pipe-pre-send). Test-assertion failures from the worker (resp["error"]) carry full per-rank tracebacks and propagate without retry — so a real bug still surfaces as FAILED. Visibility: every retry attempt writes a [POOL-RETRY] line to stderr. pytest captures per-test stderr and writes it into JUnit <testcase>/<system-err>. A flaky test will appear as PASSED in the case row but with a [POOL-RETRY] line in <system-err> — visible to the reviewer, and queryable by CI dashboards looking for flake patterns (e.g. "same test_id retries across multiple CI runs"). If both attempts die, a [POOL-RETRY-FAIL] line is also logged with the first error's headline, then the second attempt's full traceback propagates as the test failure. Smoke-tested: 3 representative cases (p2p, a2a flash; p2p fused) still PASS in 19 s. Ported from PR NVIDIA#2993 (dc565ff). Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Replaces the [CP_POOL_RESP] sentinel-prefix protocol with a stronger fix at the source: on rank>0, close stdout at the fd level via dup2 to /dev/null at worker startup. Catches both Python `print` writes and C-level (NCCL, libc, etc.) writes that the sentinel could only mitigate by scanning + skipping non-protocol lines. With non-rank-0 stdout silenced, rank 0's JSON line is the only thing that reaches the parent's pipe, so PoolWorker._submit_once collapses from a sentinel-scanning while loop to a single select + readline + json.loads. Closes follow-up M2 from the PR description; addresses greptile's review comment on stdout pollution. Validated on 8xH100 with the test_essential=True flash-attn pool path (9 passed / 55 skipped / 0 failed in 56s; no JSONDecodeError, no protocol corruption). Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Design
Problem
8×H100,
test_essential=True, 38 runnable CP attention configs:torchrun(baseline measured in [PyTorch] Batch CP attention tests in single torchrun to amortize NCC… #2965).We need that overhead amortised, without changing how tests are written or how skips report. #2965 takes one approach (dry-run + per-batch torchrun). This PR proposes a simpler alternative.
Approach
One long-lived
torchrunperworld_size, fed by JSON-over-stdio. No dry-run, no batch chunking, no two-pass dispatch.A session-scoped fixture spawns at most two workers (one for
world_size=2, one forworld_size=4, both lazily). Each test body callspool.submit(kwargs)which writes one JSON request line to rank-0's stdin and reads one JSON response line from rank-0's stdout. NCCL initialises once per pool; every subsequent case reuses the same process group.How the pool works
Worker (
run_attention_with_cp_pool.py) is launched once perworld_sizeand loops:run_dpa_with_cpchecksNVTE_CP_POOL_PGand skips its owndist.init_process_groupcall; on teardown it destroys only the per-case CP sub-groups, leaving the main PG intact for the next case.Pytest side (
PoolWorker) is a thin wrapper aroundsubprocess.Popenofpython -m torch.distributed.run --standalone.submit()writes one request line,selects on stdout with a configurable timeout, reads one sentinel-prefixed JSON response line. On timeout or pipe error it terminates the pool; the nextsubmit()lazily respawns it.A daemon thread drains the worker's stderr into a bounded ring buffer (200 lines / ~40 KB) and echoes each line live to
sys.stderr. On a crash/timeout, the last 4 KB of that buffer is attached to theAssertionErrorraised on the failing test — so CI JUnit XML carries the actual cause (NCCL error, Python traceback, OOM) inline rather than just"pool worker died". Equivalent in spirit to PR #2965'srun_distributed()stderr capture.Bug found while validating: stream race in
AttnFuncWithCPAndKVAllGather.forwardA latent TE bug surfaced once the persistent pool ran the full
test_essential=Falsematrix. In the all-gather CP forward,max_logit_per_step[i]is written insidewith torch.cuda.stream(flash_attn_streams[i]):— and fori=1that'scp_stream, not the default stream. Later, at loop iterationi=2, the code readsmax_logit_per_step[1]viatorch.maximum(...)on the default stream without await_streamin between. The post-loopcurrent_stream().wait_stream(cp_stream)is too late — the race has already fired.The race is latent: outcome depends on stream scheduling. In a fresh-process subprocess (one-torchrun-per-test path) streams are cleanly initialised and timing happens to put the write before the read. In a long-running persistent-worker process — exactly what the pool exposes — prior workloads leave stream state in a different shape and the read can fire before the write completes. The result was
max_logitstale on 3 of 12 head entries (~0.3 abs diff), surfacing in 5 specific configs (allcp_comm_type=all_gather,cp_1_0/cp_1_1,bshdorfp16).Fix is one line in
context_parallel.py:current_stream().wait_stream(flash_attn_streams[i-1])before thetorch.maximumread. No-op when streams are identical (i=1), only fires when reading fromcp_stream(i=2). Independently useful for anyone running CP attention in a long-lived process, not just the pool.Performance
Measured back-to-back on the same 8×H100 box.
test_essential=True(38 runnable configs: 34 × 2-GPU + 4 × 4-GPU)torchrunspawnsB=16world_sizetest_essential=False(full matrix, 50 976 collected, 348 runnable)Pool is ~7 % faster on the full matrix. Note the runnable count rises from 328 to 348 once the race fix lands — the 20 extra cases were silently dropped earlier by the same numerical-corruption path.
Knobs
NVTE_CP_POOL_TIMEOUT_SEC=Ntest_essential=True) is ~15 s; 90 s gives ~6× headroom. Override for slower machines or heavier matrices.NVTE_CP_POOL_TIMING=1[POOL-TIMING] case_idx=N world_size=W wall_s=X.XXX ok=Bon rank-0 stderr per case. Off by default. Used to recalibrate the timeout against a new matrix.NVTE_CP_POOL_PG=1run_dpa_with_cp; tells that function not to calldist.init_process_groupand to leave the main PG alone on teardown. Not for end-user use.There is intentionally no batch-size knob — there's no concept of a batch to size.
Adding a pooled test
@pytest.mark.parametrizestack + inlinepytest.skip(...)checks.cp_poolto the function signature.pool = cp_pool(num_gpus); _submit(pool, **kwargs)wherekwargsbecomesrun_dpa_with_cp(**kwargs).That's it. No two-pass logic, no fixture stubs.
Failure semantics
pytest.skip(...)firespool.submit, no pool work.@pytest.mark.skip(if)marker firesdist.gather_object). Other ranks' tracebacks visible in subprocess stderr. No retry — real test failures propagate.[POOL-RETRY]line in<system-err>. If the retry also dies, FAIL with"pool worker (world_size=N) timed out after 90s; ..."plus the last 4 KB of the worker's stderr and a[POOL-RETRY-FAIL]line summarising both attempts."pool worker died mid-request"."pool worker died before request could be sent".Cross-rank failure detail is strictly better than
all_reduce(ok, op=MIN)(which only tells you some rank failed):gather_objectbrings back each rank's(ok, traceback)tuple so the reported error is the actual non-zero-rank stack trace, not "see subprocess stderr."What happens when a pool worker stalls
Three terminal states for any
submit(): response arrives → normal handling; no response afterPOOL_SUBMIT_TIMEOUT_SEC→ timeout path; process died → mid-request-death path. Any stall (application hang, NCCL deadlock, GPU wedge, even a stdout-pipe-full self-deadlock) eventually resolves to one of the latter two. The pool isSIGTERM'd (5 s grace) thenSIGKILL'd, andsubmit()lazily respawns a fresh pool of the sameworld_size(~6–9 s NCCL re-init) for the retry attempt. The other pool (differentworld_size) is unaffected. Blast radius: one respawn per stall; the test only FAILs if the retry also dies.Single-retry on pool-infrastructure failures
The pool worker subprocess can die mid-case due to async NCCL aborts or flaky 4-GPU collective state that doesn't reproduce on a fresh pool — we observed this once per ~350-case full-matrix run on
cp_comm_type=a2a+p2p(4-GPU pool, fused-attention). Eachsubmit()retries once if the failure mode matches"pool worker died","timed out", or"before request could be sent". Test-assertion failures (the worker returned{"ok": false, "error": ...}) do not retry — those are real bugs and propagate.Every retry leaves a
[POOL-RETRY]line in stderr; if both attempts die, a[POOL-RETRY-FAIL]line summarising the first attempt's error is also written. pytest captures per-test stderr and writes it into the JUnit<testcase>/<system-err>field, so a test that PASSED on retry is visible to anyone querying CI artifacts forPOOL-RETRY— a 50 %-flaky test would show up as ~50 % of its CI runs having a[POOL-RETRY]line in<system-err>, even though the case row is green.Mitigations for shared-process state
All cases share one Python process and one NCCL world per
world_size, so anything that needs a clean per-test starting point is reset before each case (in_run_one, not in afinallyblock, so the first case is also clean):torch.manual_seed(1234) + torch.cuda.manual_seed(1234)— RNG reseeded so input tensors are reproducible per case.FP8GlobalStateManager.reset()— drops FP8 amax history etc. that would otherwise leak across cases.torch.cuda.empty_cache().copy.deepcopy(model_configs_*[model])insiderun_dpa_with_cp— the THD branch rewritesattn_mask_typein place; without deepcopy the change leaks into the module-level dict.run_dpa_with_cpitself re-setsNVTE_FUSED_ATTN/NVTE_FLASH_ATTNunconditionally at the top of every call and pops the FP8-related transient env vars, so no explicit env-key reset is needed in the pool worker.The single-shot
run_dpa_with_cpdoes some of these inherently (it's a fresh process). For the pooled path we replicate them explicitly so the two execution modes produce identical per-case state.Edge cases
torch.distributed.run --standalonepicks a free rendezvous port at bind time. NoMASTER_PORTplumbing needed; parallel pytest sessions (e.g.,L1_pytorch_distributed_unittest's det vs non-det concurrent runs on disjoint GPU sets) cannot collide._kill()terminates the pool; the retry attempt lazily respawns a fresh worker. Blast radius: one respawn (~6–9 s) plus, if the retry also dies, one FAIL.BrokenPipeError/ empty read → AssertionError. The wrapper retries once on a fresh pool; the failing test only surfaces as FAILED if the retry also dies.torch.cuda.set_device(rank % device_count)means both pools claim GPUs 0–N starting at 0. They never run CUDA concurrently (pytest serialises tests), so the idle pool only holds ~1 GB of CUDA context per shared GPU — well within H100's 80 GB. NCCL worlds are independent. No collision. This is intentional: setting an explicitCUDA_VISIBLE_DEVICESper pool would prevent the parallel-pytest-session pattern used byqa/L1_pytorch_distributed_unittest/test.sh(det / non-det on disjoint GPU sets); see Follow-ups.NVTE_CP_POOL_PGcollision. The env var is set by the pool worker afterdist.init_process_groupand only read byrun_dpa_with_cp— no other consumer. If end-user code accidentally sets it without an init'd PG,run_dpa_with_cpwill fail when it tries to use the (non-existent) PG. Harmless; same failure as setting any other internal flag incorrectly.Validation
test_essential=Truetest_essential=FalsePer-case wall-time distribution on H100 (
test_essential=True, withNVTE_CP_POOL_TIMING=1): min 1.87 s, p50 4.77 s, p95 12.43 s, max 15.39 s. Drove thePOOL_SUBMIT_TIMEOUT_SECdefault of 90 s (~6× max).Deterministic mode
qa/L1_pytorch_distributed_unittest/test.shon this branch runstest_attention_with_cp.pytwice — once with defaultNVTE_ALLOW_NONDETERMINISTIC_ALGO=1, once with=0— in parallel on disjoint GPU sets when ≥8 GPUs are available (sequentially otherwise). Both modes pass cleanly.An inline skip at
test_attention_with_cp.py:605-609handles a known cuDNN sm90 issue:Without this skip, 5 fused-attention THD configs OOM under
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0withTried to allocate 512.01 GiBinsidetex.fused_attn_bwd(workspace-size pathology in cuDNN's deterministic backward path on Hopper, not actual memory exhaustion). The skip is independent of the pool design but lives in the same test file; documented here so deterministic-mode skips don't get mistaken for batching-infra bugs.Follow-ups (not in this PR)
Items intentionally deferred to keep this PR scoped:
try/except Exception: passarounddist.destroy_process_groupcalls. Real semantic change: would surface errors that the swallow currently hides. Needs its own validation.NVTE_CP_POOL_PGenv-var contract between the pool worker andrun_dpa_with_cpwith a module-level flag (run_attention_with_cp._POOL_MANAGED_PG = True). Cleaner, type-checkable; same effect./dev/nullat worker startup instead of sentinel-prefix scanning. Closes the stdout-pollution class at the source rather than papering over it.CUDA_VISIBLE_DEVICESper pool (so 2-GPU and 4-GPU pools claim disjoint GPUs) would break the parallel-pytest-session pattern used byqa/L1_pytorch_distributed_unittest/test.sh. Each top-level pytest session setsCUDA_VISIBLE_DEVICESitself (e.g. det on GPUs 0-3, non-det on 4-7); the pool inherits that and usesrank % device_countto map within the session's slice. Adding per-pool device pinning would override the session's slice and produce overlap across sessions. The current "overlap within a session, idle context only" behaviour (edge case 4) is the right trade-off.Files
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py—wait_streamfix inAttnFuncWithCPAndKVAllGather.forward(the latent stream race the pool exposed).tests/pytorch/attention/test_attention_with_cp.py—cp_poolsession fixture +PoolWorker(lazy spawn,--standalone, 90 s timeout, kill-and-respawn, stderr drainer thread + tail attached to crash AssertionErrors, single retry on pool-infrastructure failures with[POOL-RETRY]stderr tag).tests/pytorch/attention/run_attention_with_cp_pool.py— pool worker: init NCCL once, dispatch loop,_reset_between_cases(seed/FP8/cache),gather_objectfor cross-rank failure detail, optionalNVTE_CP_POOL_TIMING=1per-case timing log.tests/pytorch/attention/run_attention_with_cp.py—run_dpa_with_cphonoursNVTE_CP_POOL_PG(skips its own PG init/destroy), deep-copies model configs.qa/L1_pytorch_distributed_unittest/test.sh— runstest_attention_with_cp.pytwice (non-det + det) in parallel on disjoint GPU sets when ≥8 GPUs available; sequential fallback otherwise.Comparison to #2965
Both PRs solve the same problem; this one is structurally smaller and has fewer concepts.
origin/mainCP_TEST_BATCH_SIZE,CP_TEST_BATCH_RETRYNVTE_CP_POOL_TIMEOUT_SEC(default 90 s),NVTE_CP_POOL_TIMING_COLLECT_MODE,_DummyRequest,_item_static_skip,_BACKEND_CACHE, batch chunking, atomic JSON flush, singleton retryall_reduce(ok, MIN)— boolean onlygather_object— full traceback per rankrun_distributed()attaches last 4 KB)MASTER_PORTenv per parallel pytest session--standalone, automatictest_essential=True)Type of change
Checklist
test_essential=False: 348 / 0, both det and non-det)