Skip to content

Commit 8dee07b

Browse files
committed
Merge main into pr_better_errors and update new assertions
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
2 parents 221d723 + 6e0085a commit 8dee07b

134 files changed

Lines changed: 12969 additions & 2947 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/build.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ name: 'Build'
77
on:
88
pull_request:
99
workflow_dispatch:
10+
concurrency:
11+
# Group by workflow name + PR number (for PRs) or ref (for branch/tag pushes)
12+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
13+
cancel-in-progress: true
1014
jobs:
1115
core:
1216
name: 'Core'

.github/workflows/docs.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ on:
88
pull_request:
99
workflow_dispatch:
1010
workflow_call:
11+
concurrency:
12+
# Group by workflow name + PR number (for PRs) or ref (for branch/tag pushes)
13+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
14+
cancel-in-progress: true
1115
jobs:
1216
build_docs:
1317
name: 'Build'

.github/workflows/lint.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ name: 'Lint'
77
on:
88
pull_request:
99
workflow_dispatch:
10+
concurrency:
11+
# Group by workflow name + PR number (for PRs) or ref (for branch/tag pushes)
12+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
13+
cancel-in-progress: true
1014
jobs:
1115
pytorch_cpplint:
1216
name: 'PyTorch C++'

docs/examples/attention/attention.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@
151151
"- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n",
152152
"- flash-attention supports KV-caching and paged attention, and cuDNN attention does not.\n",
153153
"- flash-attention uses bottom right diagonal for `causal` mask in cross attention (see [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)), and cuDNN attention supports both top left and bottom right.\n",
154+
"- **Sliding window attention (SWA):** flash-attention has SWA(left, right) support for all mask types except top-left causal masks, with or without dropout, and without bias. cuDNN attention supports SWA(left, 0) starting from 9.2 and SWA(left, right) starting from 9.6, without dropout, and with `bias_type=\"no_bias\"`.\n",
154155
"- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n",
155156
"\n",
156157
"To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0."
@@ -389,7 +390,7 @@
389390
"\n",
390391
"| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n",
391392
"| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n",
392-
"| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n",
393+
"| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | Yes (cuDNN 9.2+) | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n",
393394
"| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes | Yes (`bshd`,`thd`) | Yes |\n",
394395
"| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n",
395396
"\n",

docs/examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"source": [
2929
"### Question 1: Why choose Striped>1 ?\n",
3030
"\n",
31+
"\n",
3132
"Prior to the addition of this feature, Transformer Engine JAX attention already supported load balancing via a striping pattern, i.e., `stripe_size=1` for `CP + THD + P2P(Ring) + Striped + SWA`. However, this reordering technique does not lend itself well to an all-gathered (post-AG) pattern. The following example illustrates this distinction. For this example, `cp_size=4`, `num_segments=4`, `window_size=(8,0)`, and the pattern is for a single rank after striped reordering has been performed: \n",
3233
"\n",
3334
"#### I. Striped (`stripe_size=1`)\n",

examples/jax/collective_gemm/common.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,6 @@ def _initialize_distributed(args):
131131
)
132132

133133
_distributed_initialized = True
134-
jax.clear_caches()
135-
jax.config.update(
136-
"jax_use_shardy_partitioner", False
137-
) # CollectiveGEMM does not work with Shardy yet
138134

