-
Notifications
You must be signed in to change notification settings - Fork 725
CP Tests batching using subprocess worker pool #2993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sudhakarsingh27
wants to merge
27
commits into
NVIDIA:main
Choose a base branch
from
sudhakarsingh27:sudhakars/cp_batching_pool
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+576
−70
Open
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
ccf7321
Batch CP attention tests via a persistent NCCL pool
sudhakarsingh27 59609ac
Reset FP8 state and barrier between pool cases
sudhakarsingh27 73e8cef
Deep-copy ModelConfig in run_dpa_with_cp
sudhakarsingh27 311137c
Skip deterministic configs incompatible with FusedAttention
sudhakarsingh27 49878d6
Reseed RNG between pool cases; reset before, not after
sudhakarsingh27 385e966
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 86b334b
Robustify pool: capture worker stderr, tighten timeout, add timing knob
sudhakarsingh27 ae5298c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e162a9e
Address PR review: NCCL leak, stdout protocol, Windows note
sudhakarsingh27 169be82
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 557bd80
[PyTorch] Fix stream race on max_logit_per_step in all-gather CP forward
sudhakarsingh27 4815883
Address PR review (R2): drop dead code in pool worker and PoolWorker
sudhakarsingh27 d15bfce
Address PR review (items 2+3): reuse CP groups across pool cases
sudhakarsingh27 dd1d802
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into sudha…
sudhakarsingh27 87c67ac
Flatten try/finally wrap in run_dpa_with_cp
sudhakarsingh27 adb84af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 70ad33d
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into sudha…
sudhakarsingh27 b590cd8
Merge branch 'sudhakars/cp_batching_pool' of https://github.com/sudha…
sudhakarsingh27 a018a53
Set test_essential=True to match shipping default
sudhakarsingh27 dc565ff
Retry once on pool-infrastructure failures with stderr-logged flake t…
sudhakarsingh27 5f37995
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 36b65fb
[PyTorch] Pool: redirect non-rank-0 stdout to /dev/null; drop sentinel
sudhakarsingh27 0313809
Merge branch 'main' into sudhakars/cp_batching_pool
sudhakarsingh27 06ec3db
Merge branch 'main' of https://github.com/NVIDIA/TransformerEngine in…
sudhakarsingh27 9ba49be
Address PR review (R3): backend-cache, pool isolation, group-kill, de…
sudhakarsingh27 75e1479
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 3bfe03a
Merge branch 'main' into sudhakars/cp_batching_pool
sudhakarsingh27 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,221 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """ | ||
| Persistent worker for batched CP attention tests. | ||
|
|
||
| Launched ONCE per (pytest session, world_size) by torchrun. All ranks init | ||
| NCCL, then enter a dispatch loop: | ||
|
|
||
| rank 0: | ||
| read one JSON request line from stdin | ||
| broadcast it to all ranks | ||
| all ranks: | ||
| call run_dpa_with_cp(**kwargs) — the same work function the | ||
| per-case subprocess design uses, with NVTE_CP_POOL_PG=1 so the | ||
| function reuses our PG instead of re-initing it | ||
| torch.cuda.empty_cache() per case | ||
| all ranks gather (ok, error_msg) to rank 0 | ||
| rank 0: | ||
| write one JSON response line to stdout | ||
|
|
||
| Protocol (line-delimited JSON over rank-0 stdio): | ||
| request : {"op": "run", "kwargs": {...}} | ||
| {"op": "shutdown"} | ||
| response: {"ok": true} | ||
| {"ok": false, "error": "first failing rank's traceback"} | ||
| """ | ||
| import json | ||
| import os | ||
| import sys | ||
| import time | ||
| import traceback | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| # Make sibling modules importable when launched directly. | ||
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|
|
||
| from run_attention_with_cp import run_dpa_with_cp | ||
| from transformer_engine.pytorch.quantization import FP8GlobalStateManager | ||
|
|
||
|
|
||
| def _recv_request(rank: int) -> dict: | ||
| box = [None] | ||
| if rank == 0: | ||
| line = sys.stdin.readline() | ||
| box[0] = {"op": "shutdown"} if not line else json.loads(line) | ||
| dist.broadcast_object_list(box, src=0) | ||
| return box[0] | ||
|
|
||
|
|
||
| def _send_response(rank: int, payload: dict) -> None: | ||
| if rank == 0: | ||
| sys.stdout.write(json.dumps(payload) + "\n") | ||
| sys.stdout.flush() | ||
|
|
||
|
|
||
| def _silence_non_rank0_stdout(rank: int) -> None: | ||
| """Redirect non-rank-0 stdout to /dev/null at fd level. | ||
|
|
||
| All ranks share rank 0's stdout fd (torchrun inherits it from the launcher), | ||
| so Python/library writes on rank>0 would interleave with rank 0's JSON | ||
| protocol on the parent's pipe. Closing fd 1 at the OS level on rank>0 | ||
| catches both Python (``print``) and C-level (NCCL, etc.) writes. | ||
| """ | ||
| if rank == 0: | ||
| return | ||
| devnull = os.open(os.devnull, os.O_WRONLY) | ||
| os.dup2(devnull, 1) | ||
| os.close(devnull) | ||
| sys.stdout = open(1, "w", closefd=False) | ||
|
|
||
|
|
||
| def _reset_between_cases() -> None: | ||
| """Drop state that would otherwise cascade across cases. | ||
|
|
||
| Matches the per-case startup of the single-shot worker | ||
| (``_run_single_config`` on the per-case-subprocess branch): identical RNG | ||
| seed at the start of every case, FP8 state cleared, allocator clean. | ||
| ``run_dpa_with_cp`` re-sets ``NVTE_FUSED_ATTN``/``NVTE_FLASH_ATTN`` | ||
| unconditionally and pops the other transient env vars itself, so no | ||
| explicit pop is needed here. | ||
| """ | ||
| torch.manual_seed(1234) | ||
| torch.cuda.manual_seed(1234) | ||
| FP8GlobalStateManager.reset() | ||
| torch.cuda.empty_cache() | ||
| # Invalidate DPA's module-level backend cache so the per-case | ||
| # NVTE_FLASH_ATTN/NVTE_FUSED_ATTN env-var toggle actually takes effect | ||
| # instead of reusing the previous case's resolved backend. | ||
| try: | ||
| from transformer_engine.pytorch.attention.dot_product_attention import dot_product_attention | ||
|
|
||
| dot_product_attention._attention_backends["backend_selection_requires_update"] = True | ||
| except (ImportError, AttributeError, KeyError): | ||
| pass | ||
|
|
||
|
|
||
| _case_counter = 0 | ||
|
|
||
|
|
||
| def _run_one(req: dict, rank: int) -> tuple[bool, str]: | ||
| global _case_counter | ||
| op = req["op"] | ||
| if op != "run": | ||
| return False, f"unknown op: {op}" | ||
| # Reset BEFORE the case so the first case also starts from a known RNG seed | ||
| # and clean FP8 state — same as the single-shot worker's per-process startup. | ||
| _reset_between_cases() | ||
| t0 = time.monotonic() | ||
| ok = True | ||
| err = "" | ||
| try: | ||
| run_dpa_with_cp(**req.get("kwargs", {})) | ||
| except Exception: | ||
| ok = False | ||
| err = f"[Rank {rank}] {traceback.format_exc()}" | ||
| wall = time.monotonic() - t0 | ||
| # Per-case wall time on rank 0, opt-in via NVTE_CP_POOL_TIMING=1. | ||
| # Used to tune POOL_SUBMIT_TIMEOUT_SEC against the observed distribution. | ||
| if rank == 0 and int(os.environ.get("NVTE_CP_POOL_TIMING", "0")): | ||
| _case_counter += 1 | ||
| sys.stderr.write( | ||
| f"[POOL-TIMING] case_idx={_case_counter} " | ||
| f"world_size={int(os.environ.get('WORLD_SIZE', 0))} " | ||
| f"wall_s={wall:.3f} ok={ok}\n" | ||
| ) | ||
| sys.stderr.flush() | ||
| return ok, err | ||
|
|
||
|
|
||
| def _create_cp_comm_groups(rank: int, world_size: int) -> tuple: | ||
| """Pre-create the CP collective groups for this pool. | ||
|
|
||
| world_size and the rank set are constant for the lifetime of one pool, so | ||
| the world group and the a2a+p2p sub-groups are deterministic. Creating | ||
| them once here and reusing them across every case eliminates ~50-100 ms | ||
| of NCCL setup per case (cyanguwa's review feedback on PR #2993). | ||
|
|
||
| Returns ``(world_group, a2a_p2p_sub_groups)``. ``a2a_p2p_sub_groups`` is | ||
| empty when world_size is too small to support a2a+p2p (needs an even | ||
| world_size ≥ 4); cases with cp_comm_type='a2a+p2p' wouldn't be routed to | ||
| such a pool anyway. | ||
| """ | ||
| world_group = dist.new_group(range(world_size), backend="nccl") | ||
| sub_groups: list = [] | ||
| if world_size >= 4 and world_size % 2 == 0: | ||
| # Mirror the layout in run_attention_with_cp.py: cp_size/2 pairs along | ||
| # axis 0, plus 2 stride-2 groups along axis 1. | ||
| cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] | ||
| cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] | ||
| for sub_ranks in cp_comm_sub_ranks: | ||
| sub_group = dist.new_group(sub_ranks, backend="nccl") | ||
| if rank in sub_ranks: | ||
| sub_groups.append(sub_group) | ||
| return world_group, sub_groups | ||
|
|
||
|
|
||
| def main() -> None: | ||
| rank = int(os.environ["RANK"]) | ||
| world_size = int(os.environ["WORLD_SIZE"]) | ||
| _silence_non_rank0_stdout(rank) | ||
| torch.cuda.set_device(rank % torch.cuda.device_count()) | ||
| dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) | ||
| os.environ["NVTE_CP_POOL_PG"] = "1" | ||
|
|
||
| # Stash pool-shared CP groups on the run_attention_with_cp module so | ||
| # run_dpa_with_cp can read them per case. Imported here (after the env var | ||
| # is set) to keep import-time side effects minimal. | ||
| import run_attention_with_cp as _rac | ||
|
|
||
| _rac._pool_cp_comm_group, _rac._pool_cp_comm_sub_groups = _create_cp_comm_groups( | ||
| rank, world_size | ||
| ) | ||
|
|
||
| try: | ||
| while True: | ||
| req = _recv_request(rank) | ||
| if req.get("op") == "shutdown": | ||
| break | ||
|
|
||
| ok, msg = _run_one(req, rank) | ||
|
|
||
| gathered: list[tuple[bool, str]] = [None] * world_size # type: ignore[list-item] | ||
| # gather_object is itself a collective synchronization point — if | ||
| # every rank reached it, none is ahead. No extra barrier needed. | ||
| dist.gather_object((ok, msg), gathered if rank == 0 else None, dst=0) | ||
|
|
||
| if rank == 0: | ||
| all_ok = all(o for o, _ in gathered) | ||
| if all_ok: | ||
| _send_response(rank, {"ok": True}) | ||
| else: | ||
| first_err = next(m for o, m in gathered if not o) | ||
| _send_response(rank, {"ok": False, "error": first_err}) | ||
| # Release the allocator cache so this pool doesn't squat on | ||
| # GPUs that an overlapping different-world-size pool needs. | ||
| torch.cuda.empty_cache() | ||
| finally: | ||
| # Tear down pool-shared CP groups before the main PG (NCCL requires | ||
| # sub-groups to be destroyed first). Each destroy is independently | ||
| # guarded so a wedged communicator on one group doesn't leak the rest. | ||
| if _rac._pool_cp_comm_group is not None: | ||
| try: | ||
| dist.destroy_process_group(_rac._pool_cp_comm_group) | ||
| except Exception: | ||
| pass | ||
| for g in _rac._pool_cp_comm_sub_groups: | ||
| try: | ||
| dist.destroy_process_group(g) | ||
| except Exception: | ||
| pass | ||
| _rac._pool_cp_comm_group = None | ||
| _rac._pool_cp_comm_sub_groups = [] | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchrun(and the worker processes it spawns for ranks 1–N) all inherit the same stdout file descriptor as rank 0. If torchrun writes any status line to stdout, or if any non-rank-0 worker accidentally prints (e.g. via aprintcall in a library, NCCL debug output, or a Python warning), those bytes are interleaved with rank 0's JSON responses. The parent'sreadline()inPoolWorker.submitwould then receive a non-JSON line and raise ajson.JSONDecodeError, killing the pool and failing the test with a misleading error message.Consider redirecting torchrun's own output or adding a sentinel prefix to every response line so the reader can skip unrecognised lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented the stronger fix in
36b65fb3: redirect non-rank-0 stdout to/dev/nullat the fd level viadup2at worker startup. Catches both Pythonprintand C-level (NCCL, libc, etc.) writes — the sentinel-prefix approach could only mitigate the former by scanning + skipping.With non-rank-0 stdout silenced, rank 0's JSON line is the only thing that reaches the parent's pipe, so the
_RESP_PREFIXmachinery + the sentinel-scanning while loop inPoolWorker._submit_onceare gone. The reader collapses to oneselect+ onereadline+ onejson.loads.Validated on 8×H100 (
test_essential=True, flash-attn): 9 passed / 55 skipped / 0 failed in 56 s; noJSONDecodeError, no protocol corruption. Closes the M2 follow-up listed in the PR description.