Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
259662b
update FE to 1.17
cyanguwa Jan 10, 2026
c243794
add determinism flag
cyanguwa Jan 10, 2026
5578a60
add determinism to test
cyanguwa Jan 12, 2026
9bc1d64
add determinism to qa/
cyanguwa Jan 12, 2026
b1bdab7
move bias/dbias/versioning/dropout logic to C API
cyanguwa Jan 12, 2026
ea109c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2026
70fc94b
Update qa/L0_pytorch_unittest/test.sh
cyanguwa Jan 12, 2026
e82bd96
add determinism to Jax extension
cyanguwa Jan 13, 2026
8365962
add determinism to Jax tests
cyanguwa Jan 13, 2026
c7db02b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2026
4aaa627
Update tests/jax/test_fused_attn.py
cyanguwa Jan 13, 2026
0ee6b87
Update transformer_engine/common/fused_attn/fused_attn.cpp
cyanguwa Jan 13, 2026
4bd5e95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2026
bd31e01
fix the AI fixes
cyanguwa Jan 13, 2026
eb2e055
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2026
6f0e515
fix Jax extension call
cyanguwa Jan 13, 2026
2c22cbf
minor fixes based on comments
cyanguwa Jan 13, 2026
279f2f6
Merge branch 'main' into blackwell_determinism
cyanguwa Jan 13, 2026
aae98f3
fix selection logic and fwd arg
cyanguwa Jan 14, 2026
b962d32
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2026
1068594
fix version check in Jax test
cyanguwa Jan 14, 2026
fed07f2
Merge branch 'main' into blackwell_determinism
cyanguwa Jan 14, 2026
c51cf44
fix pytorch CI failures
cyanguwa Jan 15, 2026
3885684
fix Jax CI failures
cyanguwa Jan 15, 2026
8bf3a0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2026
f3a0234
Merge branch 'main' into blackwell_determinism
cyanguwa Jan 15, 2026
f526569
fix non-/determinism logic and CI
cyanguwa Jan 15, 2026
0cb374a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2026
9162ff6
fix formatting
cyanguwa Jan 15, 2026
ee90c5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2026
77d0f2a
Update transformer_engine/common/fused_attn/fused_attn.cpp
cyanguwa Jan 15, 2026
6a3346f
Merge branch 'main' into blackwell_determinism
cyanguwa Jan 18, 2026
65a67c6
update to 9.18.1 for requirement
cyanguwa Jan 19, 2026
7187d02
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 102 files
1 change: 1 addition & 0 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
mkdir -p "$XML_LOG_DIR"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_deterministic.xml $TE_PATH/tests/jax/test_fused_attn.py || test_fail "tests/jax/test_fused_attn.py"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this will first run the non-deterministic fused attn tests as part of L31, which runs all non distributed tests, followed by running the fused attn deterministic tests as part of L32.
Is that the intention ? - to run fused attn 2x - with and without determinism ?

That will greatly increase our test time and might be unnecessary. The last pipeline launched was for L1 so I am unsure that I can track the effect this change will have on timing as this is an L0 change. Could you report that in the PR please ?
Thanks !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could come with an approach that runs half the fused attn tests deterministically and the other half non-deterministically ?
Or run all deterministically only ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this extra line tests test_fused_attn.py with determinism, while the line before tests everything with non-determinism. The extra test_fused_attn.py test takes ~20mins on Blackwell:

================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test_backward                                                | 5040x | 1336.28s | avg:   0.27s
================================================================================
TOTAL RUNTIME                                                |      | 1336.28s |
================================================================================


pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
Expand Down
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
Expand Down
28 changes: 19 additions & 9 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,15 +413,25 @@ def _check_configs(self):
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
if (
get_device_compute_capability(0) >= 100
and self.dropout_prob == 0.1
and self.attn_bias_type is not AttnBiasType.NO_BIAS
):
pytest.skip(
"For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
)

if get_device_compute_capability(0) >= 100 and self.is_training:
if FusedAttnHelper.is_non_deterministic_allowed() and (
(self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS)
or get_cudnn_version() < 90700
):
pytest.skip(
"For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with"
" dropout"
)
if not FusedAttnHelper.is_non_deterministic_allowed() and (
self.dropout_prob != 0.0
or self.attn_bias_type != AttnBiasType.NO_BIAS
or get_cudnn_version() < 91801
):
pytest.skip(
"For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or"
" dropout"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
Expand Down
41 changes: 40 additions & 1 deletion tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@
f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
)


# Get determinism
_deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
or torch.are_deterministic_algorithms_enabled()
)


# Reset RNG seed and states
seed = 1234
reset_rng_states()
Expand Down Expand Up @@ -160,6 +168,7 @@ def test_dot_product_attention(
qkv_layout=qkv_layout,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
Expand All @@ -170,6 +179,7 @@ def test_dot_product_attention(
qkv_layout=qkv_layout,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

Expand Down Expand Up @@ -886,11 +896,14 @@ def _run_dot_product_attention(
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True

# Create seqlens
Expand Down Expand Up @@ -1292,6 +1305,7 @@ def test_transformer_layer(
qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
),
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
Expand All @@ -1305,6 +1319,7 @@ def test_transformer_layer(
else qkv_format.replace("hd", "3hd")
),
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

Expand Down Expand Up @@ -1432,10 +1447,13 @@ def _run_transformer_layer(
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True

# Create input tensor
Expand Down Expand Up @@ -1629,6 +1647,7 @@ def test_dpa_fp8_extra_state(model, dtype):
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="sb3hd",
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported and not flash_attn_supported:
Expand Down Expand Up @@ -1819,6 +1838,7 @@ def test_mha_fp8_vs_f16(
fp8=True,
fp8_meta=fp8_meta,
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported_fp8 < 1:
Expand All @@ -1830,6 +1850,7 @@ def test_mha_fp8_vs_f16(
qkv_dtype=dtype,
qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training,
deterministic=_deterministic,
)
_, fused_attn_supported_f16, _ = available_backends
if not fused_attn_supported_f16:
Expand All @@ -1838,6 +1859,7 @@ def test_mha_fp8_vs_f16(
if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
Expand All @@ -1847,6 +1869,7 @@ def test_mha_fp8_vs_f16(
if fused_attn_supported_fp8:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
Expand All @@ -1856,6 +1879,7 @@ def test_mha_fp8_vs_f16(
if fused_attn_supported_f16:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
Expand Down Expand Up @@ -2068,6 +2092,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
fp8=True,
fp8_meta=fp8_meta,
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported < 1:
Expand All @@ -2078,6 +2103,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
deterministic=_deterministic,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
Expand All @@ -2088,6 +2114,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
Expand All @@ -2097,6 +2124,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
if unfused_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)")
unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
Expand All @@ -2105,6 +2133,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
Expand All @@ -2113,6 +2142,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if config.dropout_p == 0.0:
# test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)")
Expand Down Expand Up @@ -2367,13 +2397,16 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not (fused_attn_backends and unfused_attn_supported):
pytest.skip("Not enough backends to run this test with.")

fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(
dtype, config, "UnfusedDotProductAttention"
)

atol = 5e-1
rtol = 5e-1
Expand Down Expand Up @@ -2406,10 +2439,13 @@ def _run_custom_mha_fp8(dtype, config, backend):
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True

inp = 0.0001 * torch.randint(
Expand Down Expand Up @@ -2460,10 +2496,13 @@ def _run_ref_mha_f16(dtype, config, backend):

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True

inp = torch.load("qkv.pt").to(device="cuda")
Expand Down
Loading
Loading