139135
assert jax.local_device_count() == 1, (
140136
f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found"

examples/jax/collective_gemm/test_gemm.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_shard
8888
def run_gemm_tests(args, mesh=None):
8989
"""Execute GEMM tests."""
9090
print(args)
91-
# Collective GEMM requires Shardy partitioner to be disabled
92-
jax.config.update("jax_use_shardy_partitioner", False)
9391

9492
# Initialize distributed with provided arguments
9593
_initialize_distributed(args)
@@ -137,8 +135,7 @@ def run_gemm_tests(args, mesh=None):
137135
bias_sharded,
138136
contracting_dims=((2,), (0,)),
139137
collective_op=collective_op,
140-
# CollectiveGEMM output should have a correct sharding without applying sharding constraint
141-
output_sharding=None,
138+
output_sharding=output_sharding,
142139
)
143140
assert (
144141
ref_output.sharding == output.sharding

examples/jax/collective_gemm/test_layernorm_mlp_grad.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ def _value_and_grad_layernorm_mlp(
119119
def run_layernorm_mlp_grad_tests(args, mesh=None):
120120
"""Execute Dense Gradient tests."""
121121
print(args)
122-
# Collective GEMM requires Shardy partitioner to be disabled
123-
jax.config.update("jax_use_shardy_partitioner", False)
124122

125123
# Initialize distributed with provided arguments
126124
_initialize_distributed(args)

examples/jax/encoder/run_test_multiprocessing_encoder.sh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@ TEST_CASES=(
1111
"test_te_current_scaling_fp8"
1212
"test_te_mxfp8"
1313
"test_te_nvfp4"
14-
"test_te_bf16_shardy"
15-
"test_te_delayed_scaling_fp8_shardy"
16-
"test_te_current_scaling_fp8_shardy"
17-
"test_te_nvfp4_shardy"
1814
)
1915

2016
: ${TE_PATH:=/opt/transformerengine}

examples/jax/encoder/test_model_parallel_encoder.py

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def check_fp8(state, var_collect, inputs, masks, labels):
239239
def train_and_evaluate(args):
240240
"""Execute model training and evaluation loop."""
241241
print(args)
242-
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
243242

244243
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
245244

@@ -474,9 +473,6 @@ def encoder_parser(args):
474473
parser.add_argument(
475474
"--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
476475
)
477-
parser.add_argument(
478-
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
479-
)
480476

481477
return parser.parse_args(args)
482478

@@ -559,70 +555,6 @@ def test_te_nvfp4_with_sp(self):
559555
actual = train_and_evaluate(self.args)
560556
assert actual[0] < 0.40 and actual[1] > 0.82
561557

562-
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
563-
def test_te_bf16_shardy(self):
564-
"""Test Transformer Engine with BF16"""
565-
self.args.enable_shardy = True
566-
actual = train_and_evaluate(self.args)
567-
assert actual[0] < 0.36 and actual[1] > 0.84
568-
569-
@unittest.skipIf(not is_fp8_supported, fp8_reason)
570-
def test_te_delayed_scaling_fp8_shardy(self):
571-
"""Test Transformer Engine with DelayedScaling FP8"""
572-
self.args.enable_shardy = True
573-
self.args.use_fp8 = True
574-
self.args.fp8_recipe = "DelayedScaling"
575-
actual = train_and_evaluate(self.args)
576-
assert actual[0] < 0.362 and actual[1] > 0.84
577-
578-
@unittest.skipIf(not is_fp8_supported, fp8_reason)
579-
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
580-
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
581-
self.args.enable_shardy = True
582-
self.args.enable_sp = True
583-
self.args.use_fp8 = True
584-
self.args.fp8_recipe = "DelayedScaling"
585-
actual = train_and_evaluate(self.args)
586-
assert actual[0] < 0.362 and actual[1] > 0.84
587-
588-
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
589-
def test_te_mxfp8_shardy(self):
590-
"""Test Transformer Engine with MXFP8"""
591-
self.args.enable_shardy = True
592-
self.args.use_fp8 = True
593-
self.args.fp8_recipe = "MXFP8BlockScaling"
594-
actual = train_and_evaluate(self.args)
595-
assert actual[0] < 0.36 and actual[1] > 0.84
596-
597-
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
598-
def test_te_nvfp4_shardy(self):
599-
"""Test Transformer Engine with NVFP4"""
600-
self.args.enable_shardy = True
601-
self.args.use_fp8 = True
602-
self.args.fp8_recipe = "NVFP4BlockScaling"
603-
actual = train_and_evaluate(self.args)
604-
assert actual[0] < 0.40 and actual[1] > 0.82
605-
606-
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
607-
def test_te_mxfp8_with_sp_shardy(self):
608-
"""Test Transformer Engine with MXFP8 + SP"""
609-
self.args.enable_shardy = True
610-
self.args.enable_sp = True
611-
self.args.use_fp8 = True
612-
self.args.fp8_recipe = "MXFP8BlockScaling"
613-
actual = train_and_evaluate(self.args)
614-
assert actual[0] < 0.36 and actual[1] > 0.84
615-
616-
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
617-
def test_te_nvfp4_with_sp_shardy(self):
618-
"""Test Transformer Engine with NVFP4"""
619-
self.args.enable_shardy = True
620-
self.args.enable_sp = True
621-
self.args.use_fp8 = True
622-
self.args.fp8_recipe = "NVFP4BlockScaling"
623-
actual = train_and_evaluate(self.args)
624-
assert actual[0] < 0.40 and actual[1] > 0.82
625-
626558

627559
if __name__ == "__main__":
628560
train_and_evaluate(encoder_parser(None))

0 commit comments

Comments
 (0)