diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 0258951d4d..b372d39879 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93 +Subproject commit b372d39879d44c91a8d5b342022e74802b6a8da2 diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index ee9ce130aa..2ab04eae39 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" mkdir -p "$XML_LOG_DIR" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_deterministic.xml $TE_PATH/tests/jax/test_fused_attn.py || test_fail "tests/jax/test_fused_attn.py" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 21eed28367..a13dfada79 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -45,6 +45,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index a0aee50430..e5c0a66735 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -413,15 +413,25 @@ def _check_configs(self): pytest.skip( "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) - # TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support - if ( - get_device_compute_capability(0) >= 100 - and self.dropout_prob == 0.1 - and self.attn_bias_type is not AttnBiasType.NO_BIAS - ): - pytest.skip( - "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" - ) + + if get_device_compute_capability(0) >= 100 and self.is_training: + if FusedAttnHelper.is_non_deterministic_allowed() and ( + (self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS) + or get_cudnn_version() < 90700 + ): + pytest.skip( + "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with" + " dropout" + ) + if not FusedAttnHelper.is_non_deterministic_allowed() and ( + self.dropout_prob != 0.0 + or self.attn_bias_type != AttnBiasType.NO_BIAS + or get_cudnn_version() < 91801 + ): + pytest.skip( + "For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or" + " dropout" + ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate(): diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index eb7905bcd5..9111d3511c 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -72,6 +72,14 @@ f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}" ) + +# Get determinism +_deterministic = ( + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() +) + + # Reset RNG seed and states seed = 1234 reset_rng_states() @@ -160,6 +168,7 @@ def test_dot_product_attention( qkv_layout=qkv_layout, pad_between_seqs=pad_between_seqs, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: @@ -170,6 +179,7 @@ def test_dot_product_attention( qkv_layout=qkv_layout, pad_between_seqs=pad_between_seqs, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends @@ -886,11 +896,14 @@ def _run_dot_product_attention( reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True # Create seqlens @@ -1292,6 +1305,7 @@ def test_transformer_layer( qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd") ), is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: @@ -1305,6 +1319,7 @@ def test_transformer_layer( else qkv_format.replace("hd", "3hd") ), is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends @@ -1432,10 +1447,13 @@ def _run_transformer_layer( reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True # Create input tensor @@ -1629,6 +1647,7 @@ def test_dpa_fp8_extra_state(model, dtype): qkv_dtype=torch.float8_e4m3fn, qkv_layout="sb3hd", is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported and not flash_attn_supported: @@ -1819,6 +1838,7 @@ def test_mha_fp8_vs_f16( fp8=True, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends if flash_attn_supported + fused_attn_supported_fp8 < 1: @@ -1830,6 +1850,7 @@ def test_mha_fp8_vs_f16( qkv_dtype=dtype, qkv_layout=qkv_format.replace("hd", "h3d"), is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported_f16, _ = available_backends if not fused_attn_supported_f16: @@ -1838,6 +1859,7 @@ def test_mha_fp8_vs_f16( if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( @@ -1847,6 +1869,7 @@ def test_mha_fp8_vs_f16( if fused_attn_supported_fp8: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( @@ -1856,6 +1879,7 @@ def test_mha_fp8_vs_f16( if fused_attn_supported_f16: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( @@ -2068,6 +2092,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fp8=True, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if flash_attn_supported + fused_attn_supported < 1: @@ -2078,6 +2103,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal qkv_dtype=dtype, qkv_layout=qkv_layout, is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -2088,6 +2114,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)") flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( @@ -2097,6 +2124,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if unfused_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( @@ -2105,6 +2133,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( @@ -2113,6 +2142,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if config.dropout_p == 0.0: # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") @@ -2367,13 +2397,16 @@ def test_custom_mha_fp8_vs_f16(dtype, model): qkv_dtype=torch.float8_e4m3fn, qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd", is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not (fused_attn_backends and unfused_attn_supported): pytest.skip("Not enough backends to run this test with.") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") - unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") + unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16( + dtype, config, "UnfusedDotProductAttention" + ) atol = 5e-1 rtol = 5e-1 @@ -2406,10 +2439,13 @@ def _run_custom_mha_fp8(dtype, config, backend): reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True inp = 0.0001 * torch.randint( @@ -2460,10 +2496,13 @@ def _run_ref_mha_f16(dtype, config, backend): os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True inp = torch.load("qkv.pt").to(device="cuda") diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index fde0d38921..415bfae063 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -206,7 +206,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -440,7 +440,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.13.1+: vanilla, off-by-one, learnable (cudnn_runtime_version >= 91301 || (cudnn_runtime_version < 91301 && - softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) { + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && + // determinism on Blackwell + // pre-9.18.1: fwd: deterministic; bwd: non-deterministic + // 9.18.1+: fwd: deterministic; bwd: non-deterministic/deterministic + (sm_arch_ < 100 || + (sm_arch_ >= 100 && (!is_training || + (is_training && !deterministic && + (dropout == 0.0 || bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)) || + (is_training && deterministic && cudnn_runtime_version >= 91801 && + dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { @@ -553,7 +562,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, - cuda_graph); + cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -595,7 +604,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, wkspace, stream, handle); #else NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) @@ -669,7 +679,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph); + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph, + deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -855,7 +866,7 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -897,7 +908,8 @@ void nvte_fused_attn_fwd_kvpacked( input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( - "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) @@ -982,10 +994,10 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, - softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - d, window_size_left, window_size_right, false, cuda_graph); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false, + cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -1166,7 +1178,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -1189,7 +1201,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) @@ -1262,7 +1275,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph); + cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index dd70ccf8df..0fabb81aef 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -208,13 +208,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] window_size_right Sliding window size (the right half). * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. + * \param[in] deterministic Whether determinism is required or not. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph); + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); /*! \brief Compute dot product attention with packed QKV input. * diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 0cdfcebf38..ee10115aa1 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -144,6 +144,7 @@ def get_fused_attn_backend(self): self.head_dim_v, self.window_size[0], self.window_size[1], + not self.is_non_deterministic_allowed(), ) @staticmethod @@ -3563,13 +3564,21 @@ def fused_attn_bwd( softmax_offset, (None, HEAD_AXES, None, None) ) - # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on - # sm100+ compute_capabilities = get_all_device_compute_capability() - if any(x >= 100 for x in compute_capabilities): - assert not ( - attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 - ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" + if any(x >= 100 for x in compute_capabilities) and is_training: + assert ( + FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() >= (9, 7, 0) + and (attn_bias_type == AttnBiasType.NO_BIAS or dropout_probability == 0.0) + ) or ( + not FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() >= (9, 18, 1) + and attn_bias_type == AttnBiasType.NO_BIAS + and dropout_probability == 0.0 + ), ( + "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout," + " and deterministic bprop (cuDNN 9.18.1+) does not support bias or dropout" + ) fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index a83a1e0a80..5f93392633 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -113,7 +113,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right); + int64_t window_size_right, bool deterministic); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 540aeb8b2d..4fe8e728a3 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -16,12 +16,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right) { + int64_t window_size_right, bool deterministic) { auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, deterministic); return backend; } @@ -266,7 +266,7 @@ static void FusedAttnForwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, deterministic); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -522,7 +522,7 @@ static void FusedAttnBackwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, deterministic); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bf19388d7e..cb74a15e77 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -994,6 +994,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt window_size[1], return_max_logit, cuda_graph, + deterministic, ) if fused_attention_backend == FusedAttnBackend["No_Backend"]: logger.debug("Disabling FusedAttention as no backend supports the provided input") @@ -1064,10 +1065,6 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") use_fused_attention = False fused_attention_backend = None - if is_training and device_compute_capability >= (10, 0): - logger.debug("Disabling FusedAttention for determinism reasons on Blackwell") - use_fused_attention = False - fused_attention_backend = None # use_flash_attention may have been set above use_flash_attention_2 = use_flash_attention and use_flash_attention_2 diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 9dc0d1f37b..591c89f83f 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -81,7 +81,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph); + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index b455e03757..be645d91b9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -45,12 +45,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, deterministic); return fused_attention_backend; }