From 259662b11bd702203d91c521746ea3bf6dbf4ef9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 10 Jan 2026 00:31:53 -0800 Subject: [PATCH 01/30] update FE to 1.17 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 0258951d4d..b372d39879 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93 +Subproject commit b372d39879d44c91a8d5b342022e74802b6a8da2 From c2437949886fd955e1b0e2c9c3c50f4f4a384ed7 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 10 Jan 2026 00:35:37 -0800 Subject: [PATCH 02/30] add determinism flag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 14 +++++++------- .../common/include/transformer_engine/fused_attn.h | 3 ++- .../jax/csrc/extensions/attention.cpp | 6 +++--- .../attention/dot_product_attention/utils.py | 3 ++- transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/attention.cpp | 4 ++-- 6 files changed, 17 insertions(+), 15 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index fde0d38921..980ae42112 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -206,7 +206,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -553,7 +553,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, - cuda_graph); + cuda_graph, true); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -669,7 +669,7 @@ void nvte_fused_attn_bwd_qkvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph); + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -855,7 +855,7 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, true); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -985,7 +985,7 @@ void nvte_fused_attn_bwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - d, window_size_left, window_size_right, false, cuda_graph); + d, window_size_left, window_size_right, false, cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -1166,7 +1166,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, true); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -1262,7 +1262,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph); + cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index dd70ccf8df..0fabb81aef 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -208,13 +208,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] window_size_right Sliding window size (the right half). * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. + * \param[in] deterministic Whether determinism is required or not. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph); + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); /*! \brief Compute dot product attention with packed QKV input. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 540aeb8b2d..af76fd8a58 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -21,7 +21,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, true); return backend; } @@ -266,7 +266,7 @@ static void FusedAttnForwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, true); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -522,7 +522,7 @@ static void FusedAttnBackwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, true); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bf19388d7e..98aa4fe328 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -994,6 +994,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt window_size[1], return_max_logit, cuda_graph, + deterministic, ) if fused_attention_backend == FusedAttnBackend["No_Backend"]: logger.debug("Disabling FusedAttention as no backend supports the provided input") @@ -1064,7 +1065,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") use_fused_attention = False fused_attention_backend = None - if is_training and device_compute_capability >= (10, 0): + if is_training and device_compute_capability >= (10, 0) and (cudnn_version < (9, 18, 0) or core_attention_bias_type != "no_bias"): logger.debug("Disabling FusedAttention for determinism reasons on Blackwell") use_fused_attention = False fused_attention_backend = None diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 52ef02a347..b64658439c 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph); + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); std::pair quantizer_helper(py::handle quantizer, const std::vector &shape, DType dtype, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index b455e03757..be645d91b9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -45,12 +45,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, deterministic); return fused_attention_backend; } From 5578a60214b1c22822a48c0c3921a4117ad94c8d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 11 Jan 2026 23:37:12 -0800 Subject: [PATCH 03/30] add determinism to test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 37 +++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index eb7905bcd5..fa26b59390 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -72,6 +72,14 @@ f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}" ) + +# Get determinism +_deterministic = ( + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() + ) + + # Reset RNG seed and states seed = 1234 reset_rng_states() @@ -160,6 +168,7 @@ def test_dot_product_attention( qkv_layout=qkv_layout, pad_between_seqs=pad_between_seqs, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: @@ -170,6 +179,7 @@ def test_dot_product_attention( qkv_layout=qkv_layout, pad_between_seqs=pad_between_seqs, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends @@ -886,11 +896,14 @@ def _run_dot_product_attention( reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True # Create seqlens @@ -1292,6 +1305,7 @@ def test_transformer_layer( qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd") ), is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: @@ -1305,6 +1319,7 @@ def test_transformer_layer( else qkv_format.replace("hd", "3hd") ), is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends @@ -1432,10 +1447,13 @@ def _run_transformer_layer( reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True # Create input tensor @@ -1629,6 +1647,7 @@ def test_dpa_fp8_extra_state(model, dtype): qkv_dtype=torch.float8_e4m3fn, qkv_layout="sb3hd", is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported and not flash_attn_supported: @@ -1819,6 +1838,7 @@ def test_mha_fp8_vs_f16( fp8=True, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends if flash_attn_supported + fused_attn_supported_fp8 < 1: @@ -1830,6 +1850,7 @@ def test_mha_fp8_vs_f16( qkv_dtype=dtype, qkv_layout=qkv_format.replace("hd", "h3d"), is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported_f16, _ = available_backends if not fused_attn_supported_f16: @@ -1838,6 +1859,7 @@ def test_mha_fp8_vs_f16( if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( @@ -1847,6 +1869,7 @@ def test_mha_fp8_vs_f16( if fused_attn_supported_fp8: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( @@ -1856,6 +1879,7 @@ def test_mha_fp8_vs_f16( if fused_attn_supported_f16: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( @@ -2068,6 +2092,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fp8=True, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if flash_attn_supported + fused_attn_supported < 1: @@ -2078,6 +2103,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal qkv_dtype=dtype, qkv_layout=qkv_layout, is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -2088,6 +2114,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)") flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( @@ -2097,6 +2124,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if unfused_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( @@ -2105,6 +2133,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( @@ -2113,6 +2142,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if config.dropout_p == 0.0: # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") @@ -2367,6 +2397,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): qkv_dtype=torch.float8_e4m3fn, qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd", is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not (fused_attn_backends and unfused_attn_supported): @@ -2406,10 +2437,13 @@ def _run_custom_mha_fp8(dtype, config, backend): reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True inp = 0.0001 * torch.randint( @@ -2460,10 +2494,13 @@ def _run_ref_mha_f16(dtype, config, backend): os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True inp = torch.load("qkv.pt").to(device="cuda") From 9bc1d645b4deda1481230b7674842a36ac0e3714 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 11 Jan 2026 23:37:34 -0800 Subject: [PATCH 04/30] add determinism to qa/ Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- qa/L0_pytorch_unittest/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 21eed28367..01a84b23e7 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -45,6 +45,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" From b1bdab7ce058d3b15249df355e567d375892d43d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 12 Jan 2026 00:43:00 -0800 Subject: [PATCH 05/30] move bias/dbias/versioning/dropout logic to C API Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/fused_attn/fused_attn.cpp | 3 ++- .../pytorch/attention/dot_product_attention/utils.py | 4 ---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 980ae42112..ebd0e44201 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -440,7 +440,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.13.1+: vanilla, off-by-one, learnable (cudnn_runtime_version >= 91301 || (cudnn_runtime_version < 91301 && - softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) { + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && + !(sm_arch_ >= 100 && is_training && deterministic && (cudnn_runtime_version < 91800 || bias_type != NVTE_Bias_Type::NVTE_NO_BIAS || dropout != 0.0))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 98aa4fe328..cb74a15e77 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1065,10 +1065,6 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") use_fused_attention = False fused_attention_backend = None - if is_training and device_compute_capability >= (10, 0) and (cudnn_version < (9, 18, 0) or core_attention_bias_type != "no_bias"): - logger.debug("Disabling FusedAttention for determinism reasons on Blackwell") - use_fused_attention = False - fused_attention_backend = None # use_flash_attention may have been set above use_flash_attention_2 = use_flash_attention and use_flash_attention_2 From ea109c23e29f379a4e73fca8f82592e0c9d7e487 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Jan 2026 08:51:52 +0000 Subject: [PATCH 06/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 6 +++--- .../common/fused_attn/fused_attn.cpp | 15 +++++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index fa26b59390..f34baad41c 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -75,9 +75,9 @@ # Get determinism _deterministic = ( - not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) - or torch.are_deterministic_algorithms_enabled() - ) + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() +) # Reset RNG seed and states diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index ebd0e44201..11fd1acfbc 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -441,7 +441,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 91301 || (cudnn_runtime_version < 91301 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && - !(sm_arch_ >= 100 && is_training && deterministic && (cudnn_runtime_version < 91800 || bias_type != NVTE_Bias_Type::NVTE_NO_BIAS || dropout != 0.0))) { + !(sm_arch_ >= 100 && is_training && deterministic && + (cudnn_runtime_version < 91800 || bias_type != NVTE_Bias_Type::NVTE_NO_BIAS || + dropout != 0.0))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { @@ -670,7 +672,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph, deterministic); + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph, + deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -983,10 +986,10 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, - softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - d, window_size_left, window_size_right, false, cuda_graph, deterministic); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false, + cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) From 70fc94b8fa6aa81349ada6b87459745a8b05a94a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 12 Jan 2026 01:09:50 -0800 Subject: [PATCH 07/30] Update qa/L0_pytorch_unittest/test.sh make .xml file specific to deterministic tests in qa/ Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- qa/L0_pytorch_unittest/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 01a84b23e7..cf97d891c4 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -45,7 +45,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" -NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" From e82bd96498a74abc84e0ff758040f59e1bef0eaa Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 13 Jan 2026 04:50:33 -0800 Subject: [PATCH 08/30] add determinism to Jax extension Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- qa/L0_jax_unittest/test.sh | 1 + transformer_engine/jax/cpp_extensions/attention.py | 1 + transformer_engine/jax/csrc/extensions.h | 2 +- transformer_engine/jax/csrc/extensions/attention.cpp | 4 ++-- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index ee9ce130aa..2ab04eae39 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" mkdir -p "$XML_LOG_DIR" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_deterministic.xml $TE_PATH/tests/jax/test_fused_attn.py || test_fail "tests/jax/test_fused_attn.py" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 0cdfcebf38..887ca16e56 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -144,6 +144,7 @@ def get_fused_attn_backend(self): self.head_dim_v, self.window_size[0], self.window_size[1], + not self.is_non_deterministic_allowed(), ) @staticmethod diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index a83a1e0a80..5f93392633 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -113,7 +113,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right); + int64_t window_size_right, bool deterministic); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index af76fd8a58..9276adb453 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -16,12 +16,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right) { + int64_t window_size_right, bool deterministic) { auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false, true); + false, false, deterministic); return backend; } From 8365962309aef50f70cb9c4c32e41604b50e3f68 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 13 Jan 2026 05:43:55 -0800 Subject: [PATCH 09/30] add determinism to Jax tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/jax/test_fused_attn.py | 10 ++++++---- transformer_engine/jax/cpp_extensions/attention.py | 10 ++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index a0aee50430..68fca5878e 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -413,14 +413,16 @@ def _check_configs(self): pytest.skip( "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) - # TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support if ( get_device_compute_capability(0) >= 100 - and self.dropout_prob == 0.1 - and self.attn_bias_type is not AttnBiasType.NO_BIAS + and self.is_training + and not FusedAttnHelper.is_non_deterministic_allowed() + and (self.dropout_prob != 0.0 + or self.attn_bias_type != AttnBiasType.NO_BIAS + or get_cudn_version() < (9, 18, 0)) ): pytest.skip( - "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" + "For sm100+, bprop determinism is only supported by cuDNN 9.18+ and for no_bias and dropout = 0.0" ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 887ca16e56..a186c357d0 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3564,13 +3564,11 @@ def fused_attn_bwd( softmax_offset, (None, HEAD_AXES, None, None) ) - # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on - # sm100+ compute_capabilities = get_all_device_compute_capability() - if any(x >= 100 for x in compute_capabilities): - assert not ( - attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 - ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" + if any(x >= 100 for x in compute_capabilities) and is_training and not FusedAttnHelper.is_non_deterministic_allowed(): + assert ( + attn_bias_type == AttnBiasType.NO_BIAS and dropout_probability == 0 and get_cudnn_version() >= (9, 18, 0) + ), "For sm100+, bprop determinism is only supported by cuDNN 9.18+ and for no_bias and dropout = 0.0" fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, From c7db02ba305e9f43ed83cf21455df7adeafd92b1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 13:44:44 +0000 Subject: [PATCH 10/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_fused_attn.py | 11 +++++++---- .../jax/cpp_extensions/attention.py | 15 ++++++++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 68fca5878e..2c0de9329d 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -417,12 +417,15 @@ def _check_configs(self): get_device_compute_capability(0) >= 100 and self.is_training and not FusedAttnHelper.is_non_deterministic_allowed() - and (self.dropout_prob != 0.0 - or self.attn_bias_type != AttnBiasType.NO_BIAS - or get_cudn_version() < (9, 18, 0)) + and ( + self.dropout_prob != 0.0 + or self.attn_bias_type != AttnBiasType.NO_BIAS + or get_cudn_version() < (9, 18, 0) + ) ): pytest.skip( - "For sm100+, bprop determinism is only supported by cuDNN 9.18+ and for no_bias and dropout = 0.0" + "For sm100+, bprop determinism is only supported by cuDNN 9.18+ and for no_bias and" + " dropout = 0.0" ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index a186c357d0..a680fac57b 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3565,10 +3565,19 @@ def fused_attn_bwd( ) compute_capabilities = get_all_device_compute_capability() - if any(x >= 100 for x in compute_capabilities) and is_training and not FusedAttnHelper.is_non_deterministic_allowed(): + if ( + any(x >= 100 for x in compute_capabilities) + and is_training + and not FusedAttnHelper.is_non_deterministic_allowed() + ): assert ( - attn_bias_type == AttnBiasType.NO_BIAS and dropout_probability == 0 and get_cudnn_version() >= (9, 18, 0) - ), "For sm100+, bprop determinism is only supported by cuDNN 9.18+ and for no_bias and dropout = 0.0" + attn_bias_type == AttnBiasType.NO_BIAS + and dropout_probability == 0 + and get_cudnn_version() >= (9, 18, 0) + ), ( + "For sm100+, bprop determinism is only supported by cuDNN 9.18+ and for no_bias and" + " dropout = 0.0" + ) fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, From 4aaa627184f481c9ab629971498e376c8695127d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 13 Jan 2026 06:00:42 -0800 Subject: [PATCH 11/30] Update tests/jax/test_fused_attn.py fix typo Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/jax/test_fused_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 2c0de9329d..89b9003843 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -420,7 +420,7 @@ def _check_configs(self): and ( self.dropout_prob != 0.0 or self.attn_bias_type != AttnBiasType.NO_BIAS - or get_cudn_version() < (9, 18, 0) + or get_cudnn_version() < (9, 18, 0) ) ): pytest.skip( From 0ee6b878f7f38b3e30ecbad6c648e685943201d2 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 13 Jan 2026 06:02:50 -0800 Subject: [PATCH 12/30] Update transformer_engine/common/fused_attn/fused_attn.cpp fix indentation Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/fused_attn/fused_attn.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 11fd1acfbc..197d7309a4 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -441,7 +441,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 91301 || (cudnn_runtime_version < 91301 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && - !(sm_arch_ >= 100 && is_training && deterministic && + !(sm_arch_ >= 100 && is_training && deterministic && (cudnn_runtime_version < 91800 || bias_type != NVTE_Bias_Type::NVTE_NO_BIAS || dropout != 0.0))) { (cudnn_runtime_version < 91800 || bias_type != NVTE_Bias_Type::NVTE_NO_BIAS || dropout != 0.0))) { flag_arb = true; From 4bd5e95102a58f0b910a1ac31080e375bfb1b72d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 14:03:32 +0000 Subject: [PATCH 13/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 1595 +++++++++-------- 1 file changed, 800 insertions(+), 795 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 197d7309a4..747a747343 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -441,897 +441,902 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 91301 || (cudnn_runtime_version < 91301 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && - !(sm_arch_ >= 100 && is_training && deterministic && (cudnn_runtime_version < 91800 || bias_type != NVTE_Bias_Type::NVTE_NO_BIAS || dropout != 0.0))) { + !(sm_arch_ >= 100 && is_training && deterministic && (cudnn_runtime_version < 91800 || bias_type != NVTE_Bias_Type::NVTE_NO_BIAS || dropout != 0.0))) { - flag_arb = true; - } - if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; - } - if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { - if (flag_arb == true) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; - } else if ((flag_arb == false) && (flag_m512 == true)) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen; - } - int env_backend = static_cast(backend); - env_backend = transformer_engine::getenv("NVTE_FUSED_ATTN_BACKEND", env_backend); - if (((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) && - flag_m512) || - ((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) && - flag_arb)) { - backend = static_cast(env_backend); - } - } - if (cudnn_runtime_version < 8901 && - backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } - if (cudnn_runtime_version < 8900 && - backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } - if ((cudnn_runtime_version == 91400) && (max_seqlen_kv > 1024) && (window_size_left != -1) && - (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of attention mask (non-causal) and " + (cudnn_runtime_version < 91800 || bias_type != NVTE_Bias_Type::NVTE_NO_BIAS || + dropout != 0.0))) { + flag_arb = true; + } + if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + } + if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + if (flag_arb == true) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + } else if ((flag_arb == false) && (flag_m512 == true)) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen; + } + int env_backend = static_cast(backend); + env_backend = transformer_engine::getenv("NVTE_FUSED_ATTN_BACKEND", env_backend); + if (((env_backend == + static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) && + flag_m512) || + ((env_backend == + static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) && + flag_arb)) { + backend = static_cast(env_backend); + } + } + if (cudnn_runtime_version < 8901 && + backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+." + " Please upgrade your cuDNN version if possible." + << std::endl; + } + if (cudnn_runtime_version < 8900 && + backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+." + " Please upgrade your cuDNN version if possible." + << std::endl; + } + if ((cudnn_runtime_version == 91400) && (max_seqlen_kv > 1024) && + (window_size_left != -1) && (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK) && + (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout + << "Warning: Given combination of attention mask (non-causal) and " "max_seqlen_kv (> 1024) does not support fused attention for cuDNN 9.14.0. " " Please upgrade your cuDNN version if possible." << std::endl; - } - if ((cudnn_runtime_version <= 91500) && is_training && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - (max_seqlen_kv % 128 != 0) && cuda_graph && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)) { + } + if ((cudnn_runtime_version <= 91500) && is_training && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || + qkv_format == NVTE_QKV_Format::NVTE_SBHD) && + (max_seqlen_kv % 128 != 0) && cuda_graph && + (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) && + (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && + (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: Given combination of attention mask (non-padding)," + " max_seqlen_kv (not divisible by 128), and qkv_format (BSHD/SBHD) for" + " backward fused attention with graph capture requires cuDNN 9.15.1+. " + "Please upgrade your cuDNN version if possible." + << std::endl; + } + } else { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of attention mask (non-padding)," - " max_seqlen_kv (not divisible by 128), and qkv_format (BSHD/SBHD) for" - " backward fused attention with graph capture requires cuDNN 9.15.1+. " - "Please upgrade your cuDNN version if possible." - << std::endl; } - } else { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + return backend; } - return backend; -} -// NVTE fused attention FWD with packed QKV -// DEPRECATED: This API is deprecated. -// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); - using namespace transformer_engine; - - const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); - const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); - const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); - const Tensor *input_QKV = convertNVTETensorCheck(QKV); - const Tensor *input_Bias = convertNVTETensorCheck(Bias); - const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); - Tensor *input_output_S = convertNVTETensorCheck(S); - Tensor *output_O = convertNVTETensorCheck(O); - Tensor *wkspace = convertNVTETensor(workspace); - - auto ndim = input_QKV->data.shape.size(); - size_t b = input_cu_seqlens->data.shape[0] - 1; - size_t h = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - h = input_QKV->data.shape[ndim - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - h = input_QKV->data.shape[ndim - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); - } - size_t d = input_QKV->data.shape[ndim - 1]; - size_t t = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - t = input_QKV->data.shape[0]; - } + // NVTE fused attention FWD with packed QKV + // DEPRECATED: This API is deprecated. + // Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. + void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, + bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); + using namespace transformer_engine; + + const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); + const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_QKV = convertNVTETensorCheck(QKV); + const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); + Tensor *input_output_S = convertNVTETensorCheck(S); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensor(workspace); + + auto ndim = input_QKV->data.shape.size(); + size_t b = input_cu_seqlens->data.shape[0] - 1; + size_t h = 0; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + h = input_QKV->data.shape[ndim - 2]; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + h = input_QKV->data.shape[ndim - 3]; + } else { + NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); + } + size_t d = input_QKV->data.shape[ndim - 1]; + size_t t = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t = input_QKV->data.shape[0]; + } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, - cuda_graph, true); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, + dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, + return_max_logit, cuda_graph, true); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - // Unpack QKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - fused_attn_max_512_fwd(b, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, - input_cu_seqlens, input_rng_state, wkspace, stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_max_512_fwd(b, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - // Unpack QKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - fused_attn_arbitrary_seqlen_fwd( - b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, - return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, - wkspace, stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_arbitrary_seqlen_fwd( + b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, + input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, + wkspace, stream, handle); #else - NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + NVTE_ERROR( + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - // Unpack QKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, - input_cu_seqlens, input_rng_state, wkspace, stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } -} -// NVTE fused attention BWD with packed QKV -// DEPRECATED: This API is deprecated. -// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. -void nvte_fused_attn_bwd_qkvpacked( - const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, - NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, - NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); - using namespace transformer_engine; - - const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); - const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); - const Tensor *input_QKV = convertNVTETensorCheck(QKV); - const Tensor *input_O = convertNVTETensorCheck(O); - const Tensor *input_dO = convertNVTETensorCheck(dO); - const Tensor *input_S = convertNVTETensorCheck(S); - Tensor *input_output_dP = convertNVTETensorCheck(dP); - Tensor *output_dQKV = convertNVTETensorCheck(dQKV); - Tensor *output_dBias = convertNVTETensorCheck(dBias); - Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); - Tensor *wkspace = convertNVTETensor(workspace); - - auto ndim = input_QKV->data.shape.size(); - size_t b = input_cu_seqlens->data.shape[0] - 1; - size_t h = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - h = input_QKV->data.shape[ndim - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - h = input_QKV->data.shape[ndim - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); - } - size_t d = input_QKV->data.shape[ndim - 1]; - size_t t = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - t = input_QKV->data.shape[0]; + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } } + // NVTE fused attention BWD with packed QKV + // DEPRECATED: This API is deprecated. + // Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. + void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); + using namespace transformer_engine; + + const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); + const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); + const Tensor *input_QKV = convertNVTETensorCheck(QKV); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + const Tensor *input_S = convertNVTETensorCheck(S); + Tensor *input_output_dP = convertNVTETensorCheck(dP); + Tensor *output_dQKV = convertNVTETensorCheck(dQKV); + Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); + Tensor *wkspace = convertNVTETensor(workspace); + + auto ndim = input_QKV->data.shape.size(); + size_t b = input_cu_seqlens->data.shape[0] - 1; + size_t h = 0; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + h = input_QKV->data.shape[ndim - 2]; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + h = input_QKV->data.shape[ndim - 3]; + } else { + NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); + } + size_t d = input_QKV->data.shape[ndim - 1]; + size_t t = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t = input_QKV->data.shape[0]; + } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph, - deterministic); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, + h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph, + deterministic); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - - // Unpack QKV and dQKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V and dQ, dK, dV - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); - Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); - Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); - - fused_attn_max_512_bwd(b, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_dO, output_S, - &dQ_view, &dK_view, &dV_view, output_dBias, input_cu_seqlens, - input_cu_seqlens, wkspace, stream, handle); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_max_512_bwd(b, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_dO, + output_S, &dQ_view, &dK_view, &dV_view, output_dBias, input_cu_seqlens, + input_cu_seqlens, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - size_t i = 0; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_Bias, *input_SoftmaxOffset; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } - // Unpack QKV and dQKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V and dQ, dK, dV - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); - Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); - Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); - - fused_attn_arbitrary_seqlen_bwd( - b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, - &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, - &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_arbitrary_seqlen_bwd( + b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, + &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, + &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, + input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, + wkspace, stream, handle); #else - const char *err_msg = - "cuDNN 8.9.0 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; - NVTE_ERROR(err_msg); + const char *err_msg = + "cuDNN 8.9.0 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; + NVTE_ERROR(err_msg); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - - // Unpack QKV and dQKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V and dQ, dK, dV - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); - Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); - Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); - - fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, - input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, - input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, - handle); + const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, + input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, + &dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, + stream, handle); #else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } -} -// NVTE fused attention FWD with packed KV -// DEPRECATED: This API is deprecated. -// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. -void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, - NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); - using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); - const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); - const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_KV = convertNVTETensorCheck(KV); - const Tensor *input_Bias = convertNVTETensorCheck(Bias); - const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); - Tensor *input_output_S = convertNVTETensorCheck(S); - Tensor *output_O = convertNVTETensorCheck(O); - Tensor *wkspace = convertNVTETensor(workspace); - - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - auto ndim = input_Q->data.shape.size(); - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; - auto ndim_kv = input_KV->data.shape.size(); - size_t h_kv = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - h_kv = input_KV->data.shape[ndim_kv - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - h_kv = input_KV->data.shape[ndim_kv - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); - } - size_t t_q = 0; - size_t t_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_KV->data.shape[0]; - } - int64_t num_pages_k = 0; - int64_t num_pages_v = 0; - int64_t page_size_k = 0; - int64_t page_size_v = 0; - int64_t max_pages_per_seq_k = 0; - int64_t max_pages_per_seq_v = 0; - if (input_page_table_k->data.dptr != nullptr) { - max_pages_per_seq_k = input_page_table_k->data.shape[1]; - } - if (input_page_table_v->data.dptr != nullptr) { - max_pages_per_seq_v = input_page_table_v->data.shape[1]; + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } } - if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { + // NVTE fused attention FWD with packed KV + // DEPRECATED: This API is deprecated. + // Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. + void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); + const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_KV = convertNVTETensorCheck(KV); + const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); + Tensor *input_output_S = convertNVTETensorCheck(S); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensor(workspace); + + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + auto ndim = input_Q->data.shape.size(); + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t d = input_Q->data.shape[ndim - 1]; + auto ndim_kv = input_KV->data.shape.size(); + size_t h_kv = 0; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + h_kv = input_KV->data.shape[ndim_kv - 2]; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + h_kv = input_KV->data.shape[ndim_kv - 3]; + } else { + NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); + } + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { - num_pages_k = input_KV->data.shape[0]; - page_size_k = input_KV->data.shape[1]; - num_pages_v = num_pages_v; - page_size_v = page_size_v; - } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { - num_pages_k = input_KV->data.shape[1]; - page_size_k = input_KV->data.shape[0]; - num_pages_v = num_pages_v; - page_size_v = page_size_v; + if (q_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + t_kv = input_KV->data.shape[0]; + } + int64_t num_pages_k = 0; + int64_t num_pages_v = 0; + int64_t page_size_k = 0; + int64_t page_size_v = 0; + int64_t max_pages_per_seq_k = 0; + int64_t max_pages_per_seq_v = 0; + if (input_page_table_k->data.dptr != nullptr) { + max_pages_per_seq_k = input_page_table_k->data.shape[1]; + } + if (input_page_table_v->data.dptr != nullptr) { + max_pages_per_seq_v = input_page_table_v->data.shape[1]; + } + if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { + num_pages_k = input_KV->data.shape[0]; + page_size_k = input_KV->data.shape[1]; + num_pages_v = num_pages_v; + page_size_v = page_size_v; + } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { + num_pages_k = input_KV->data.shape[1]; + page_size_k = input_KV->data.shape[0]; + num_pages_v = num_pages_v; + page_size_v = page_size_v; + } } - } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_KV->data.dtype); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_KV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, - return_max_logit, cuda_graph, true); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, + return_max_logit, cuda_graph, true); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - // Unpack KV and call the non-packed function - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, + &V_view, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - // Unpack KV and call the non-packed function - const auto Q_type = input_Q->data.dtype; - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, - return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_arbitrary_seqlen_fwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR( - "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + NVTE_ERROR( + "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - // Unpack KV and call the non-packed function - const auto Q_type = input_Q->data.dtype; - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } -} -// NVTE fused attention BWD with packed KV -// DEPRECATED: This API is deprecated. -// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. -void nvte_fused_attn_bwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); - using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_KV = convertNVTETensorCheck(KV); - const Tensor *input_O = convertNVTETensorCheck(O); - const Tensor *input_dO = convertNVTETensorCheck(dO); - const Tensor *input_S = convertNVTETensorCheck(S); - Tensor *input_output_dP = convertNVTETensorCheck(dP); - Tensor *output_dQ = convertNVTETensorCheck(dQ); - Tensor *output_dKV = convertNVTETensorCheck(dKV); - Tensor *output_dBias = convertNVTETensorCheck(dBias); - Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); - Tensor *wkspace = convertNVTETensor(workspace); - - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - auto ndim = input_Q->data.shape.size(); - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; - auto ndim_kv = input_KV->data.shape.size(); - size_t h_kv = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - h_kv = input_KV->data.shape[ndim_kv - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - h_kv = input_KV->data.shape[ndim_kv - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); - } - size_t t_q = 0; - size_t t_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_KV->data.shape[0]; + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } } + // NVTE fused attention BWD with packed KV + // DEPRECATED: This API is deprecated. + // Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. + void nvte_fused_attn_bwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, + const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_KV = convertNVTETensorCheck(KV); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + const Tensor *input_S = convertNVTETensorCheck(S); + Tensor *input_output_dP = convertNVTETensorCheck(dP); + Tensor *output_dQ = convertNVTETensorCheck(dQ); + Tensor *output_dKV = convertNVTETensorCheck(dKV); + Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); + Tensor *wkspace = convertNVTETensor(workspace); + + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + auto ndim = input_Q->data.shape.size(); + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t d = input_Q->data.shape[ndim - 1]; + auto ndim_kv = input_KV->data.shape.size(); + size_t h_kv = 0; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + h_kv = input_KV->data.shape[ndim_kv - 2]; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + h_kv = input_KV->data.shape[ndim_kv - 3]; + } else { + NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); + } + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + t_kv = input_KV->data.shape[0]; + } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_KV->data.dtype); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_KV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false, - cuda_graph, deterministic); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false, + cuda_graph, deterministic); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - // Unpack KV and dKV and call the non-packed function - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + // Unpack KV and dKV and call the non-packed function + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); - Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); - fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_dO, output_S, - output_dQ, &dK_view, &dV_view, output_dBias, input_cu_seqlens_q, - input_cu_seqlens_kv, wkspace, stream, handle); + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + input_dO, output_S, output_dQ, &dK_view, &dV_view, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - size_t i = 0; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_Bias, *input_SoftmaxOffset; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } - // Unpack KV and dKV and call the non-packed function - const auto Q_type = input_Q->data.dtype; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - // Create tensor views for dK, dV - Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); - Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); - - fused_attn_arbitrary_seqlen_bwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, - input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, - output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, - wkspace, stream, handle); + // Unpack KV and dKV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + // Create tensor views for dK, dV + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + + fused_attn_arbitrary_seqlen_bwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, + deterministic, input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, + input_SoftmaxOffset, output_S, output_dQ, &dK_view, &dV_view, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else - const char *err_msg = - "cuDNN 8.9.3 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; - NVTE_ERROR(err_msg); + const char *err_msg = + "cuDNN 8.9.3 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; + NVTE_ERROR(err_msg); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - - // Unpack KV and dKV and call the non-packed function - const auto Q_type = input_Q->data.dtype; - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); - Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); - - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, - &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, - stream, handle); + const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + + // Unpack KV and dKV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, + input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, + &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } -} -// NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_fwd); - using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); - const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); - const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_K = convertNVTETensorCheck(K); - const Tensor *input_V = convertNVTETensorCheck(V); - const Tensor *input_Bias = convertNVTETensorCheck(Bias); - const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); - Tensor *input_output_S = convertNVTETensorCheck(S); - Tensor *output_O = convertNVTETensorCheck(O); - Tensor *wkspace = convertNVTETensor(workspace); - - auto ndim = input_Q->data.shape.size(); - auto ndim_kv = input_K->data.shape.size(); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; - size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim_kv - 1]; - size_t t_q = 0; - size_t t_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_K->data.shape[0]; - } - int64_t num_pages_k = 0; - int64_t num_pages_v = 0; - int64_t page_size_k = 0; - int64_t page_size_v = 0; - int64_t max_pages_per_seq_k = 0; - int64_t max_pages_per_seq_v = 0; - if (input_page_table_k->data.dptr != nullptr) { - max_pages_per_seq_k = input_page_table_k->data.shape[1]; - } - if (input_page_table_v->data.dptr != nullptr) { - max_pages_per_seq_v = input_page_table_v->data.shape[1]; + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } } - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { + // NVTE fused attention FWD with separate Q, K and V + void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); + const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_K = convertNVTETensorCheck(K); + const Tensor *input_V = convertNVTETensorCheck(V); + const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); + Tensor *input_output_S = convertNVTETensorCheck(S); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensor(workspace); + + auto ndim = input_Q->data.shape.size(); + auto ndim_kv = input_K->data.shape.size(); + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim_kv - 2]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim_kv - 1]; + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { - num_pages_k = input_K->data.shape[0]; - page_size_k = input_K->data.shape[1]; - num_pages_v = input_V->data.shape[0]; - page_size_v = input_V->data.shape[1]; - } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { - num_pages_k = input_K->data.shape[1]; - page_size_k = input_K->data.shape[0]; - num_pages_v = input_V->data.shape[1]; - page_size_v = input_V->data.shape[0]; + if (q_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + t_kv = input_K->data.shape[0]; + } + int64_t num_pages_k = 0; + int64_t num_pages_v = 0; + int64_t page_size_k = 0; + int64_t page_size_v = 0; + int64_t max_pages_per_seq_k = 0; + int64_t max_pages_per_seq_v = 0; + if (input_page_table_k->data.dptr != nullptr) { + max_pages_per_seq_k = input_page_table_k->data.shape[1]; + } + if (input_page_table_v->data.dptr != nullptr) { + max_pages_per_seq_v = input_page_table_v->data.shape[1]; + } + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { + num_pages_k = input_K->data.shape[0]; + page_size_k = input_K->data.shape[1]; + num_pages_v = input_V->data.shape[0]; + page_size_v = input_V->data.shape[1]; + } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { + num_pages_k = input_K->data.shape[1]; + page_size_k = input_K->data.shape[0]; + num_pages_v = input_V->data.shape[1]; + page_size_v = input_V->data.shape[0]; + } } - } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_K->data.dtype); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_K->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph, true); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, + return_max_logit, cuda_graph, true); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, - input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, + input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, - return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + fused_attn_arbitrary_seqlen_fwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + NVTE_ERROR( + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } -} -// NVTE fused attention BWD with separate Q, K and V -void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, - size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_bwd); - using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_K = convertNVTETensorCheck(K); - const Tensor *input_V = convertNVTETensorCheck(V); - const Tensor *input_O = convertNVTETensorCheck(O); - const Tensor *input_dO = convertNVTETensorCheck(dO); - const Tensor *input_S = convertNVTETensorCheck(S); - Tensor *input_output_dP = convertNVTETensorCheck(dP); - Tensor *output_dQ = convertNVTETensorCheck(dQ); - Tensor *output_dK = convertNVTETensorCheck(dK); - Tensor *output_dV = convertNVTETensorCheck(dV); - Tensor *output_dBias = convertNVTETensorCheck(dBias); - Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); - Tensor *wkspace = convertNVTETensor(workspace); - - auto ndim = input_Q->data.shape.size(); - auto ndim_kv = input_K->data.shape.size(); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; - size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim_kv - 1]; - size_t t_q = 0; - size_t t_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_K->data.shape[0]; + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } } + // NVTE fused attention BWD with separate Q, K and V + void nvte_fused_attn_bwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, + const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, + NVTETensor dQ, NVTETensor dK, NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_K = convertNVTETensorCheck(K); + const Tensor *input_V = convertNVTETensorCheck(V); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + const Tensor *input_S = convertNVTETensorCheck(S); + Tensor *input_output_dP = convertNVTETensorCheck(dP); + Tensor *output_dQ = convertNVTETensorCheck(dQ); + Tensor *output_dK = convertNVTETensorCheck(dK); + Tensor *output_dV = convertNVTETensorCheck(dV); + Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); + Tensor *wkspace = convertNVTETensor(workspace); + + auto ndim = input_Q->data.shape.size(); + auto ndim_kv = input_K->data.shape.size(); + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim_kv - 2]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim_kv - 1]; + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + t_kv = input_K->data.shape[0]; + } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_K->data.dtype); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_K->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph, deterministic); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, + cuda_graph, deterministic); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, - input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - size_t i = 0; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_Bias, *input_SoftmaxOffset; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } - fused_attn_arbitrary_seqlen_bwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, - input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, - output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + fused_attn_arbitrary_seqlen_bwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, + deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, + input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else - const char *err_msg = - "cuDNN 8.9.0 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; - NVTE_ERROR(err_msg); + const char *err_msg = + "cuDNN 8.9.0 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; + NVTE_ERROR(err_msg); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, + input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, + output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } } -} -uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len, - cudaStream_t stream) { - NVTE_API_CALL(nvte_get_runtime_num_segments); - using namespace transformer_engine::fused_attn; - return GetRuntimeNumSegments(cu_seqlen, workspace, len, stream); -} + uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len, + cudaStream_t stream) { + NVTE_API_CALL(nvte_get_runtime_num_segments); + using namespace transformer_engine::fused_attn; + return GetRuntimeNumSegments(cu_seqlen, workspace, len, stream); + } -void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed, - size_t q_max_seqlen, size_t kv_max_seqlen, - NVTE_Fused_Attn_Backend backend, cudaStream_t stream) { - NVTE_API_CALL(nvte_populate_rng_state_async); - using namespace transformer_engine::fused_attn; - PopulateRngStateAsync(rng_state_dst, seed, q_max_seqlen, kv_max_seqlen, backend, stream); -} + void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed, + size_t q_max_seqlen, size_t kv_max_seqlen, + NVTE_Fused_Attn_Backend backend, cudaStream_t stream) { + NVTE_API_CALL(nvte_populate_rng_state_async); + using namespace transformer_engine::fused_attn; + PopulateRngStateAsync(rng_state_dst, seed, q_max_seqlen, kv_max_seqlen, backend, stream); + } From bd31e01447d19d870d6c7cc73d9886d23a29a758 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 13 Jan 2026 06:12:18 -0800 Subject: [PATCH 14/30] fix the AI fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- qa/L0_pytorch_unittest/test.sh | 2 +- transformer_engine/common/fused_attn/fused_attn.cpp | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index cf97d891c4..a13dfada79 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -45,7 +45,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" -NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 747a747343..6f1fe269b7 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -441,9 +441,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 91301 || (cudnn_runtime_version < 91301 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && + // determinism on Blackwell !(sm_arch_ >= 100 && is_training && deterministic && - (cudnn_runtime_version < 91800 || bias_type != NVTE_Bias_Type::NVTE_NO_BIAS || - dropout != 0.0))) { (cudnn_runtime_version < 91800 || bias_type != NVTE_Bias_Type::NVTE_NO_BIAS || dropout != 0.0))) { flag_arb = true; From eb2e05548a7b3ed6050f768e3a25725018b9b3af Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 14:13:05 +0000 Subject: [PATCH 15/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 1594 ++++++++--------- 1 file changed, 797 insertions(+), 797 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 6f1fe269b7..3f561fc189 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -445,897 +445,897 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( !(sm_arch_ >= 100 && is_training && deterministic && (cudnn_runtime_version < 91800 || bias_type != NVTE_Bias_Type::NVTE_NO_BIAS || dropout != 0.0))) { - flag_arb = true; - } - if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; - } - if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { - if (flag_arb == true) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; - } else if ((flag_arb == false) && (flag_m512 == true)) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen; - } - int env_backend = static_cast(backend); - env_backend = transformer_engine::getenv("NVTE_FUSED_ATTN_BACKEND", env_backend); - if (((env_backend == - static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) && - flag_m512) || - ((env_backend == - static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) && - flag_arb)) { - backend = static_cast(env_backend); - } - } - if (cudnn_runtime_version < 8901 && - backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } - if (cudnn_runtime_version < 8900 && - backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } - if ((cudnn_runtime_version == 91400) && (max_seqlen_kv > 1024) && - (window_size_left != -1) && (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout - << "Warning: Given combination of attention mask (non-causal) and " + flag_arb = true; + } + if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + } + if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + if (flag_arb == true) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + } else if ((flag_arb == false) && (flag_m512 == true)) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen; + } + int env_backend = static_cast(backend); + env_backend = transformer_engine::getenv("NVTE_FUSED_ATTN_BACKEND", env_backend); + if (((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) && + flag_m512) || + ((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) && + flag_arb)) { + backend = static_cast(env_backend); + } + } + if (cudnn_runtime_version < 8901 && + backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+." + " Please upgrade your cuDNN version if possible." + << std::endl; + } + if (cudnn_runtime_version < 8900 && + backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+." + " Please upgrade your cuDNN version if possible." + << std::endl; + } + if ((cudnn_runtime_version == 91400) && (max_seqlen_kv > 1024) && (window_size_left != -1) && + (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK) && + (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: Given combination of attention mask (non-causal) and " "max_seqlen_kv (> 1024) does not support fused attention for cuDNN 9.14.0. " " Please upgrade your cuDNN version if possible." << std::endl; - } - if ((cudnn_runtime_version <= 91500) && is_training && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - (max_seqlen_kv % 128 != 0) && cuda_graph && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of attention mask (non-padding)," - " max_seqlen_kv (not divisible by 128), and qkv_format (BSHD/SBHD) for" - " backward fused attention with graph capture requires cuDNN 9.15.1+. " - "Please upgrade your cuDNN version if possible." - << std::endl; - } - } else { + } + if ((cudnn_runtime_version <= 91500) && is_training && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && + (max_seqlen_kv % 128 != 0) && cuda_graph && + (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) && + (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && + (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)) { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: Given combination of attention mask (non-padding)," + " max_seqlen_kv (not divisible by 128), and qkv_format (BSHD/SBHD) for" + " backward fused attention with graph capture requires cuDNN 9.15.1+. " + "Please upgrade your cuDNN version if possible." + << std::endl; } - return backend; + } else { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; } + return backend; +} - // NVTE fused attention FWD with packed QKV - // DEPRECATED: This API is deprecated. - // Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. - void nvte_fused_attn_fwd_qkvpacked( - const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, - bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); - using namespace transformer_engine; - - const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); - const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); - const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); - const Tensor *input_QKV = convertNVTETensorCheck(QKV); - const Tensor *input_Bias = convertNVTETensorCheck(Bias); - const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); - Tensor *input_output_S = convertNVTETensorCheck(S); - Tensor *output_O = convertNVTETensorCheck(O); - Tensor *wkspace = convertNVTETensor(workspace); - - auto ndim = input_QKV->data.shape.size(); - size_t b = input_cu_seqlens->data.shape[0] - 1; - size_t h = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - h = input_QKV->data.shape[ndim - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - h = input_QKV->data.shape[ndim - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); - } - size_t d = input_QKV->data.shape[ndim - 1]; - size_t t = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - t = input_QKV->data.shape[0]; - } +// NVTE fused attention FWD with packed QKV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); + using namespace transformer_engine; + + const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); + const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_QKV = convertNVTETensorCheck(QKV); + const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); + Tensor *input_output_S = convertNVTETensorCheck(S); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensor(workspace); + + auto ndim = input_QKV->data.shape.size(); + size_t b = input_cu_seqlens->data.shape[0] - 1; + size_t h = 0; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + h = input_QKV->data.shape[ndim - 2]; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + h = input_QKV->data.shape[ndim - 3]; + } else { + NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); + } + size_t d = input_QKV->data.shape[ndim - 1]; + size_t t = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t = input_QKV->data.shape[0]; + } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, - dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, - return_max_logit, cuda_graph, true); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, + cuda_graph, true); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - // Unpack QKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - fused_attn_max_512_fwd(b, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, - input_cu_seqlens, input_rng_state, wkspace, stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_max_512_fwd(b, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - // Unpack QKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - fused_attn_arbitrary_seqlen_fwd( - b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, - return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, - wkspace, stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_arbitrary_seqlen_fwd( + b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, + input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, + wkspace, stream, handle); #else - NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " - "\n"); + NVTE_ERROR( + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - // Unpack QKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, - input_cu_seqlens, input_rng_state, wkspace, stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } - // NVTE fused attention BWD with packed QKV - // DEPRECATED: This API is deprecated. - // Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. - void nvte_fused_attn_bwd_qkvpacked( - const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, - NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, - NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); - using namespace transformer_engine; - - const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); - const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); - const Tensor *input_QKV = convertNVTETensorCheck(QKV); - const Tensor *input_O = convertNVTETensorCheck(O); - const Tensor *input_dO = convertNVTETensorCheck(dO); - const Tensor *input_S = convertNVTETensorCheck(S); - Tensor *input_output_dP = convertNVTETensorCheck(dP); - Tensor *output_dQKV = convertNVTETensorCheck(dQKV); - Tensor *output_dBias = convertNVTETensorCheck(dBias); - Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); - Tensor *wkspace = convertNVTETensor(workspace); - - auto ndim = input_QKV->data.shape.size(); - size_t b = input_cu_seqlens->data.shape[0] - 1; - size_t h = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - h = input_QKV->data.shape[ndim - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - h = input_QKV->data.shape[ndim - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); - } - size_t d = input_QKV->data.shape[ndim - 1]; - size_t t = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - t = input_QKV->data.shape[0]; - } +} +// NVTE fused attention BWD with packed QKV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); + using namespace transformer_engine; - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); + const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); + const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); + const Tensor *input_QKV = convertNVTETensorCheck(QKV); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + const Tensor *input_S = convertNVTETensorCheck(S); + Tensor *input_output_dP = convertNVTETensorCheck(dP); + Tensor *output_dQKV = convertNVTETensorCheck(dQKV); + Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); + Tensor *wkspace = convertNVTETensor(workspace); + + auto ndim = input_QKV->data.shape.size(); + size_t b = input_cu_seqlens->data.shape[0] - 1; + size_t h = 0; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + h = input_QKV->data.shape[ndim - 2]; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + h = input_QKV->data.shape[ndim - 3]; + } else { + NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); + } + size_t d = input_QKV->data.shape[ndim - 1]; + size_t t = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t = input_QKV->data.shape[0]; + } + + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, - h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph, - deterministic); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph, + deterministic); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - - // Unpack QKV and dQKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V and dQ, dK, dV - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); - Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); - Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); - - fused_attn_max_512_bwd(b, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_dO, - output_S, &dQ_view, &dK_view, &dV_view, output_dBias, input_cu_seqlens, - input_cu_seqlens, wkspace, stream, handle); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_max_512_bwd(b, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_dO, output_S, + &dQ_view, &dK_view, &dV_view, output_dBias, input_cu_seqlens, + input_cu_seqlens, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - size_t i = 0; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_Bias, *input_SoftmaxOffset; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } - // Unpack QKV and dQKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V and dQ, dK, dV - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); - Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); - Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); - - fused_attn_arbitrary_seqlen_bwd( - b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, - &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, - &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, - input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, - wkspace, stream, handle); + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_arbitrary_seqlen_bwd( + b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, + &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, + &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, + input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else - const char *err_msg = - "cuDNN 8.9.0 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; - NVTE_ERROR(err_msg); + const char *err_msg = + "cuDNN 8.9.0 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; + NVTE_ERROR(err_msg); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - - // Unpack QKV and dQKV and call the non-packed function - const auto QKV_type = input_QKV->data.dtype; - size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); - std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); - - // Create tensor views for Q, K, V and dQ, dK, dV - Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); - Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); - Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); - - Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); - Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); - Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); - - fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, - input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, - &dV_view, input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, - stream, handle); + const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, + input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, + input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, + handle); #else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } - // NVTE fused attention FWD with packed KV - // DEPRECATED: This API is deprecated. - // Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. - void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); - using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); - const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); - const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_KV = convertNVTETensorCheck(KV); - const Tensor *input_Bias = convertNVTETensorCheck(Bias); - const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); - Tensor *input_output_S = convertNVTETensorCheck(S); - Tensor *output_O = convertNVTETensorCheck(O); - Tensor *wkspace = convertNVTETensor(workspace); - - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - auto ndim = input_Q->data.shape.size(); - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; - auto ndim_kv = input_KV->data.shape.size(); - size_t h_kv = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - h_kv = input_KV->data.shape[ndim_kv - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - h_kv = input_KV->data.shape[ndim_kv - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); - } - size_t t_q = 0; - size_t t_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); +} +// NVTE fused attention FWD with packed KV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); + const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_KV = convertNVTETensorCheck(KV); + const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); + Tensor *input_output_S = convertNVTETensorCheck(S); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensor(workspace); + + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + auto ndim = input_Q->data.shape.size(); + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t d = input_Q->data.shape[ndim - 1]; + auto ndim_kv = input_KV->data.shape.size(); + size_t h_kv = 0; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + h_kv = input_KV->data.shape[ndim_kv - 2]; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + h_kv = input_KV->data.shape[ndim_kv - 3]; + } else { + NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); + } + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + t_kv = input_KV->data.shape[0]; + } + int64_t num_pages_k = 0; + int64_t num_pages_v = 0; + int64_t page_size_k = 0; + int64_t page_size_v = 0; + int64_t max_pages_per_seq_k = 0; + int64_t max_pages_per_seq_v = 0; + if (input_page_table_k->data.dptr != nullptr) { + max_pages_per_seq_k = input_page_table_k->data.shape[1]; + } + if (input_page_table_v->data.dptr != nullptr) { + max_pages_per_seq_v = input_page_table_v->data.shape[1]; + } + if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_KV->data.shape[0]; - } - int64_t num_pages_k = 0; - int64_t num_pages_v = 0; - int64_t page_size_k = 0; - int64_t page_size_v = 0; - int64_t max_pages_per_seq_k = 0; - int64_t max_pages_per_seq_v = 0; - if (input_page_table_k->data.dptr != nullptr) { - max_pages_per_seq_k = input_page_table_k->data.shape[1]; - } - if (input_page_table_v->data.dptr != nullptr) { - max_pages_per_seq_v = input_page_table_v->data.shape[1]; - } - if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { - num_pages_k = input_KV->data.shape[0]; - page_size_k = input_KV->data.shape[1]; - num_pages_v = num_pages_v; - page_size_v = page_size_v; - } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { - num_pages_k = input_KV->data.shape[1]; - page_size_k = input_KV->data.shape[0]; - num_pages_v = num_pages_v; - page_size_v = page_size_v; - } + if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { + num_pages_k = input_KV->data.shape[0]; + page_size_k = input_KV->data.shape[1]; + num_pages_v = num_pages_v; + page_size_v = page_size_v; + } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { + num_pages_k = input_KV->data.shape[1]; + page_size_k = input_KV->data.shape[0]; + num_pages_v = num_pages_v; + page_size_v = page_size_v; } + } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_KV->data.dtype); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_KV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, - return_max_logit, cuda_graph, true); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, + return_max_logit, cuda_graph, true); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - // Unpack KV and call the non-packed function - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, - &V_view, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - // Unpack KV and call the non-packed function - const auto Q_type = input_Q->data.dtype; - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, - return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_arbitrary_seqlen_fwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR( - "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. " - "\n"); + NVTE_ERROR( + "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - // Unpack KV and call the non-packed function - const auto Q_type = input_Q->data.dtype; - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } +} +// NVTE fused attention BWD with packed KV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. +void nvte_fused_attn_bwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, + const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_KV = convertNVTETensorCheck(KV); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + const Tensor *input_S = convertNVTETensorCheck(S); + Tensor *input_output_dP = convertNVTETensorCheck(dP); + Tensor *output_dQ = convertNVTETensorCheck(dQ); + Tensor *output_dKV = convertNVTETensorCheck(dKV); + Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); + Tensor *wkspace = convertNVTETensor(workspace); + + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + auto ndim = input_Q->data.shape.size(); + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t d = input_Q->data.shape[ndim - 1]; + auto ndim_kv = input_KV->data.shape.size(); + size_t h_kv = 0; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + h_kv = input_KV->data.shape[ndim_kv - 2]; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + h_kv = input_KV->data.shape[ndim_kv - 3]; + } else { + NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); + } + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + t_kv = input_KV->data.shape[0]; } - // NVTE fused attention BWD with packed KV - // DEPRECATED: This API is deprecated. - // Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. - void nvte_fused_attn_bwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); - using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_KV = convertNVTETensorCheck(KV); - const Tensor *input_O = convertNVTETensorCheck(O); - const Tensor *input_dO = convertNVTETensorCheck(dO); - const Tensor *input_S = convertNVTETensorCheck(S); - Tensor *input_output_dP = convertNVTETensorCheck(dP); - Tensor *output_dQ = convertNVTETensorCheck(dQ); - Tensor *output_dKV = convertNVTETensorCheck(dKV); - Tensor *output_dBias = convertNVTETensorCheck(dBias); - Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); - Tensor *wkspace = convertNVTETensor(workspace); - - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - auto ndim = input_Q->data.shape.size(); - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; - auto ndim_kv = input_KV->data.shape.size(); - size_t h_kv = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - h_kv = input_KV->data.shape[ndim_kv - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - h_kv = input_KV->data.shape[ndim_kv - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); - } - size_t t_q = 0; - size_t t_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_KV->data.shape[0]; - } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_KV->data.dtype); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_KV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false, - cuda_graph, deterministic); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false, + cuda_graph, deterministic); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - // Unpack KV and dKV and call the non-packed function - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + // Unpack KV and dKV and call the non-packed function + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); - Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); - fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, - input_dO, output_S, output_dQ, &dK_view, &dV_view, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_dO, output_S, + output_dQ, &dK_view, &dV_view, output_dBias, input_cu_seqlens_q, + input_cu_seqlens_kv, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - size_t i = 0; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_Bias, *input_SoftmaxOffset; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } - // Unpack KV and dKV and call the non-packed function - const auto Q_type = input_Q->data.dtype; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - // Create tensor views for dK, dV - Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); - Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); - - fused_attn_arbitrary_seqlen_bwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - deterministic, input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, - input_SoftmaxOffset, output_S, output_dQ, &dK_view, &dV_view, output_dBias, - output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); + // Unpack KV and dKV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + // Create tensor views for dK, dV + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + + fused_attn_arbitrary_seqlen_bwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, + input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, + output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, + wkspace, stream, handle); #else - const char *err_msg = - "cuDNN 8.9.3 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; - NVTE_ERROR(err_msg); + const char *err_msg = + "cuDNN 8.9.3 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; + NVTE_ERROR(err_msg); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - - // Unpack KV and dKV and call the non-packed function - const auto Q_type = input_Q->data.dtype; - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); - std::vector unpacked_kv_shape = - calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); - - Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); - Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); - - Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); - Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); - - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - &dK_view, &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + + // Unpack KV and dKV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, + input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, + &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, + stream, handle); #else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } +} +// NVTE fused attention FWD with separate Q, K and V +void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); + const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_K = convertNVTETensorCheck(K); + const Tensor *input_V = convertNVTETensorCheck(V); + const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); + Tensor *input_output_S = convertNVTETensorCheck(S); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensor(workspace); + + auto ndim = input_Q->data.shape.size(); + auto ndim_kv = input_K->data.shape.size(); + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim_kv - 2]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim_kv - 1]; + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + t_kv = input_K->data.shape[0]; } - // NVTE fused attention FWD with separate Q, K and V - void nvte_fused_attn_fwd( - const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_fwd); - using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); - const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); - const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_K = convertNVTETensorCheck(K); - const Tensor *input_V = convertNVTETensorCheck(V); - const Tensor *input_Bias = convertNVTETensorCheck(Bias); - const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); - Tensor *input_output_S = convertNVTETensorCheck(S); - Tensor *output_O = convertNVTETensorCheck(O); - Tensor *wkspace = convertNVTETensor(workspace); - - auto ndim = input_Q->data.shape.size(); - auto ndim_kv = input_K->data.shape.size(); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; - size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim_kv - 1]; - size_t t_q = 0; - size_t t_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + int64_t num_pages_k = 0; + int64_t num_pages_v = 0; + int64_t page_size_k = 0; + int64_t page_size_v = 0; + int64_t max_pages_per_seq_k = 0; + int64_t max_pages_per_seq_v = 0; + if (input_page_table_k->data.dptr != nullptr) { + max_pages_per_seq_k = input_page_table_k->data.shape[1]; + } + if (input_page_table_v->data.dptr != nullptr) { + max_pages_per_seq_v = input_page_table_v->data.shape[1]; + } + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_K->data.shape[0]; - } - int64_t num_pages_k = 0; - int64_t num_pages_v = 0; - int64_t page_size_k = 0; - int64_t page_size_v = 0; - int64_t max_pages_per_seq_k = 0; - int64_t max_pages_per_seq_v = 0; - if (input_page_table_k->data.dptr != nullptr) { - max_pages_per_seq_k = input_page_table_k->data.shape[1]; - } - if (input_page_table_v->data.dptr != nullptr) { - max_pages_per_seq_v = input_page_table_v->data.shape[1]; - } - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { - num_pages_k = input_K->data.shape[0]; - page_size_k = input_K->data.shape[1]; - num_pages_v = input_V->data.shape[0]; - page_size_v = input_V->data.shape[1]; - } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { - num_pages_k = input_K->data.shape[1]; - page_size_k = input_K->data.shape[0]; - num_pages_v = input_V->data.shape[1]; - page_size_v = input_V->data.shape[0]; - } + if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { + num_pages_k = input_K->data.shape[0]; + page_size_k = input_K->data.shape[1]; + num_pages_v = input_V->data.shape[0]; + page_size_v = input_V->data.shape[1]; + } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { + num_pages_k = input_K->data.shape[1]; + page_size_k = input_K->data.shape[0]; + num_pages_v = input_V->data.shape[1]; + page_size_v = input_V->data.shape[0]; } + } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_K->data.dtype); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_K->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph, true); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, + return_max_logit, cuda_graph, true); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, - input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, + input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, - return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + fused_attn_arbitrary_seqlen_fwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " - "\n"); + NVTE_ERROR( + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } +} +// NVTE fused attention BWD with separate Q, K and V +void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, + const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, + size_t max_seqlen_kv, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_K = convertNVTETensorCheck(K); + const Tensor *input_V = convertNVTETensorCheck(V); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + const Tensor *input_S = convertNVTETensorCheck(S); + Tensor *input_output_dP = convertNVTETensorCheck(dP); + Tensor *output_dQ = convertNVTETensorCheck(dQ); + Tensor *output_dK = convertNVTETensorCheck(dK); + Tensor *output_dV = convertNVTETensorCheck(dV); + Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); + Tensor *wkspace = convertNVTETensor(workspace); + + auto ndim = input_Q->data.shape.size(); + auto ndim_kv = input_K->data.shape.size(); + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim_kv - 2]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim_kv - 1]; + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + t_kv = input_K->data.shape[0]; } - // NVTE fused attention BWD with separate Q, K and V - void nvte_fused_attn_bwd( - const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, - const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, - NVTETensor dQ, NVTETensor dK, NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_bwd); - using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_K = convertNVTETensorCheck(K); - const Tensor *input_V = convertNVTETensorCheck(V); - const Tensor *input_O = convertNVTETensorCheck(O); - const Tensor *input_dO = convertNVTETensorCheck(dO); - const Tensor *input_S = convertNVTETensorCheck(S); - Tensor *input_output_dP = convertNVTETensorCheck(dP); - Tensor *output_dQ = convertNVTETensorCheck(dQ); - Tensor *output_dK = convertNVTETensorCheck(dK); - Tensor *output_dV = convertNVTETensorCheck(dV); - Tensor *output_dBias = convertNVTETensorCheck(dBias); - Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); - Tensor *wkspace = convertNVTETensor(workspace); - - auto ndim = input_Q->data.shape.size(); - auto ndim_kv = input_K->data.shape.size(); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; - size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim_kv - 1]; - size_t t_q = 0; - size_t t_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_K->data.shape[0]; - } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_K->data.dtype); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_K->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph, deterministic); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, + cuda_graph, deterministic); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, - input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - size_t i = 0; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_Bias, *input_SoftmaxOffset; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } - fused_attn_arbitrary_seqlen_bwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, - input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, - output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + fused_attn_arbitrary_seqlen_bwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, + deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, + input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else - const char *err_msg = - "cuDNN 8.9.0 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; - NVTE_ERROR(err_msg); + const char *err_msg = + "cuDNN 8.9.0 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; + NVTE_ERROR(err_msg); #endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, + input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, + output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, wkspace, stream, handle); #else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); - } + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } +} - uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len, - cudaStream_t stream) { - NVTE_API_CALL(nvte_get_runtime_num_segments); - using namespace transformer_engine::fused_attn; - return GetRuntimeNumSegments(cu_seqlen, workspace, len, stream); - } +uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len, + cudaStream_t stream) { + NVTE_API_CALL(nvte_get_runtime_num_segments); + using namespace transformer_engine::fused_attn; + return GetRuntimeNumSegments(cu_seqlen, workspace, len, stream); +} - void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed, - size_t q_max_seqlen, size_t kv_max_seqlen, - NVTE_Fused_Attn_Backend backend, cudaStream_t stream) { - NVTE_API_CALL(nvte_populate_rng_state_async); - using namespace transformer_engine::fused_attn; - PopulateRngStateAsync(rng_state_dst, seed, q_max_seqlen, kv_max_seqlen, backend, stream); - } +void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed, + size_t q_max_seqlen, size_t kv_max_seqlen, + NVTE_Fused_Attn_Backend backend, cudaStream_t stream) { + NVTE_API_CALL(nvte_populate_rng_state_async); + using namespace transformer_engine::fused_attn; + PopulateRngStateAsync(rng_state_dst, seed, q_max_seqlen, kv_max_seqlen, backend, stream); +} From 6f0e5150ab3506277aa0853726150d85ea0037cf Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 13 Jan 2026 06:22:04 -0800 Subject: [PATCH 16/30] fix Jax extension call Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/jax/csrc/extensions/attention.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 9276adb453..4fe8e728a3 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -266,7 +266,7 @@ static void FusedAttnForwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false, true); + false, false, deterministic); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -522,7 +522,7 @@ static void FusedAttnBackwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false, true); + false, false, deterministic); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); From 2c22cbf8cce64e6b4e7446fa0aea6c8e4780ac14 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 13 Jan 2026 06:25:20 -0800 Subject: [PATCH 17/30] minor fixes based on comments Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/jax/cpp_extensions/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index a680fac57b..fdddd0abcf 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3572,7 +3572,7 @@ def fused_attn_bwd( ): assert ( attn_bias_type == AttnBiasType.NO_BIAS - and dropout_probability == 0 + and dropout_probability == 0.0 and get_cudnn_version() >= (9, 18, 0) ), ( "For sm100+, bprop determinism is only supported by cuDNN 9.18+ and for no_bias and" From aae98f3e4af0a838edf5ee611fcf8c8f9cc1b1a4 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 14 Jan 2026 03:53:03 -0800 Subject: [PATCH 18/30] fix selection logic and fwd arg Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 3f561fc189..aec4c09e3e 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -442,9 +442,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version < 91301 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && // determinism on Blackwell - !(sm_arch_ >= 100 && is_training && deterministic && - (cudnn_runtime_version < 91800 || bias_type != NVTE_Bias_Type::NVTE_NO_BIAS || - dropout != 0.0))) { + // pre-9.18.0: fwd: deterministic; bwd: non-deterministic + // 9.18.0+: fwd: deterministic; bwd: non-deterministic/deterministic + (sm_arch_ < 100 || + (sm_arch_ >=100 && (!is_training || (is_training && !deterministic) || + (is_training && deterministic && cudnn_runtime_version >= 91800 && dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { @@ -557,7 +559,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, - cuda_graph, true); + cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -861,7 +863,7 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, - return_max_logit, cuda_graph, true); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -1173,7 +1175,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph, true); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) From b962d3269c135d694fb2ef71a7b2b89bffad0f29 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 11:54:00 +0000 Subject: [PATCH 19/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/fused_attn/fused_attn.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index aec4c09e3e..679856e861 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -445,8 +445,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // pre-9.18.0: fwd: deterministic; bwd: non-deterministic // 9.18.0+: fwd: deterministic; bwd: non-deterministic/deterministic (sm_arch_ < 100 || - (sm_arch_ >=100 && (!is_training || (is_training && !deterministic) || - (is_training && deterministic && cudnn_runtime_version >= 91800 && dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { + (sm_arch_ >= 100 && (!is_training || (is_training && !deterministic) || + (is_training && deterministic && cudnn_runtime_version >= 91800 && + dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { From 10685946485200e6a783c4660f3493452818c92b Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 14 Jan 2026 05:15:10 -0800 Subject: [PATCH 20/30] fix version check in Jax test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/jax/test_fused_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 89b9003843..49e6cbf34d 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -420,7 +420,7 @@ def _check_configs(self): and ( self.dropout_prob != 0.0 or self.attn_bias_type != AttnBiasType.NO_BIAS - or get_cudnn_version() < (9, 18, 0) + or get_cudnn_version() < 91800 ) ): pytest.skip( From c51cf4436848a221f8037d4b89942daa27939ea3 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 15 Jan 2026 00:28:09 -0800 Subject: [PATCH 21/30] fix pytorch CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index f34baad41c..bfa804d0f5 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2404,7 +2404,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): pytest.skip("Not enough backends to run this test with.") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") - unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") + unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedDotProductAttention") atol = 5e-1 rtol = 5e-1 From 38856847c64f02d84a6d0ebb69dbce6539e9f12f Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 15 Jan 2026 00:35:40 -0800 Subject: [PATCH 22/30] fix Jax CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/jax/test_fused_attn.py | 6 ++---- transformer_engine/jax/cpp_extensions/attention.py | 8 ++++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 49e6cbf34d..8bbe1bfb2d 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -416,16 +416,14 @@ def _check_configs(self): if ( get_device_compute_capability(0) >= 100 and self.is_training - and not FusedAttnHelper.is_non_deterministic_allowed() and ( self.dropout_prob != 0.0 or self.attn_bias_type != AttnBiasType.NO_BIAS - or get_cudnn_version() < 91800 + or (not FusedAttnHelper.is_non_deterministic_allowed() and get_cudnn_version() < 91800) ) ): pytest.skip( - "For sm100+, bprop determinism is only supported by cuDNN 9.18+ and for no_bias and" - " dropout = 0.0" + "For sm100+, bprop is only supported for no_bias and dropout = 0.0, and bprop determinism is only supported by cuDNN 9.18.0+" ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index fdddd0abcf..2d1c4928ee 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3568,15 +3568,15 @@ def fused_attn_bwd( if ( any(x >= 100 for x in compute_capabilities) and is_training - and not FusedAttnHelper.is_non_deterministic_allowed() ): assert ( attn_bias_type == AttnBiasType.NO_BIAS and dropout_probability == 0.0 - and get_cudnn_version() >= (9, 18, 0) + and (FusedAttnHelper.is_non_deterministic_allowed() + or (not FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() >= (9, 18, 0))) ), ( - "For sm100+, bprop determinism is only supported by cuDNN 9.18+ and for no_bias and" - " dropout = 0.0" + "For sm100+, bprop is only supported for no_bias and dropout = 0.0, and bprop determinism is only supported by cuDNN 9.18+" ) fused_config = _FusedAttnConfig( From 8bf3a0fbd4739646d7404b93365924875266a6eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Jan 2026 08:36:45 +0000 Subject: [PATCH 23/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_fused_attn.py | 8 ++++++-- tests/pytorch/attention/test_attention.py | 4 +++- .../jax/cpp_extensions/attention.py | 18 ++++++++++-------- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 8bbe1bfb2d..54a489c374 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -419,11 +419,15 @@ def _check_configs(self): and ( self.dropout_prob != 0.0 or self.attn_bias_type != AttnBiasType.NO_BIAS - or (not FusedAttnHelper.is_non_deterministic_allowed() and get_cudnn_version() < 91800) + or ( + not FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() < 91800 + ) ) ): pytest.skip( - "For sm100+, bprop is only supported for no_bias and dropout = 0.0, and bprop determinism is only supported by cuDNN 9.18.0+" + "For sm100+, bprop is only supported for no_bias and dropout = 0.0, and bprop" + " determinism is only supported by cuDNN 9.18.0+" ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index bfa804d0f5..9111d3511c 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2404,7 +2404,9 @@ def test_custom_mha_fp8_vs_f16(dtype, model): pytest.skip("Not enough backends to run this test with.") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") - unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedDotProductAttention") + unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16( + dtype, config, "UnfusedDotProductAttention" + ) atol = 5e-1 rtol = 5e-1 diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 2d1c4928ee..2cf3100bd0 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3565,18 +3565,20 @@ def fused_attn_bwd( ) compute_capabilities = get_all_device_compute_capability() - if ( - any(x >= 100 for x in compute_capabilities) - and is_training - ): + if any(x >= 100 for x in compute_capabilities) and is_training: assert ( attn_bias_type == AttnBiasType.NO_BIAS and dropout_probability == 0.0 - and (FusedAttnHelper.is_non_deterministic_allowed() - or (not FusedAttnHelper.is_non_deterministic_allowed() - and get_cudnn_version() >= (9, 18, 0))) + and ( + FusedAttnHelper.is_non_deterministic_allowed() + or ( + not FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() >= (9, 18, 0) + ) + ) ), ( - "For sm100+, bprop is only supported for no_bias and dropout = 0.0, and bprop determinism is only supported by cuDNN 9.18+" + "For sm100+, bprop is only supported for no_bias and dropout = 0.0, and bprop" + " determinism is only supported by cuDNN 9.18+" ) fused_config = _FusedAttnConfig( From f5265696980790ee8c6e8ece1d613bf81802b7dd Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 15 Jan 2026 06:50:50 -0800 Subject: [PATCH 24/30] fix non-/determinism logic and CI Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/jax/test_fused_attn.py | 30 +++++++++---------- .../common/fused_attn/fused_attn.cpp | 2 +- .../jax/cpp_extensions/attention.py | 20 +++++++------ 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 54a489c374..c36aba845b 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -413,22 +413,22 @@ def _check_configs(self): pytest.skip( "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) - if ( - get_device_compute_capability(0) >= 100 - and self.is_training - and ( - self.dropout_prob != 0.0 + + if get_device_compute_capability(0) >= 100 and self.is_training: + if FusedAttnHelper.is_non_deterministic_allowed() + and ((self.dropout_prob != 0.0 + and self.attn_bias_type != AttnBiasType.NO_BIAS) + or get_cudnn_version() < 90700): + pytest.skip( + "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout" + ) + if not FusedAttnHelper.is_non_deterministic_allowed() + and (self.dropout_prob != 0.0 or self.attn_bias_type != AttnBiasType.NO_BIAS - or ( - not FusedAttnHelper.is_non_deterministic_allowed() - and get_cudnn_version() < 91800 - ) - ) - ): - pytest.skip( - "For sm100+, bprop is only supported for no_bias and dropout = 0.0, and bprop" - " determinism is only supported by cuDNN 9.18.0+" - ) + or get_cudnn_version() < 91800): + pytest.skip( + "For sm100+, deterministic bprop (cuDNN 9.18+) does not support bias or dropout" + ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate(): diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 679856e861..b021716559 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -445,7 +445,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // pre-9.18.0: fwd: deterministic; bwd: non-deterministic // 9.18.0+: fwd: deterministic; bwd: non-deterministic/deterministic (sm_arch_ < 100 || - (sm_arch_ >= 100 && (!is_training || (is_training && !deterministic) || + (sm_arch_ >= 100 && (!is_training || (is_training && !deterministic && !(dropout != 0.0 && bias_type != NVTE_Bias_Type::NVTE_NO_BIAS)) || (is_training && deterministic && cudnn_runtime_version >= 91800 && dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { flag_arb = true; diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 2cf3100bd0..7aaa67ac55 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3567,18 +3567,20 @@ def fused_attn_bwd( compute_capabilities = get_all_device_compute_capability() if any(x >= 100 for x in compute_capabilities) and is_training: assert ( - attn_bias_type == AttnBiasType.NO_BIAS - and dropout_probability == 0.0 - and ( + ( FusedAttnHelper.is_non_deterministic_allowed() - or ( - not FusedAttnHelper.is_non_deterministic_allowed() - and get_cudnn_version() >= (9, 18, 0) - ) + and get_cudnn_version() >= (9, 7, 0) + and (attn_bias_type == AttnBiasType.NO_BIAS or dropout_probability == 0.0) + ) + or ( + not FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() >= (9, 18, 0) + and attn_bias_type == AttnBiasType.NO_BIAS + and dropout_probability == 0.0 ) ), ( - "For sm100+, bprop is only supported for no_bias and dropout = 0.0, and bprop" - " determinism is only supported by cuDNN 9.18+" + "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout," + " and deterministic bprop (cuDNN 9.18+) does not support bias or dropout" ) fused_config = _FusedAttnConfig( From 0cb374ab5af6667061596e21f8cf8637c51b913a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Jan 2026 14:52:48 +0000 Subject: [PATCH 25/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 4 +++- .../jax/cpp_extensions/attention.py | 19 ++++++++----------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index b021716559..1bc3c6ed48 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -445,7 +445,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // pre-9.18.0: fwd: deterministic; bwd: non-deterministic // 9.18.0+: fwd: deterministic; bwd: non-deterministic/deterministic (sm_arch_ < 100 || - (sm_arch_ >= 100 && (!is_training || (is_training && !deterministic && !(dropout != 0.0 && bias_type != NVTE_Bias_Type::NVTE_NO_BIAS)) || + (sm_arch_ >= 100 && (!is_training || + (is_training && !deterministic && + !(dropout != 0.0 && bias_type != NVTE_Bias_Type::NVTE_NO_BIAS)) || (is_training && deterministic && cudnn_runtime_version >= 91800 && dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { flag_arb = true; diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 7aaa67ac55..d33b43aa17 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3567,17 +3567,14 @@ def fused_attn_bwd( compute_capabilities = get_all_device_compute_capability() if any(x >= 100 for x in compute_capabilities) and is_training: assert ( - ( - FusedAttnHelper.is_non_deterministic_allowed() - and get_cudnn_version() >= (9, 7, 0) - and (attn_bias_type == AttnBiasType.NO_BIAS or dropout_probability == 0.0) - ) - or ( - not FusedAttnHelper.is_non_deterministic_allowed() - and get_cudnn_version() >= (9, 18, 0) - and attn_bias_type == AttnBiasType.NO_BIAS - and dropout_probability == 0.0 - ) + FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() >= (9, 7, 0) + and (attn_bias_type == AttnBiasType.NO_BIAS or dropout_probability == 0.0) + ) or ( + not FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() >= (9, 18, 0) + and attn_bias_type == AttnBiasType.NO_BIAS + and dropout_probability == 0.0 ), ( "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout," " and deterministic bprop (cuDNN 9.18+) does not support bias or dropout" From 9162ff63d72462b358d185b0b229ca245b510d4c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 15 Jan 2026 06:57:47 -0800 Subject: [PATCH 26/30] fix formatting Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/jax/test_fused_attn.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index c36aba845b..adfd04542a 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -415,17 +415,11 @@ def _check_configs(self): ) if get_device_compute_capability(0) >= 100 and self.is_training: - if FusedAttnHelper.is_non_deterministic_allowed() - and ((self.dropout_prob != 0.0 - and self.attn_bias_type != AttnBiasType.NO_BIAS) - or get_cudnn_version() < 90700): + if FusedAttnHelper.is_non_deterministic_allowed() and ((self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS) or get_cudnn_version() < 90700): pytest.skip( "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout" ) - if not FusedAttnHelper.is_non_deterministic_allowed() - and (self.dropout_prob != 0.0 - or self.attn_bias_type != AttnBiasType.NO_BIAS - or get_cudnn_version() < 91800): + if not FusedAttnHelper.is_non_deterministic_allowed() and (self.dropout_prob != 0.0 or self.attn_bias_type != AttnBiasType.NO_BIAS or get_cudnn_version() < 91800): pytest.skip( "For sm100+, deterministic bprop (cuDNN 9.18+) does not support bias or dropout" ) From ee90c5aaf99aa70084e898451e597c52c5165455 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Jan 2026 14:59:18 +0000 Subject: [PATCH 27/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_fused_attn.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index adfd04542a..3ac9d693a3 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -415,14 +415,22 @@ def _check_configs(self): ) if get_device_compute_capability(0) >= 100 and self.is_training: - if FusedAttnHelper.is_non_deterministic_allowed() and ((self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS) or get_cudnn_version() < 90700): + if FusedAttnHelper.is_non_deterministic_allowed() and ( + (self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS) + or get_cudnn_version() < 90700 + ): pytest.skip( - "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout" - ) - if not FusedAttnHelper.is_non_deterministic_allowed() and (self.dropout_prob != 0.0 or self.attn_bias_type != AttnBiasType.NO_BIAS or get_cudnn_version() < 91800): + "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with" + " dropout" + ) + if not FusedAttnHelper.is_non_deterministic_allowed() and ( + self.dropout_prob != 0.0 + or self.attn_bias_type != AttnBiasType.NO_BIAS + or get_cudnn_version() < 91800 + ): pytest.skip( "For sm100+, deterministic bprop (cuDNN 9.18+) does not support bias or dropout" - ) + ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate(): From 77d0f2ae26816582846ec49e3326b5656534f853 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 15 Jan 2026 07:00:04 -0800 Subject: [PATCH 28/30] Update transformer_engine/common/fused_attn/fused_attn.cpp fix and/or logic Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/fused_attn/fused_attn.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 1bc3c6ed48..7d90b9b687 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -447,7 +447,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (sm_arch_ < 100 || (sm_arch_ >= 100 && (!is_training || (is_training && !deterministic && - !(dropout != 0.0 && bias_type != NVTE_Bias_Type::NVTE_NO_BIAS)) || + (dropout == 0.0 || bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)) || (is_training && deterministic && cudnn_runtime_version >= 91800 && dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { flag_arb = true; From 65a67c6bf66b89c456bc92c99e2daf78bba43beb Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sun, 18 Jan 2026 16:32:56 -0800 Subject: [PATCH 29/30] update to 9.18.1 for requirement Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/jax/test_fused_attn.py | 4 ++-- transformer_engine/common/fused_attn/fused_attn.cpp | 6 +++--- transformer_engine/jax/cpp_extensions/attention.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 3ac9d693a3..f80c31da0b 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -426,10 +426,10 @@ def _check_configs(self): if not FusedAttnHelper.is_non_deterministic_allowed() and ( self.dropout_prob != 0.0 or self.attn_bias_type != AttnBiasType.NO_BIAS - or get_cudnn_version() < 91800 + or get_cudnn_version() < 91801 ): pytest.skip( - "For sm100+, deterministic bprop (cuDNN 9.18+) does not support bias or dropout" + "For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or dropout" ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 7d90b9b687..415bfae063 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -442,13 +442,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version < 91301 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && // determinism on Blackwell - // pre-9.18.0: fwd: deterministic; bwd: non-deterministic - // 9.18.0+: fwd: deterministic; bwd: non-deterministic/deterministic + // pre-9.18.1: fwd: deterministic; bwd: non-deterministic + // 9.18.1+: fwd: deterministic; bwd: non-deterministic/deterministic (sm_arch_ < 100 || (sm_arch_ >= 100 && (!is_training || (is_training && !deterministic && (dropout == 0.0 || bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)) || - (is_training && deterministic && cudnn_runtime_version >= 91800 && + (is_training && deterministic && cudnn_runtime_version >= 91801 && dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { flag_arb = true; } diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index d33b43aa17..ee10115aa1 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3572,12 +3572,12 @@ def fused_attn_bwd( and (attn_bias_type == AttnBiasType.NO_BIAS or dropout_probability == 0.0) ) or ( not FusedAttnHelper.is_non_deterministic_allowed() - and get_cudnn_version() >= (9, 18, 0) + and get_cudnn_version() >= (9, 18, 1) and attn_bias_type == AttnBiasType.NO_BIAS and dropout_probability == 0.0 ), ( "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout," - " and deterministic bprop (cuDNN 9.18+) does not support bias or dropout" + " and deterministic bprop (cuDNN 9.18.1+) does not support bias or dropout" ) fused_config = _FusedAttnConfig( From 7187d02f925c81cfe921dd2a1869e28360866e87 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Jan 2026 00:34:08 +0000 Subject: [PATCH 30/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_fused_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index f80c31da0b..e5c0a66735 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -429,7 +429,8 @@ def _check_configs(self): or get_cudnn_version() < 91801 ): pytest.skip( - "For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or dropout" + "For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or" + " dropout" ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats