Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
// T3HD and TH3D are not supported by cuDNN on SM120; assert before hitting the path.
const int device_id_fwd = cuda::current_device();
const int sm_arch_fwd = cuda::sm_arch(device_id_fwd);
if (sm_arch_fwd >= 120 &&
(qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D)) {
NVTE_ERROR(
"T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 "
"Use thd_thd_thd or other THD layouts instead.");
Comment on lines +639 to +641
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing period in forward error message

The forward error message is missing a period that is present in the corresponding backward error message (line 748). Minor inconsistency but worth fixing for uniformity.

Suggested change
NVTE_ERROR(
"T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 "
"Use thd_thd_thd or other THD layouts instead.");
"T3HD and TH3D QKV layouts are not supported by cuDNN on SM120. "
"Use thd_thd_thd or other THD layouts instead.");

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

}
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
Expand Down Expand Up @@ -730,6 +739,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
// T3HD and TH3D are not supported by cuDNN on SM120; assert before hitting the path.
const int device_id_bwd = cuda::current_device();
const int sm_arch_bwd = cuda::sm_arch(device_id_bwd);
if (sm_arch_bwd >= 120 &&
(qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D)) {
NVTE_ERROR(
"T3HD and TH3D QKV layouts are not supported by cuDNN on SM120. "
"Use thd_thd_thd or other THD layouts instead.");
}
size_t i = 0;
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD);
const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && !(sm_arch_ >= 120);

NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
Expand All @@ -96,11 +99,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t actual_b = b;
if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv;
// On SM 120, cuDNN support check treats layouts with stride[0] > dim[1]*dim[2]*dim[3]
// as interleaved and rejects them. Use BHSD-like dimensions/strides with max_seqlen at plan build
// so the check passes; ragged offset still provides variable-length boundaries.
if (sm_arch_ < 120) {
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv;
}
}

const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
Expand Down Expand Up @@ -336,7 +344,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
}

std::shared_ptr<fe::graph::Tensor_attributes> Max, Sum_Exp;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
offset_stats =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_stats")
Expand All @@ -353,7 +361,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_name("Sum_Exp")
.set_dim({b, h, s_q, 1})
.set_data_type(fe::DataType_t::FLOAT));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
Max->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
Sum_Exp->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
} else {
Expand Down Expand Up @@ -381,7 +389,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(

if (!return_max_logit) {
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1});
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
} else {
Stats->set_stride({h * s_q, s_q, 1, 1});
Expand All @@ -407,9 +415,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr);
auto offset_kv_tuple =
is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr);
auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600)
? std::make_tuple(offset_stats)
: std::make_tuple(nullptr);
auto offset_s_tuple =
use_ragged_stats ? std::make_tuple(offset_stats) : std::make_tuple(nullptr);
auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset)
: std::make_tuple(nullptr, nullptr);

Expand Down Expand Up @@ -443,7 +450,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
size_t seqlen_offsets_workspace_size = 0;
if (is_ragged_q || is_ragged_kv) {
size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset;
} else {
seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset;
Expand Down Expand Up @@ -510,7 +517,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
}
void *devOffsetsS = nullptr;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
devOffsetsS = static_cast<int8_t *>(devOffsets) +
(static_cast<int>(is_ragged_q) + static_cast<int>(is_ragged_kv)) * 2 *
num_bytes_per_ragged_offset;
Expand All @@ -529,7 +536,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
}
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
variant_pack[offset_stats] = devOffsetsS;
}
}
Expand Down Expand Up @@ -587,6 +594,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && !(sm_arch_ >= 120);

NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
Expand All @@ -598,13 +606,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t actual_b = b;
if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv;
// On SM 120, cuDNN support check requires BHSD-like strides with max_seqlen (see fwd).
if (sm_arch_ < 120) {
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv;
}
}

// We choose between 32-bit and 64-bit offsets depending on need.
// This allows us to support older cuDNN runtimes gracefully.
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
Expand Down Expand Up @@ -765,7 +775,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_name("stats")
.set_dim({b, h, s_q, 1})
.set_data_type(fe::DataType_t::FLOAT));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
offset_stats =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_stats")
Expand All @@ -791,7 +801,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);

if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
sdpa_backward_options.set_max_total_seq_len_q(s_q);
}
if (is_ragged_kv && cudnn_runtime_version >= 90600) {
Expand Down Expand Up @@ -914,9 +924,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr);
auto offset_kv_tuple =
is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr);
auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600)
? std::make_tuple(offset_stats)
: std::make_tuple(nullptr);
auto offset_s_tuple =
use_ragged_stats ? std::make_tuple(offset_stats) : std::make_tuple(nullptr);
auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset)
: std::make_tuple(nullptr, nullptr);

