From fbb0702939387d7766dd0b7359511a38eef18d89 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Mon, 5 Jan 2026 11:10:11 -0800 Subject: [PATCH 1/2] Update THD sink attention logic for newer cudnn versions THD Sink attention is supported in 9.18.0 Signed-off-by: Chen Cui --- .../pytorch/attention/dot_product_attention/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bf19388d7e..3749d40e37 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -716,10 +716,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_unfused_attention = False if qkv_format == "thd": - logger.debug( - "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type - ) - use_fused_attention = False + if cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type + ) + use_fused_attention = False logger.debug( "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", softmax_type, From 01848c0687b89ec87d586d5e1070772d7cf68ea8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Jan 2026 19:11:33 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 3749d40e37..fce04bfa2d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -718,7 +718,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if qkv_format == "thd": if cudnn_version < (9, 18, 0): logger.debug( - "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type + "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", + softmax_type, ) use_fused_attention = False logger.debug(