Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ccf7321
Batch CP attention tests via a persistent NCCL pool
sudhakarsingh27 May 12, 2026
59609ac
Reset FP8 state and barrier between pool cases
sudhakarsingh27 May 12, 2026
73e8cef
Deep-copy ModelConfig in run_dpa_with_cp
sudhakarsingh27 May 12, 2026
311137c
Skip deterministic configs incompatible with FusedAttention
sudhakarsingh27 May 14, 2026
49878d6
Reseed RNG between pool cases; reset before, not after
sudhakarsingh27 May 14, 2026
385e966
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2026
86b334b
Robustify pool: capture worker stderr, tighten timeout, add timing knob
sudhakarsingh27 May 14, 2026
ae5298c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2026
e162a9e
Address PR review: NCCL leak, stdout protocol, Windows note
sudhakarsingh27 May 14, 2026
169be82
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2026
557bd80
[PyTorch] Fix stream race on max_logit_per_step in all-gather CP forward
sudhakarsingh27 May 15, 2026
4815883
Address PR review (R2): drop dead code in pool worker and PoolWorker
sudhakarsingh27 May 15, 2026
d15bfce
Address PR review (items 2+3): reuse CP groups across pool cases
sudhakarsingh27 May 15, 2026
dd1d802
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into sudha…
sudhakarsingh27 May 15, 2026
87c67ac
Flatten try/finally wrap in run_dpa_with_cp
sudhakarsingh27 May 15, 2026
adb84af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2026
70ad33d
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into sudha…
sudhakarsingh27 May 18, 2026
b590cd8
Merge branch 'sudhakars/cp_batching_pool' of https://github.com/sudha…
sudhakarsingh27 May 18, 2026
a018a53
Set test_essential=True to match shipping default
sudhakarsingh27 May 18, 2026
dc565ff
Retry once on pool-infrastructure failures with stderr-logged flake t…
sudhakarsingh27 May 19, 2026
5f37995
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2026
36b65fb
[PyTorch] Pool: redirect non-rank-0 stdout to /dev/null; drop sentinel
sudhakarsingh27 May 19, 2026
0313809
Merge branch 'main' into sudhakars/cp_batching_pool
sudhakarsingh27 May 19, 2026
06ec3db
Merge branch 'main' of https://github.com/NVIDIA/TransformerEngine in…
sudhakarsingh27 May 20, 2026
9ba49be
Address PR review (R3): backend-cache, pool isolation, group-kill, de…
sudhakarsingh27 May 21, 2026
75e1479
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2026
3bfe03a
Merge branch 'main' into sudhakars/cp_batching_pool
sudhakarsingh27 May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 79 additions & 23 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.

import copy
import os
import sys
import logging
Expand Down Expand Up @@ -29,6 +30,15 @@
)
from utils import ModelConfig, compare_and_assert

# Pool mode (NVTE_CP_POOL_PG=1) only: shared CP collective groups, created once
# per pool by run_attention_with_cp_pool.main() and reused across every case in
# that pool. world_size and the rank set don't change per case, so re-creating
# these per call would be wasted NCCL setup (~50-100 ms each). Single-shot
# subprocess mode leaves these None / [] and run_dpa_with_cp creates/destroys
# its own groups inline.
_pool_cp_comm_group = None
_pool_cp_comm_sub_groups: list = []

dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}