Expand Down Expand Up @@ -949,7 +958,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
size_t seqlen_offsets_workspace_size = 0;
if (is_ragged_q || is_ragged_kv) {
size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset;
} else {
seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset;
Expand Down Expand Up @@ -1019,7 +1028,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
}
void *devOffsetsS = nullptr;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
devOffsetsS = static_cast<int8_t *>(devOffsets) +
(static_cast<int>(is_ragged_q) + static_cast<int>(is_ragged_kv)) * 2 *
num_bytes_per_ragged_offset;
Expand All @@ -1038,7 +1047,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
}
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
variant_pack[offset_stats] = devOffsetsS;
}
}
Expand Down Expand Up @@ -1102,6 +1111,9 @@ void fused_attn_arbitrary_seqlen_fwd(
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}

const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);

void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
Expand All @@ -1128,15 +1140,17 @@ void fused_attn_arbitrary_seqlen_fwd(
if (return_max_logit) {
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Max->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
!(sm_arch_ >= 120)) {
output_Max->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_Max->data.dtype = DType::kFloat32;
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Sum_Exp->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
!(sm_arch_ >= 120)) {
output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
Expand All @@ -1145,7 +1159,8 @@ void fused_attn_arbitrary_seqlen_fwd(
} else {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
!(sm_arch_ >= 120)) {
output_S->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,11 @@ def forward(
softmax_lse_in_packed_format = False
if qkv_format == "thd":
if use_fused_attention:
softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0)
softmax_lse_in_packed_format = get_cudnn_version() >= (
9,
6,
0,
) and get_device_compute_capability() < (12, 0)
else:
softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3

Expand Down
33 changes: 23 additions & 10 deletions transformer_engine/pytorch/attention/dot_product_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,11 +554,15 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if inference_params is not None:
# Temporarily disabling fused attention for kv caching for sm89 irrespective of cuDNN version
# until the cuDNN bug is resolved
if device_compute_capability == (8, 9):
logger.debug("Disabling FusedAttention for KV caching for sm89")
# Temporarily disabling fused attention for kv caching for sm89/sm120 irrespective of
# cuDNN version until the cuDNN bug is resolved.
if device_compute_capability in ((8, 9), (12, 0)):
logger.debug("Disabling FusedAttention for KV caching for sm89/sm120")
use_fused_attention = False
# Temporarily disable FlashAttention for KV caching on sm120
if device_compute_capability == (12, 0):
logger.debug("Disabling FlashAttention for KV caching for sm120")
use_flash_attention = False
if context_parallel:
logger.debug("Disabling all backends for KV caching with context parallelism")
use_flash_attention = False
Expand Down Expand Up @@ -691,12 +695,21 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
)
use_flash_attention = False
if device_compute_capability == (12, 0):
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as qkv_format = thd is"
" not supported for compute capability = sm120"
)
use_fused_attention = False
if cudnn_version < (9, 18, 1):
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as qkv_format = thd is"
" not supported for compute capability = sm120 and cuDNN version < 9.18.1"
)
use_fused_attention = False
elif qkv_layout in {"t3hd", "th3d"}:
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as qkv_layout = %s is not supported for"
" compute capability = sm120",
qkv_layout,
)
use_fused_attention = False

# Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention_3:
Expand Down
19 changes: 14 additions & 5 deletions transformer_engine/pytorch/cpp_extensions/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,22 @@ def fused_attn_fwd(

if return_max_logit:
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
# thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
# thd (newer cuDNN runtimes, non-sm120): output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1]
# thd (older cuDNN runtimes or sm120): output_tensors: out [tq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
# bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
stats = output_tensors[1] + torch.log(output_tensors[2])
amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3)
max_tensor = output_tensors[1]
if qkv_format == "thd" and max_tensor.ndim == 4:
# For THD on older cuDNN runtimes or THD on sm120, stats can be [b, h, sq, 1] with padded
# sequence positions. Exclude those padded positions when computing max_logit.
seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(device=max_tensor.device)
sq_idx = torch.arange(max_tensor.shape[2], device=max_tensor.device).view(1, 1, -1, 1)
valid = sq_idx < seqlens_q.view(-1, 1, 1, 1)
max_tensor = max_tensor.masked_fill(~valid, float("-inf"))
amax_dims = (0, 2) if max_tensor.ndim == 3 else (0, 2, 3)
# Max -> max_logit [h]
max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype)
max_logit = torch.amax(max_tensor, dim=amax_dims).to(dtype=output_tensors[0].dtype)
aux_ctx_tensors = [stats]
aux_ctx_tensors.extend(output_tensors[3:])
return output_tensors[0], aux_ctx_tensors, max_logit
Expand Down
Loading