Expand Down Expand Up @@ -209,10 +219,13 @@ def run_dpa_with_cp(
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
# Deep-copy: the module-level dict is shared across pool cases; the
# THD branch below rewrites attn_mask_type in place, which would
# otherwise leak into subsequent cases reusing the same model key.
config = copy.deepcopy(model_configs_flash_attn[model])
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
config = copy.deepcopy(model_configs_fused_attn[model])
assert config.attn_mask_type in [
"causal",
"no_mask",
Expand All @@ -226,6 +239,9 @@ def run_dpa_with_cp(
# set up distributed group
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
# When NVTE_CP_POOL_PG=1, the pool runner owns the lifecycle of the main
# process group across many cases; here we only reuse it.
_pool_managed_pg = os.getenv("NVTE_CP_POOL_PG", "0") == "1"
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
Expand All @@ -234,25 +250,35 @@ def run_dpa_with_cp(
device = rank % device_count
torch.cuda.set_device(device)
logging.info(f"[Rank {rank}] Setup: world_size {world_size}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)

# set up communication group for CP
if not _pool_managed_pg:
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)

# Set up communication group for CP. In pool mode, the pool worker has
# already pre-created world-scoped and a2a+p2p sub-groups once and stashed
# them in module-level pointers; we reuse those and the pool destroys them
# at shutdown. In single-shot mode we create them per call and destroy in
# the finally below.
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert world_size % 2 == 0, (
"{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size"
" = 2."
)
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)

_reusing_pool_groups = _pool_managed_pg and _pool_cp_comm_group is not None
cp_comm_group = None
cp_comm_sub_groups: list = []
if _reusing_pool_groups:
cp_comm_group = _pool_cp_comm_group
cp_comm_sub_groups = _pool_cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else []
else:
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert world_size % 2 == 0, (
"{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has"
" cp_size = 2."
)
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
if scaling_mode == "delayed":
fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
Expand Down Expand Up @@ -564,7 +590,10 @@ def run_dpa_with_cp(
seq_kv_size = dbias.shape[-1]
# Reshape to split seq_q dimension
dbias = dbias.view(
*shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size
*shape_before_seq,
2 * world_size,
seq_q_size // (2 * world_size),
seq_kv_size,
)
# Index select on the newly created dimension (now at position seq_q_dim)
dbias = dbias.index_select(seq_q_dim, seq_idx)
Expand Down Expand Up @@ -754,16 +783,43 @@ def run_dpa_with_cp(
)
elif qkv_format == "thd":
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
t,
tensors_cp[i],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
else:
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
)
logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches")

# destroy distribution group
dist.destroy_process_group()
# Teardown on the success path. Pool mode: cp_comm_group / cp_comm_sub_groups
# point at pool-shared groups owned by the pool runner (which destroys them
# at pool shutdown), and the main PG is also pool-owned — both branches
# below are no-ops. Single-shot mode: destroy what we created here. If the
# body above raises, we skip this — the subprocess dies at function return
# and NCCL releases the communicators with the process.
if not _reusing_pool_groups:
if cp_comm_group is not None:
try:
dist.destroy_process_group(cp_comm_group)
except Exception:
pass
for g in cp_comm_sub_groups:
try:
dist.destroy_process_group(g)
except Exception:
pass
if not _pool_managed_pg:
try:
dist.destroy_process_group()
except Exception:
pass


def main(**kwargs):
Expand Down
221 changes: 221 additions & 0 deletions tests/pytorch/attention/run_attention_with_cp_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""
Persistent worker for batched CP attention tests.

Launched ONCE per (pytest session, world_size) by torchrun. All ranks init
NCCL, then enter a dispatch loop:

rank 0:
read one JSON request line from stdin
broadcast it to all ranks
all ranks:
call run_dpa_with_cp(**kwargs) — the same work function the
per-case subprocess design uses, with NVTE_CP_POOL_PG=1 so the
function reuses our PG instead of re-initing it
torch.cuda.empty_cache() per case
all ranks gather (ok, error_msg) to rank 0
rank 0:
write one JSON response line to stdout

Protocol (line-delimited JSON over rank-0 stdio):
request : {"op": "run", "kwargs": {...}}
{"op": "shutdown"}
response: {"ok": true}
{"ok": false, "error": "first failing rank's traceback"}
"""
import json
import os
import sys
import time
import traceback

import torch
import torch.distributed as dist

# Make sibling modules importable when launched directly.
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from run_attention_with_cp import run_dpa_with_cp
from transformer_engine.pytorch.quantization import FP8GlobalStateManager


def _recv_request(rank: int) -> dict:
box = [None]
if rank == 0:
line = sys.stdin.readline()
box[0] = {"op": "shutdown"} if not line else json.loads(line)
dist.broadcast_object_list(box, src=0)
return box[0]


def _send_response(rank: int, payload: dict) -> None:
if rank == 0:
sys.stdout.write(json.dumps(payload) + "\n")
sys.stdout.flush()
Comment on lines +54 to +57
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 stdout pollution can silently corrupt the JSON protocol

torchrun (and the worker processes it spawns for ranks 1–N) all inherit the same stdout file descriptor as rank 0. If torchrun writes any status line to stdout, or if any non-rank-0 worker accidentally prints (e.g. via a print call in a library, NCCL debug output, or a Python warning), those bytes are interleaved with rank 0's JSON responses. The parent's readline() in PoolWorker.submit would then receive a non-JSON line and raise a json.JSONDecodeError, killing the pool and failing the test with a misleading error message.

Consider redirecting torchrun's own output or adding a sentinel prefix to every response line so the reader can skip unrecognised lines.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented the stronger fix in 36b65fb3: redirect non-rank-0 stdout to /dev/null at the fd level via dup2 at worker startup. Catches both Python print and C-level (NCCL, libc, etc.) writes — the sentinel-prefix approach could only mitigate the former by scanning + skipping.

With non-rank-0 stdout silenced, rank 0's JSON line is the only thing that reaches the parent's pipe, so the _RESP_PREFIX machinery + the sentinel-scanning while loop in PoolWorker._submit_once are gone. The reader collapses to one select + one readline + one json.loads.

Validated on 8×H100 (test_essential=True, flash-attn): 9 passed / 55 skipped / 0 failed in 56 s; no JSONDecodeError, no protocol corruption. Closes the M2 follow-up listed in the PR description.



def _silence_non_rank0_stdout(rank: int) -> None:
"""Redirect non-rank-0 stdout to /dev/null at fd level.

All ranks share rank 0's stdout fd (torchrun inherits it from the launcher),
so Python/library writes on rank>0 would interleave with rank 0's JSON
protocol on the parent's pipe. Closing fd 1 at the OS level on rank>0
catches both Python (``print``) and C-level (NCCL, etc.) writes.
"""
if rank == 0:
return
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, 1)
os.close(devnull)
sys.stdout = open(1, "w", closefd=False)


def _reset_between_cases() -> None:
"""Drop state that would otherwise cascade across cases.

Matches the per-case startup of the single-shot worker
(``_run_single_config`` on the per-case-subprocess branch): identical RNG
seed at the start of every case, FP8 state cleared, allocator clean.
``run_dpa_with_cp`` re-sets ``NVTE_FUSED_ATTN``/``NVTE_FLASH_ATTN``
unconditionally and pops the other transient env vars itself, so no
explicit pop is needed here.
"""
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
FP8GlobalStateManager.reset()
torch.cuda.empty_cache()
# Invalidate DPA's module-level backend cache so the per-case
# NVTE_FLASH_ATTN/NVTE_FUSED_ATTN env-var toggle actually takes effect
# instead of reusing the previous case's resolved backend.
try:
from transformer_engine.pytorch.attention.dot_product_attention import dot_product_attention

dot_product_attention._attention_backends["backend_selection_requires_update"] = True
except (ImportError, AttributeError, KeyError):
pass


_case_counter = 0


def _run_one(req: dict, rank: int) -> tuple[bool, str]:
global _case_counter
op = req["op"]
if op != "run":
return False, f"unknown op: {op}"
# Reset BEFORE the case so the first case also starts from a known RNG seed
# and clean FP8 state — same as the single-shot worker's per-process startup.
_reset_between_cases()
t0 = time.monotonic()
ok = True
err = ""
try:
run_dpa_with_cp(**req.get("kwargs", {}))
except Exception:
ok = False
err = f"[Rank {rank}] {traceback.format_exc()}"
wall = time.monotonic() - t0
# Per-case wall time on rank 0, opt-in via NVTE_CP_POOL_TIMING=1.
# Used to tune POOL_SUBMIT_TIMEOUT_SEC against the observed distribution.
if rank == 0 and int(os.environ.get("NVTE_CP_POOL_TIMING", "0")):
_case_counter += 1
sys.stderr.write(
f"[POOL-TIMING] case_idx={_case_counter} "
f"world_size={int(os.environ.get('WORLD_SIZE', 0))} "
f"wall_s={wall:.3f} ok={ok}\n"
)
sys.stderr.flush()
return ok, err


def _create_cp_comm_groups(rank: int, world_size: int) -> tuple:
"""Pre-create the CP collective groups for this pool.

world_size and the rank set are constant for the lifetime of one pool, so
the world group and the a2a+p2p sub-groups are deterministic. Creating
them once here and reusing them across every case eliminates ~50-100 ms
of NCCL setup per case (cyanguwa's review feedback on PR #2993).

Returns ``(world_group, a2a_p2p_sub_groups)``. ``a2a_p2p_sub_groups`` is
empty when world_size is too small to support a2a+p2p (needs an even
world_size ≥ 4); cases with cp_comm_type='a2a+p2p' wouldn't be routed to
such a pool anyway.
"""
world_group = dist.new_group(range(world_size), backend="nccl")
sub_groups: list = []
if world_size >= 4 and world_size % 2 == 0:
# Mirror the layout in run_attention_with_cp.py: cp_size/2 pairs along
# axis 0, plus 2 stride-2 groups along axis 1.
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
sub_groups.append(sub_group)
return world_group, sub_groups


def main() -> None:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
_silence_non_rank0_stdout(rank)
torch.cuda.set_device(rank % torch.cuda.device_count())
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
os.environ["NVTE_CP_POOL_PG"] = "1"

# Stash pool-shared CP groups on the run_attention_with_cp module so
# run_dpa_with_cp can read them per case. Imported here (after the env var
# is set) to keep import-time side effects minimal.
import run_attention_with_cp as _rac

_rac._pool_cp_comm_group, _rac._pool_cp_comm_sub_groups = _create_cp_comm_groups(
rank, world_size
)

try:
while True:
req = _recv_request(rank)
if req.get("op") == "shutdown":
break

ok, msg = _run_one(req, rank)

gathered: list[tuple[bool, str]] = [None] * world_size # type: ignore[list-item]
# gather_object is itself a collective synchronization point — if
# every rank reached it, none is ahead. No extra barrier needed.
dist.gather_object((ok, msg), gathered if rank == 0 else None, dst=0)

if rank == 0:
all_ok = all(o for o, _ in gathered)
if all_ok:
_send_response(rank, {"ok": True})
else:
first_err = next(m for o, m in gathered if not o)
_send_response(rank, {"ok": False, "error": first_err})
# Release the allocator cache so this pool doesn't squat on
# GPUs that an overlapping different-world-size pool needs.
torch.cuda.empty_cache()
finally:
# Tear down pool-shared CP groups before the main PG (NCCL requires
# sub-groups to be destroyed first). Each destroy is independently
# guarded so a wedged communicator on one group doesn't leak the rest.
if _rac._pool_cp_comm_group is not None:
try:
dist.destroy_process_group(_rac._pool_cp_comm_group)
except Exception:
pass
for g in _rac._pool_cp_comm_sub_groups:
try:
dist.destroy_process_group(g)
except Exception:
pass
_rac._pool_cp_comm_group = None
_rac._pool_cp_comm_sub_groups = []
dist.destroy_process_group()


if __name__ == "__main__":
main()
Loading
Loading