From f967a26797c77f408e016638c2a02b45e67e04f9 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Tue, 12 May 2026 05:15:43 +0000 Subject: [PATCH 01/11] Add JAX fused attention score_mod support Signed-off-by: Vladimir Cherepanov --- tests/jax/test_fused_attn.py | 359 ++++++++++- transformer_engine/jax/attention.py | 183 +++++- .../jax/cpp_extensions/attention.py | 574 +++++++++++++++++- transformer_engine/jax/csrc/extensions.h | 11 + .../jax/csrc/extensions/attention.cpp | 224 +++++++ .../jax/csrc/extensions/pybind.cpp | 7 + 6 files changed, 1354 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 1fb0108068..f12435fe5b 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Tests for fused attention""" +import importlib.util import os from enum import Enum, auto from dataclasses import dataclass, field @@ -40,7 +41,7 @@ CPStrategy, ReorderStrategy, ) -from transformer_engine.jax.cpp_extensions import FusedAttnHelper +from transformer_engine.jax.cpp_extensions import FusedAttnHelper, make_fused_attn_score_mod_config from transformer_engine_jax import ( NVTE_Fused_Attn_Backend, get_cudnn_version, @@ -54,6 +55,149 @@ _deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) +def _has_cudnn_frontend_python(): + return importlib.util.find_spec("cudnn") is not None + + +def _score_mod_causal(graph, score, tensors): + import cudnn # pylint: disable=import-outside-toplevel + + row_index = graph.gen_index( + input=score, + axis=2, + compute_data_type=cudnn.data_type.INT32, + ) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = graph.gen_index( + input=score, + axis=3, + compute_data_type=cudnn.data_type.INT32, + ) + col_index.set_data_type(cudnn.data_type.INT32) + keep = graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return graph.binary_select(input0=score, input1=tensors["neg_inf"], mask=keep) + + +def _score_mod_causal_bprop(graph, dscore, tensors): + import cudnn # pylint: disable=import-outside-toplevel + + row_index = graph.gen_index( + input=dscore, + axis=2, + compute_data_type=cudnn.data_type.INT32, + ) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = graph.gen_index( + input=dscore, + axis=3, + compute_data_type=cudnn.data_type.INT32, + ) + col_index.set_data_type(cudnn.data_type.INT32) + keep = graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return graph.binary_select(input0=dscore, input1=tensors["zero"], mask=keep) + + +def _score_mod_relative_position(graph, score, tensors): + import cudnn # pylint: disable=import-outside-toplevel + + row_index = graph.gen_index( + input=score, + axis=2, + compute_data_type=cudnn.data_type.INT32, + ) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = graph.gen_index( + input=score, + axis=3, + compute_data_type=cudnn.data_type.INT32, + ) + col_index.set_data_type(cudnn.data_type.INT32) + relative_position = graph.sub( + a=row_index, + b=col_index, + compute_data_type=cudnn.data_type.FLOAT, + ) + relative_position.set_data_type(cudnn.data_type.FLOAT) + return graph.add( + a=score, + b=relative_position, + compute_data_type=cudnn.data_type.FLOAT, + ) + + +class _ScoreModSoftcap: + """cuDNN frontend score_mod implementing softcapping.""" + + def __init__(self): + self.before_tanh_activation = None + + def forward(self, graph, score, tensors): + import cudnn # pylint: disable=import-outside-toplevel + + self.before_tanh_activation = graph.div( + a=score, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + tanh_out = graph.tanh(input=self.before_tanh_activation) + tanh_out.set_data_type(cudnn.data_type.FLOAT) + return graph.mul( + a=tanh_out, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + def backward(self, graph, dscore, tensors): + import cudnn # pylint: disable=import-outside-toplevel + + d_tanh_out = graph.mul( + a=dscore, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + d_tanh_out.set_data_type(cudnn.data_type.FLOAT) + d_before_tanh_activation = graph.tanh_backward( + loss=d_tanh_out, + input=self.before_tanh_activation, + compute_data_type=cudnn.data_type.FLOAT, + ) + d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + return graph.div( + a=d_before_tanh_activation, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + +def _reference_attention( + query, key, value, scale, *, causal=False, relative_position=False, softcap=None +): + scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale + if causal: + q_pos = jnp.arange(query.shape[1])[:, None] + kv_pos = jnp.arange(key.shape[1])[None, :] + scores = jnp.where(q_pos >= kv_pos, scores, -1e9) + if relative_position: + q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None] + kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :] + scores = scores + q_pos - kv_pos + if softcap is not None: + scores = softcap * jnp.tanh(scores / softcap) + probs = jax.nn.softmax(scores, axis=-1) + return jnp.einsum("bhqk,bkhd->bqhd", probs, value).astype(query.dtype) + + @pytest.fixture(autouse=True, scope="module") def init(): """ @@ -138,6 +282,219 @@ def general_dot_product_attention( return context +def _require_cudnn_frontend_score_mod(): + cudnn = pytest.importorskip("cudnn", reason="cuDNN Python frontend is required for score_mod") + version = tuple(int(part) for part in cudnn.backend_version_string().split(".")[:2]) + if version < (9, 6): + pytest.skip("cuDNN score_mod SDPA requires cuDNN frontend 9.6 or newer") + + +def _identity_score_mod(_graph, score, _tensors): + return score + + +def test_fused_attn_score_mod_validation_rejects_masks_without_cudnn_frontend(): + q = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) + k = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) + v = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) + + with pytest.raises(ValueError, match="mutually exclusive with attention masks"): + fused_attn( + (q, k, v), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.CAUSAL_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + 1.0, + 0.0, + True, + score_mod=_identity_score_mod, + ) + + +def test_fused_attn_score_mod_config_splits_tensors_and_pass_by_value_scalars(): + tensor = jnp.ones((1, 1, 1, 1), dtype=jnp.float32) + + config, tensor_operands, bprop_tensor_operands = make_fused_attn_score_mod_config( + _identity_score_mod, + None, + {"tensor": tensor, "neg_inf": -1e9}, + None, + 0.125, + True, + ) + + assert config.score_mod_tensor_names == ("tensor",) + assert len(tensor_operands) == 1 + assert tensor_operands[0].shape == (1, 1, 1, 1) + assert len(bprop_tensor_operands) == 0 + assert len(config.score_mod_scalars) == 1 + assert config.score_mod_scalars[0].name == "neg_inf" + assert config.score_mod_scalars[0].dtype == "float32" + assert len(config.score_mod_scalars[0].value) == np.dtype(np.float32).itemsize + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +def test_fused_attn_score_mod_relative_position_optional_bprop(): + _require_cudnn_frontend_score_mod() + + key = jax.random.key(0) + q_key, k_key, v_key = jax.random.split(key, 3) + q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( + jnp.float16 + ) + k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( + jnp.float16 + ) + v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( + jnp.float16 + ) + scale = 1.0 / sqrt(q.shape[-1]) + + def score_mod_loss(query, key_, value): + out = fused_attn( + (query, key_, value), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.NO_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + scale, + 0.0, + True, + score_mod=_score_mod_relative_position, + ) + return jnp.sum(out.astype(jnp.float32)), out + + def ref_loss(query, key_, value): + out = _reference_attention(query, key_, value, scale, relative_position=True) + return jnp.sum(out.astype(jnp.float32)), out + + (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( + score_mod_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad( + ref_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + + assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) + assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=5e-2, atol=5e-2) + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +def test_fused_attn_score_mod_causal_with_bprop(): + _require_cudnn_frontend_score_mod() + + key = jax.random.key(1) + q_key, k_key, v_key = jax.random.split(key, 3) + q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( + jnp.float16 + ) + k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( + jnp.float16 + ) + v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( + jnp.float16 + ) + scale = 1.0 / sqrt(q.shape[-1]) + + def score_mod_loss(query, key_, value): + out = fused_attn( + (query, key_, value), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.NO_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + scale, + 0.0, + True, + score_mod=_score_mod_causal, + score_mod_bprop=_score_mod_causal_bprop, + score_mod_tensors={"neg_inf": -1e9}, + score_mod_bprop_tensors={"zero": 0.0}, + ) + return jnp.sum(out.astype(jnp.float32)), out + + def ref_loss(query, key_, value): + out = _reference_attention(query, key_, value, scale, causal=True) + return jnp.sum(out.astype(jnp.float32)), out + + (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( + score_mod_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad( + ref_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + + assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) + assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=5e-2, atol=5e-2) + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +def test_fused_attn_score_mod_softcap_with_bprop(): + _require_cudnn_frontend_score_mod() + + key = jax.random.key(2) + q_key, k_key, v_key, d_out_key = jax.random.split(key, 4) + q = jax.random.normal(q_key, (1, 16, 2, 64), dtype=jnp.float16) + k = jax.random.normal(k_key, (1, 16, 2, 64), dtype=jnp.float16) + v = (0.1 * jax.random.normal(v_key, (1, 16, 2, 64), dtype=jnp.float16)).astype( + jnp.float16 + ) + d_out = jax.random.normal(d_out_key, (1, 16, 2, 64), dtype=jnp.float16) + scale = 1.0 / sqrt(q.shape[-1]) + softcap = 0.8 + softcap_score_mod = _ScoreModSoftcap() + + def score_mod_loss(query, key_, value): + out = fused_attn( + (query, key_, value), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.NO_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + scale, + 0.0, + True, + score_mod=softcap_score_mod.forward, + score_mod_bprop=softcap_score_mod.backward, + score_mod_tensors={"softcap": softcap}, + score_mod_bprop_tensors={"softcap": softcap}, + ) + return jnp.sum(out.astype(jnp.float32) * d_out.astype(jnp.float32)), out + + def ref_loss(query, key_, value): + out = _reference_attention(query, key_, value, scale, softcap=softcap) + return jnp.sum(out.astype(jnp.float32) * d_out.astype(jnp.float32)), out + + (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( + score_mod_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad( + ref_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + + assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2) + assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2) + + @jax.jit def make_causal_mask( segment_ids_q: ArrayLike, diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index f54a043fd2..c418dce4de 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -5,7 +5,7 @@ from __future__ import annotations from enum import Enum from functools import partial -from typing import Optional, Tuple, Union +from typing import Any, Callable, Mapping, Optional, Tuple, Union import warnings from jax.ad_checkpoint import checkpoint_name @@ -1391,10 +1391,129 @@ def _fused_attn_bwd_rule( _fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4)) +def _fused_attn_score_mod( + qkv: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], + score_mod_tensors: Tuple[jnp.ndarray, ...], + score_mod_bprop_tensors: Tuple[jnp.ndarray, ...], + config, + context_checkpoint_name: str, +): + output, _ = _fused_attn_score_mod_fwd_rule( + qkv, + score_mod_tensors, + score_mod_bprop_tensors, + config, + context_checkpoint_name, + ) + return output + + +def _fused_attn_score_mod_fwd_rule( + qkv, + score_mod_tensors, + score_mod_bprop_tensors, + config, + context_checkpoint_name, +): + output, softmax_stats = tex.fused_attn_score_mod_fwd(qkv, score_mod_tensors, config) + output = checkpoint_name(output, context_checkpoint_name) + softmax_stats = checkpoint_name(softmax_stats, context_checkpoint_name) + return output, (qkv, score_mod_tensors, score_mod_bprop_tensors, output, softmax_stats) + + +def _fused_attn_score_mod_bwd_rule(config, context_checkpoint_name, ctx, dz): + del context_checkpoint_name + qkv, score_mod_tensors, score_mod_bprop_tensors, output, softmax_stats = ctx + grad_qkv = tex.fused_attn_score_mod_bwd( + qkv, + output, + dz, + softmax_stats, + score_mod_tensors, + score_mod_bprop_tensors, + config, + ) + return ( + grad_qkv, + tuple(None for _ in score_mod_tensors), + tuple(None for _ in score_mod_bprop_tensors), + ) + + +_fused_attn_score_mod.defvjp( + _fused_attn_score_mod_fwd_rule, _fused_attn_score_mod_bwd_rule +) + + +def _validate_fused_attn_score_mod( + qkv: Tuple[jnp.ndarray, ...], + bias: Optional[jnp.ndarray], + sequence_descriptor: Optional[SequenceDescriptor], + seed: Optional[jnp.ndarray], + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + softmax_type: AttnSoftmaxType, + dropout_probability: float, + max_segments_per_seq: int, + window_size: Optional[Tuple[int, int]], + context_parallel_strategy: CPStrategy, + context_parallel_causal_load_balanced: bool, + context_parallel_axis: str, + softmax_offset: Optional[jnp.ndarray], + stripe_size: int | None, +): + """Validate arguments for the cuDNN frontend score_mod path.""" + header = "score_mod fused_attn" + if qkv_layout is not QKVLayout.BSHD_BSHD_BSHD: + raise ValueError(f"{header} currently only supports QKVLayout.BSHD_BSHD_BSHD.") + if len(qkv) != 3: + raise ValueError(f"{header} requires separate query, key and value tensors.") + if any(tensor.ndim != 4 for tensor in qkv): + raise ValueError(f"{header} requires rank-4 BSHD query/key/value tensors.") + q, k, v = qkv + if q.dtype != k.dtype or q.dtype != v.dtype: + raise ValueError(f"{header} requires query, key and value to have the same dtype.") + if q.dtype not in (jnp.float16, jnp.bfloat16): + raise ValueError(f"{header} only supports FP16/BF16 query, key and value tensors.") + if q.shape[0] != k.shape[0] or q.shape[0] != v.shape[0]: + raise ValueError(f"{header} requires matching batch dimensions.") + if k.shape[1] != v.shape[1]: + raise ValueError(f"{header} requires key and value sequence lengths to match.") + if k.shape[2] != v.shape[2]: + raise ValueError(f"{header} requires key and value head counts to match.") + if q.shape[3] != k.shape[3]: + raise ValueError(f"{header} requires query/key head dimensions to match.") + + if bias is not None or attn_bias_type is not AttnBiasType.NO_BIAS: + raise ValueError(f"{header} is mutually exclusive with attention bias.") + if sequence_descriptor is not None: + raise ValueError(f"{header} is mutually exclusive with padding/sequence descriptors.") + if seed is not None: + raise ValueError(f"{header} is mutually exclusive with dropout seed.") + if attn_mask_type is not AttnMaskType.NO_MASK: + raise ValueError(f"{header} is mutually exclusive with attention masks.") + if softmax_type is not AttnSoftmaxType.VANILLA_SOFTMAX or softmax_offset is not None: + raise ValueError(f"{header} only supports vanilla softmax without softmax_offset.") + if dropout_probability != 0.0: + raise ValueError(f"{header} is mutually exclusive with dropout.") + if max_segments_per_seq != 1: + raise ValueError(f"{header} is mutually exclusive with packed/ragged sequence metadata.") + if window_size not in (None, (-1, -1)): + raise ValueError(f"{header} is mutually exclusive with sliding-window attention.") + if context_parallel_strategy is not CPStrategy.DEFAULT: + raise ValueError(f"{header} is mutually exclusive with context parallelism.") + if context_parallel_causal_load_balanced or context_parallel_axis: + raise ValueError(f"{header} is mutually exclusive with context parallelism.") + if stripe_size is not None: + raise ValueError(f"{header} is mutually exclusive with striped context parallelism.") + + def fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], - sequence_descriptor: SequenceDescriptor, + sequence_descriptor: Optional[SequenceDescriptor], seed: Optional[jnp.ndarray], attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, @@ -1411,6 +1530,10 @@ def fused_attn( context_checkpoint_name: str = "context", softmax_offset: Optional[jnp.ndarray] = None, stripe_size: int | None = None, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Mapping[str, Any]] = None, + score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, ): """ Perform cuDNN fused attention. @@ -1453,6 +1576,20 @@ def fused_attn( Currently, a stripe_size > 1 is only supported for CP + THD + Striped + AG, whereas a stripe_size=1 is supported for both, CP + THD + Striped + AG and CP + THD + Striped + P2P(Ring) None indicates no striping strategy + score_mod (Optional[Callable]): Optional cuDNN frontend score modification callback. + The callback is called as `score_mod(graph, score, tensors)` while building a + cuDNN frontend graph. When provided, this path only supports BSHD_BSHD_BSHD + layout and is mutually exclusive with masks, padding, bias, dropout, context + parallelism, sliding windows, and non-vanilla softmax. + score_mod_bprop (Optional[Callable]): Optional score modification backward callback, + called as `score_mod_bprop(graph, dscore, tensors)`. If omitted, cuDNN uses the + default backward behavior for the forward score modification graph. + score_mod_tensors (Optional[Mapping[str, Any]]): Additional tensors or Python/NumPy + scalars made available to `score_mod` through its `tensors` dictionary. Scalars + are represented as cuDNN pass-by-value tensors. Tensor entries are treated as + non-differentiable auxiliary inputs. + score_mod_bprop_tensors (Optional[Mapping[str, Any]]): Additional tensors or + Python/NumPy scalars made available to `score_mod_bprop`. Returns: (jnp.ndarray): The output tensor from the fused attention. @@ -1485,6 +1622,48 @@ def fused_attn( AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, QKVLayout.T3HD, 0.125, 0, True, 3) """ + if score_mod is None: + if score_mod_bprop is not None: + raise ValueError("score_mod_bprop requires score_mod to be provided.") + if score_mod_tensors is not None: + raise ValueError("score_mod_tensors requires score_mod to be provided.") + if score_mod_bprop_tensors is not None: + raise ValueError("score_mod_bprop_tensors requires score_mod to be provided.") + else: + _validate_fused_attn_score_mod( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type, + attn_mask_type, + qkv_layout, + softmax_type, + dropout_probability, + max_segments_per_seq, + window_size, + context_parallel_strategy, + context_parallel_causal_load_balanced, + context_parallel_axis, + softmax_offset, + stripe_size, + ) + config, tensor_operands, bprop_tensor_operands = tex.make_fused_attn_score_mod_config( + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + scaling_factor, + is_training, + ) + return _fused_attn_score_mod( + qkv, + tensor_operands, + bprop_tensor_operands, + config, + context_checkpoint_name, + ) + if sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray): warnings.warn( "Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. " diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 489bfde997..f90efec2fa 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -7,10 +7,11 @@ import warnings from dataclasses import dataclass, replace from functools import partial, reduce -from typing import Optional, Tuple +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple import jax import jax.numpy as jnp +import numpy as np from jax import dtypes, lax, ffi from jax.sharding import PartitionSpec, NamedSharding from jax.experimental.custom_partitioning import SdyShardingRule @@ -54,6 +55,9 @@ "FusedAttnHelper", "fused_attn_fwd", "fused_attn_bwd", + "make_fused_attn_score_mod_config", + "fused_attn_score_mod_fwd", + "fused_attn_score_mod_bwd", ] @@ -267,6 +271,574 @@ def check_seed(self, seed, dropout_probability, is_training): return seed +@dataclass(frozen=True) +class _ScoreModScalarSpec: + """Static pass-by-value scalar used when building a cuDNN frontend graph.""" + + name: str + dtype: str + value: bytes + dim: Tuple[int, ...] = (1, 1, 1, 1) + stride: Tuple[int, ...] = (1, 1, 1, 1) + + +@dataclass(frozen=True) +class _FusedAttnScoreModConfig: + """Static configuration for cuDNN frontend score_mod SDPA graphs.""" + + score_mod: Callable + score_mod_bprop: Optional[Callable] + score_mod_tensor_names: Tuple[str, ...] + score_mod_bprop_tensor_names: Tuple[str, ...] + score_mod_scalars: Tuple[_ScoreModScalarSpec, ...] + score_mod_bprop_scalars: Tuple[_ScoreModScalarSpec, ...] + scaling_factor: float + is_training: bool + deterministic: bool + + def __hash__(self): + return hash( + ( + id(self.score_mod), + id(self.score_mod_bprop) if self.score_mod_bprop is not None else None, + self.score_mod_tensor_names, + self.score_mod_bprop_tensor_names, + self.score_mod_scalars, + self.score_mod_bprop_scalars, + self.scaling_factor, + self.is_training, + self.deterministic, + ) + ) + + def __eq__(self, other): + if not isinstance(other, _FusedAttnScoreModConfig): + return False + return ( + self.score_mod is other.score_mod + and self.score_mod_bprop is other.score_mod_bprop + and self.score_mod_tensor_names == other.score_mod_tensor_names + and self.score_mod_bprop_tensor_names == other.score_mod_bprop_tensor_names + and self.score_mod_scalars == other.score_mod_scalars + and self.score_mod_bprop_scalars == other.score_mod_bprop_scalars + and self.scaling_factor == other.scaling_factor + and self.is_training == other.is_training + and self.deterministic == other.deterministic + ) + + +_SCORE_MOD_UID_Q = 1 +_SCORE_MOD_UID_K = 2 +_SCORE_MOD_UID_V = 3 +_SCORE_MOD_UID_O = 4 +_SCORE_MOD_UID_STATS = 5 +_SCORE_MOD_UID_DO = 6 +_SCORE_MOD_UID_DQ = 7 +_SCORE_MOD_UID_DK = 8 +_SCORE_MOD_UID_DV = 9 +_SCORE_MOD_FWD_TENSOR_UID_BASE = 1000 +_SCORE_MOD_BPROP_TENSOR_UID_BASE = 2000 +_SCORE_MOD_FWD_SCALAR_UID_BASE = 3000 +_SCORE_MOD_BPROP_SCALAR_UID_BASE = 4000 + +_score_mod_graph_cache: Dict[Tuple[Any, ...], Tuple[int, int]] = {} + + +def _row_major_stride(shape: Sequence[int]) -> Tuple[int, ...]: + stride = [] + running = 1 + for dim in reversed(tuple(shape)): + stride.append(running) + running *= dim + return tuple(reversed(stride)) + + +def _bshd_as_bhsd_dim_stride(shape: Sequence[int]) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + if len(shape) != 4: + raise ValueError(f"score_mod requires rank-4 BSHD tensors, got shape={shape}.") + batch, seqlen, heads, head_dim = tuple(shape) + return ( + (batch, heads, seqlen, head_dim), + (seqlen * heads * head_dim, head_dim, heads * head_dim, 1), + ) + + +def _dtype_name(dtype) -> str: + return str(jnp.dtype(dtype)) + + +def _is_array_operand(value: Any) -> bool: + return hasattr(value, "shape") and hasattr(value, "dtype") and not isinstance( + value, (bool, int, float, complex, np.generic) + ) + + +def _scalar_to_spec(name: str, value: Any) -> _ScoreModScalarSpec: + if isinstance(value, bool): + dtype = np.bool_ + elif isinstance(value, int): + dtype = np.int32 + elif isinstance(value, float): + dtype = np.float32 + elif isinstance(value, np.generic): + dtype = value.dtype + else: + scalar = np.asarray(value) + if scalar.shape != (): + raise ValueError( + f"score_mod tensor '{name}' is neither a JAX array nor a scalar pass-by-value." + ) + dtype = scalar.dtype + + scalar = np.full((1, 1, 1, 1), value, dtype=dtype) + return _ScoreModScalarSpec(name=name, dtype=str(scalar.dtype), value=scalar.tobytes()) + + +def _split_score_mod_tensors( + tensors: Optional[Mapping[str, Any]], *, argument_name: str +) -> Tuple[Tuple[str, ...], Tuple[jnp.ndarray, ...], Tuple[_ScoreModScalarSpec, ...]]: + if tensors is None: + return (), (), () + if not isinstance(tensors, Mapping): + raise TypeError(f"{argument_name} must be a mapping from string names to tensors/scalars.") + + names = [] + operands = [] + scalars = [] + for name, value in tensors.items(): + if not isinstance(name, str): + raise TypeError(f"{argument_name} keys must be strings, got {type(name).__name__}.") + if _is_array_operand(value): + if len(value.shape) == 0: + raise ValueError( + f"{argument_name}['{name}'] is a rank-0 array. Use a Python/NumPy scalar " + "for cuDNN pass-by-value scalars, or reshape it to a tensor." + ) + names.append(name) + operands.append(jnp.asarray(value)) + else: + scalars.append(_scalar_to_spec(name, value)) + return tuple(names), tuple(operands), tuple(scalars) + + +def make_fused_attn_score_mod_config( + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Mapping[str, Any]], + score_mod_bprop_tensors: Optional[Mapping[str, Any]], + scaling_factor: float, + is_training: bool, +) -> Tuple[_FusedAttnScoreModConfig, Tuple[jnp.ndarray, ...], Tuple[jnp.ndarray, ...]]: + """Normalize score_mod operands and create a static graph-build config.""" + if not callable(score_mod): + raise TypeError("score_mod must be callable.") + if score_mod_bprop is not None and not callable(score_mod_bprop): + raise TypeError("score_mod_bprop must be callable when provided.") + if score_mod_bprop is None and score_mod_bprop_tensors: + raise ValueError("score_mod_bprop_tensors requires score_mod_bprop to be provided.") + + tensor_names, tensor_operands, scalars = _split_score_mod_tensors( + score_mod_tensors, argument_name="score_mod_tensors" + ) + bprop_tensor_names, bprop_tensor_operands, bprop_scalars = _split_score_mod_tensors( + score_mod_bprop_tensors, argument_name="score_mod_bprop_tensors" + ) + config = _FusedAttnScoreModConfig( + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, + score_mod_tensor_names=tensor_names, + score_mod_bprop_tensor_names=bprop_tensor_names, + score_mod_scalars=scalars, + score_mod_bprop_scalars=bprop_scalars, + scaling_factor=float(scaling_factor), + is_training=bool(is_training), + deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), + ) + return config, tensor_operands, bprop_tensor_operands + + +def _cudnn_data_type(cudnn, dtype): + dtype = jnp.dtype(dtype) + if dtype == jnp.float16: + return cudnn.data_type.HALF + if dtype == jnp.bfloat16: + return cudnn.data_type.BFLOAT16 + if dtype == jnp.float32: + return cudnn.data_type.FLOAT + if dtype == jnp.float64: + return cudnn.data_type.DOUBLE + if dtype == jnp.int32: + return cudnn.data_type.INT32 + if dtype == jnp.int64: + return cudnn.data_type.INT64 + if dtype == jnp.uint8: + return cudnn.data_type.UINT8 + if dtype == jnp.bool_: + return cudnn.data_type.BOOLEAN + raise ValueError(f"Unsupported score_mod tensor dtype: {dtype}.") + + +def _cudnn_data_type_from_name(cudnn, dtype_name: str): + if dtype_name == "bfloat16": + return cudnn.data_type.BFLOAT16 + return _cudnn_data_type(cudnn, np.dtype(dtype_name)) + + +def _graph_tensor_from_aval(cudnn, graph, name: str, aval, uid: int): + shape = tuple(int(dim) for dim in aval.shape) + return graph.tensor( + name=name, + dim=shape, + stride=_row_major_stride(shape), + data_type=_cudnn_data_type(cudnn, aval.dtype), + uid=uid, + ) + + +def _score_mod_graph_tensors( + cudnn, + graph, + names: Tuple[str, ...], + avals: Sequence[Any], + scalars: Tuple[_ScoreModScalarSpec, ...], + tensor_uid_base: int, + scalar_uid_base: int, +): + graph_tensors = {} + tensor_uids = [] + for index, (name, aval) in enumerate(zip(names, avals)): + uid = tensor_uid_base + index + graph_tensors[name] = _graph_tensor_from_aval(cudnn, graph, name, aval, uid) + tensor_uids.append(uid) + + scalar_uids = [] + scalar_values = [] + for index, scalar in enumerate(scalars): + uid = scalar_uid_base + index + graph_tensors[scalar.name] = graph.tensor( + name=scalar.name, + dim=scalar.dim, + stride=scalar.stride, + is_pass_by_value=True, + data_type=_cudnn_data_type_from_name(cudnn, scalar.dtype), + uid=uid, + ) + scalar_uids.append(uid) + scalar_values.append(scalar.value) + + return graph_tensors, tuple(tensor_uids), tuple(scalar_uids), tuple(scalar_values) + + +def _wrap_score_mod(score_mod: Optional[Callable], graph_tensors: Dict[str, Any]): + if score_mod is None: + return None + + def wrapped_score_mod(sdpa_graph, score_tensor): + return score_mod(sdpa_graph, score_tensor, graph_tensors) + + return wrapped_score_mod + + +def _finalize_score_mod_graph(cudnn, graph) -> int: + graph.validate() + graph.build_operation_graph() + try: + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + except cudnn.cudnnGraphNotSupportedError as exc: + raise RuntimeError(f"cuDNN score_mod SDPA graph is not supported: {exc}") from exc + graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE) + return max(int(graph.get_workspace_size()), 1) + + +def _graph_cache_key( + direction: str, + config: _FusedAttnScoreModConfig, + avals: Sequence[Any], +) -> Tuple[Any, ...]: + return ( + direction, + config, + tuple((tuple(aval.shape), _dtype_name(aval.dtype)) for aval in avals), + ) + + +def _shape_dtype(value) -> jax.ShapeDtypeStruct: + return jax.ShapeDtypeStruct(tuple(value.shape), value.dtype) + + +def _import_cudnn_for_score_mod(): + try: + import cudnn # pylint: disable=import-outside-toplevel + except ImportError as exc: + raise ImportError( + "score_mod fused_attn requires the cuDNN frontend Python package (`cudnn`)." + ) from exc + return cudnn + + +def _build_score_mod_fwd_graph(q_aval, k_aval, v_aval, score_mod_avals, config): + cudnn = _import_cudnn_for_score_mod() + + io_data_type = _cudnn_data_type(cudnn, q_aval.dtype) + graph = cudnn.pygraph( + io_data_type=io_data_type, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape) + k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape) + v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape) + q = graph.tensor( + name="q", dim=q_dim, stride=q_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_Q + ) + k = graph.tensor( + name="k", dim=k_dim, stride=k_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_K + ) + v = graph.tensor( + name="v", dim=v_dim, stride=v_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_V + ) + + score_mod_graph_tensors, tensor_uids, scalar_uids, scalar_values = _score_mod_graph_tensors( + cudnn, + graph, + config.score_mod_tensor_names, + score_mod_avals, + config.score_mod_scalars, + _SCORE_MOD_FWD_TENSOR_UID_BASE, + _SCORE_MOD_FWD_SCALAR_UID_BASE, + ) + + output, stats = graph.sdpa( + name="te_score_mod_sdpa", + q=q, + k=k, + v=v, + generate_stats=config.is_training, + attn_scale=config.scaling_factor, + use_causal_mask=False, + score_mod=_wrap_score_mod(config.score_mod, score_mod_graph_tensors), + ) + + batch, q_seqlen, q_heads, _ = q_aval.shape + _, _, _, v_head_dim = v_aval.shape + output_dim, output_stride = _bshd_as_bhsd_dim_stride((batch, q_seqlen, q_heads, v_head_dim)) + output.set_output(True).set_uid(_SCORE_MOD_UID_O).set_dim(output_dim).set_stride( + output_stride + ) + output.set_data_type(io_data_type) + + output_uids = [_SCORE_MOD_UID_O] + if config.is_training: + stats_shape = (batch, q_heads, q_seqlen, 1) + stats.set_output(True).set_uid(_SCORE_MOD_UID_STATS).set_dim(stats_shape).set_stride( + _row_major_stride(stats_shape) + ) + stats.set_data_type(cudnn.data_type.FLOAT) + output_uids.append(_SCORE_MOD_UID_STATS) + + workspace_size = _finalize_score_mod_graph(cudnn, graph) + graph_id = transformer_engine_jax.register_fused_attn_score_mod_graph( + graph, + [int(uid) for uid in graph._get_variant_pack_uids_sorted()], + [_SCORE_MOD_UID_Q, _SCORE_MOD_UID_K, _SCORE_MOD_UID_V, *tensor_uids], + output_uids, + list(scalar_uids), + list(scalar_values), + ) + return graph_id, workspace_size + + +def _build_score_mod_bwd_graph( + q_aval, + k_aval, + v_aval, + output_aval, + doutput_aval, + stats_aval, + score_mod_avals, + score_mod_bprop_avals, + config, +): + cudnn = _import_cudnn_for_score_mod() + + io_data_type = _cudnn_data_type(cudnn, q_aval.dtype) + graph = cudnn.pygraph( + io_data_type=io_data_type, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape) + k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape) + v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape) + o_dim, o_stride = _bshd_as_bhsd_dim_stride(output_aval.shape) + do_dim, do_stride = _bshd_as_bhsd_dim_stride(doutput_aval.shape) + q = graph.tensor( + name="q", dim=q_dim, stride=q_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_Q + ) + k = graph.tensor( + name="k", dim=k_dim, stride=k_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_K + ) + v = graph.tensor( + name="v", dim=v_dim, stride=v_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_V + ) + output = graph.tensor( + name="o", dim=o_dim, stride=o_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_O + ) + doutput = graph.tensor( + name="dO", dim=do_dim, stride=do_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_DO + ) + stats = graph.tensor( + name="stats", + dim=tuple(int(dim) for dim in stats_aval.shape), + stride=_row_major_stride(stats_aval.shape), + data_type=cudnn.data_type.FLOAT, + uid=_SCORE_MOD_UID_STATS, + ) + + score_mod_graph_tensors, tensor_uids, scalar_uids, scalar_values = _score_mod_graph_tensors( + cudnn, + graph, + config.score_mod_tensor_names, + score_mod_avals, + config.score_mod_scalars, + _SCORE_MOD_FWD_TENSOR_UID_BASE, + _SCORE_MOD_FWD_SCALAR_UID_BASE, + ) + ( + score_mod_bprop_graph_tensors, + bprop_tensor_uids, + bprop_scalar_uids, + bprop_scalar_values, + ) = _score_mod_graph_tensors( + cudnn, + graph, + config.score_mod_bprop_tensor_names, + score_mod_bprop_avals, + config.score_mod_bprop_scalars, + _SCORE_MOD_BPROP_TENSOR_UID_BASE, + _SCORE_MOD_BPROP_SCALAR_UID_BASE, + ) + + dq, dk, dv = graph.sdpa_backward( + name="te_score_mod_sdpa_backward", + q=q, + k=k, + v=v, + o=output, + dO=doutput, + stats=stats, + attn_scale=config.scaling_factor, + use_causal_mask=False, + score_mod=_wrap_score_mod(config.score_mod, score_mod_graph_tensors), + score_mod_bprop=_wrap_score_mod(config.score_mod_bprop, score_mod_bprop_graph_tensors), + use_deterministic_algorithm=config.deterministic, + ) + + dq.set_output(True).set_uid(_SCORE_MOD_UID_DQ).set_dim(q_dim).set_stride(q_stride) + dk.set_output(True).set_uid(_SCORE_MOD_UID_DK).set_dim(k_dim).set_stride(k_stride) + dv.set_output(True).set_uid(_SCORE_MOD_UID_DV).set_dim(v_dim).set_stride(v_stride) + + workspace_size = _finalize_score_mod_graph(cudnn, graph) + graph_id = transformer_engine_jax.register_fused_attn_score_mod_graph( + graph, + [int(uid) for uid in graph._get_variant_pack_uids_sorted()], + [ + _SCORE_MOD_UID_Q, + _SCORE_MOD_UID_K, + _SCORE_MOD_UID_V, + _SCORE_MOD_UID_O, + _SCORE_MOD_UID_DO, + _SCORE_MOD_UID_STATS, + *tensor_uids, + *bprop_tensor_uids, + ], + [_SCORE_MOD_UID_DQ, _SCORE_MOD_UID_DK, _SCORE_MOD_UID_DV], + [*scalar_uids, *bprop_scalar_uids], + [*scalar_values, *bprop_scalar_values], + ) + return graph_id, workspace_size + + +def fused_attn_score_mod_fwd( + qkv: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], + score_mod_tensors: Tuple[jnp.ndarray, ...], + config: _FusedAttnScoreModConfig, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Run cuDNN frontend SDPA forward with a score_mod callback.""" + q, k, v = qkv + q_aval, k_aval, v_aval = map(_shape_dtype, (q, k, v)) + score_mod_avals = tuple(_shape_dtype(arg) for arg in score_mod_tensors) + key = _graph_cache_key("fwd", config, (q_aval, k_aval, v_aval, *score_mod_avals)) + if key not in _score_mod_graph_cache: + _score_mod_graph_cache[key] = _build_score_mod_fwd_graph( + q_aval, k_aval, v_aval, score_mod_avals, config + ) + graph_id, workspace_size = _score_mod_graph_cache[key] + + batch, q_seqlen, q_heads, _ = q.shape + _, _, _, v_head_dim = v.shape + output_shape = jax.ShapeDtypeStruct((batch, q_seqlen, q_heads, v_head_dim), q.dtype) + stats_shape = (batch, q_heads, q_seqlen, 1) if config.is_training else (0,) + stats = jax.ShapeDtypeStruct(stats_shape, jnp.float32) + workspace = jax.ShapeDtypeStruct((workspace_size,), jnp.uint8) + output, softmax_stats, _ = ffi.ffi_call( + "te_fused_attn_score_mod_forward_ffi", + (output_shape, stats, workspace), + )(q, k, v, *score_mod_tensors, graph_id=graph_id) + return output, softmax_stats + + +def fused_attn_score_mod_bwd( + qkv: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], + output: jnp.ndarray, + doutput: jnp.ndarray, + softmax_stats: jnp.ndarray, + score_mod_tensors: Tuple[jnp.ndarray, ...], + score_mod_bprop_tensors: Tuple[jnp.ndarray, ...], + config: _FusedAttnScoreModConfig, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Run cuDNN frontend SDPA backward with score_mod callbacks.""" + if not config.is_training: + raise RuntimeError("score_mod backward requires fused_attn(..., is_training=True).") + + q, k, v = qkv + all_inputs = (q, k, v, output, doutput, softmax_stats, *score_mod_tensors) + all_inputs = (*all_inputs, *score_mod_bprop_tensors) + avals = tuple(_shape_dtype(arg) for arg in all_inputs) + key = _graph_cache_key("bwd", config, avals) + if key not in _score_mod_graph_cache: + _score_mod_graph_cache[key] = _build_score_mod_bwd_graph( + *avals[:6], + avals[6 : 6 + len(score_mod_tensors)], + avals[6 + len(score_mod_tensors) :], + config, + ) + graph_id, workspace_size = _score_mod_graph_cache[key] + + dq = jax.ShapeDtypeStruct(q.shape, q.dtype) + dk = jax.ShapeDtypeStruct(k.shape, k.dtype) + dv = jax.ShapeDtypeStruct(v.shape, v.dtype) + workspace = jax.ShapeDtypeStruct((workspace_size,), jnp.uint8) + dq, dk, dv, _ = ffi.ffi_call( + "te_fused_attn_score_mod_backward_ffi", + (dq, dk, dv, workspace), + )( + q, + k, + v, + output, + doutput, + softmax_stats, + *score_mod_tensors, + *score_mod_bprop_tensors, + graph_id=graph_id, + ) + return dq, dk, dv + + def generate_cu_seqlen(actual_seqlen): """ Generating cumsum seqlen for a batch diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 2ecfedc8a2..f17788a068 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -146,6 +146,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnScoreModForwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnScoreModBackwardHandler); + NVTE_Fused_Attn_Backend GetFusedAttnBackend( bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, @@ -169,6 +173,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal); +int64_t RegisterFusedAttnScoreModGraph(pybind11::object graph, + const std::vector& user_uids, + const std::vector& input_uids, + const std::vector& output_uids, + const std::vector& scalar_uids, + const std::vector& scalar_values); + // GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmV2Handler); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index ed136d7b9e..26a3e5e8c0 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -8,6 +8,13 @@ #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" +#include +#include +#include +#include +#include +#include + namespace transformer_engine { namespace jax { @@ -689,5 +696,222 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, .Attrs(), FFI_CudaGraph_Traits); +namespace { + +struct ScoreModScalarStorage { + alignas(16) std::array data{}; + size_t size = 0; +}; + +struct ScoreModGraphEntry { + PyObject *py_graph = nullptr; + std::vector user_uids; + std::vector input_uids; + std::vector output_uids; + std::vector scalar_uids; + std::vector scalar_values; +}; + +std::unordered_map> &ScoreModGraphRegistry() { + static std::unordered_map> registry; + return registry; +} + +std::mutex &ScoreModGraphRegistryMutex() { + static std::mutex mutex; + return mutex; +} + +std::atomic &NextScoreModGraphId() { + static std::atomic next_id{1}; + return next_id; +} + +struct ScoreModCudnnHandleCache { + std::unordered_map handles; + + cudnnHandle_t GetHandle() { + int device_id = 0; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + auto it = handles.find(device_id); + if (it == handles.end()) { + cudnnHandle_t handle = nullptr; + NVTE_CHECK_CUDNN(cudnnCreate(&handle)); + it = handles.emplace(device_id, handle).first; + } + return it->second; + } + + ~ScoreModCudnnHandleCache() { + for (auto &[_, handle] : handles) { + cudnnDestroy(handle); + } + } +}; + +cudnnHandle_t GetScoreModCudnnHandle() { + static thread_local ScoreModCudnnHandleCache cache; + return cache.GetHandle(); +} + +std::shared_ptr GetScoreModGraphEntry(int64_t graph_id) { + std::lock_guard lock(ScoreModGraphRegistryMutex()); + auto ®istry = ScoreModGraphRegistry(); + auto it = registry.find(graph_id); + NVTE_CHECK(it != registry.end(), "Unknown cuDNN score_mod graph id: ", graph_id); + return it->second; +} + +Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id, + const std::vector &input_ptrs, + const std::vector &output_ptrs, void *workspace) { + auto entry = GetScoreModGraphEntry(graph_id); + NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), + "cuDNN score_mod graph expected ", entry->input_uids.size(), " inputs but got ", + input_ptrs.size()); + NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(), + "cuDNN score_mod graph expected at least ", entry->output_uids.size(), + " outputs but got ", output_ptrs.size()); + + std::unordered_map variant_pack; + for (size_t i = 0; i < entry->input_uids.size(); ++i) { + variant_pack.emplace(entry->input_uids[i], input_ptrs[i]); + } + for (size_t i = 0; i < entry->output_uids.size(); ++i) { + variant_pack.emplace(entry->output_uids[i], output_ptrs[i]); + } + for (size_t i = 0; i < entry->scalar_uids.size(); ++i) { + variant_pack.emplace(entry->scalar_uids[i], entry->scalar_values[i].data.data()); + } + + std::vector user_ptrs; + user_ptrs.reserve(entry->user_uids.size()); + for (const auto uid : entry->user_uids) { + auto it = variant_pack.find(uid); + NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", + uid); + user_ptrs.push_back(reinterpret_cast(it->second)); + } + + auto handle = GetScoreModCudnnHandle(); + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + { + pybind11::gil_scoped_acquire gil; + try { + auto graph = pybind11::reinterpret_borrow(entry->py_graph); + graph.attr("_execute_with_ptrs")(user_ptrs, reinterpret_cast(workspace), + reinterpret_cast(handle)); + } catch (const pybind11::error_already_set &exc) { + NVTE_ERROR("cuDNN score_mod SDPA graph execution failed: ", exc.what()); + } + } + return ffi_with_cuda_error_check(); +} + +void AppendRemainingBuffers(Variadic_Buffer_Type args, std::vector *ptrs) { + ptrs->reserve(ptrs->size() + args.size()); + for (size_t i = 0; i < args.size(); ++i) { + auto maybe_buf = args.get(i); + NVTE_CHECK(!maybe_buf.has_error(), "Failed to decode variadic score_mod input buffer."); + ptrs->push_back(maybe_buf.value().untyped_data()); + } +} + +} // namespace + +int64_t RegisterFusedAttnScoreModGraph(pybind11::object graph, + const std::vector &user_uids, + const std::vector &input_uids, + const std::vector &output_uids, + const std::vector &scalar_uids, + const std::vector &scalar_values) { + NVTE_CHECK(!graph.is_none(), "Cannot register an empty cuDNN score_mod graph."); + NVTE_CHECK(!user_uids.empty(), "Cannot register a cuDNN score_mod graph without variant UIDs."); + NVTE_CHECK(scalar_uids.size() == scalar_values.size(), + "Mismatched score_mod scalar uid/value counts."); + + auto entry = std::make_shared(); + entry->py_graph = graph.ptr(); + Py_INCREF(entry->py_graph); + entry->user_uids = user_uids; + entry->input_uids = input_uids; + entry->output_uids = output_uids; + entry->scalar_uids = scalar_uids; + entry->scalar_values.reserve(scalar_values.size()); + for (const auto &value : scalar_values) { + NVTE_CHECK(value.size() <= 16, "score_mod pass-by-value scalars must be at most 16 bytes."); + ScoreModScalarStorage storage; + storage.size = value.size(); + std::copy(value.begin(), value.end(), storage.data.begin()); + entry->scalar_values.push_back(storage); + } + + const int64_t graph_id = NextScoreModGraphId().fetch_add(1); + { + std::lock_guard lock(ScoreModGraphRegistryMutex()); + ScoreModGraphRegistry().emplace(graph_id, std::move(entry)); + } + return graph_id; +} + +Error_Type FusedAttnScoreModForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Variadic_Buffer_Type score_mod_args, + Result_Type output_buf, Result_Type stats_buf, + Result_Type workspace_buf, Dictionary attrs) { + int64_t graph_id = get_attr_value(attrs, "graph_id"); + std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), + v_buf.untyped_data()}; + AppendRemainingBuffers(score_mod_args, &input_ptrs); + + std::vector output_ptrs = {output_buf->untyped_data(), stats_buf->untyped_data()}; + return ExecuteScoreModGraph(stream, graph_id, input_ptrs, output_ptrs, + workspace_buf->untyped_data()); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnScoreModForwardHandler, FusedAttnScoreModForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // q + .Arg() // k + .Arg() // v + .RemainingArgs() // score_mod tensor operands + .Ret() // output + .Ret() // stats + .Ret() // workspace + .Attrs()); + +Error_Type FusedAttnScoreModBackwardFFI( + cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf, + Buffer_Type output_buf, Buffer_Type doutput_buf, Buffer_Type stats_buf, + Variadic_Buffer_Type score_mod_args, Result_Type dq_buf, Result_Type dk_buf, + Result_Type dv_buf, Result_Type workspace_buf, Dictionary attrs) { + int64_t graph_id = get_attr_value(attrs, "graph_id"); + std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), + v_buf.untyped_data(), output_buf.untyped_data(), + doutput_buf.untyped_data(), stats_buf.untyped_data()}; + AppendRemainingBuffers(score_mod_args, &input_ptrs); + + std::vector output_ptrs = {dq_buf->untyped_data(), dk_buf->untyped_data(), + dv_buf->untyped_data()}; + return ExecuteScoreModGraph(stream, graph_id, input_ptrs, output_ptrs, + workspace_buf->untyped_data()); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnScoreModBackwardHandler, FusedAttnScoreModBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // q + .Arg() // k + .Arg() // v + .Arg() // output + .Arg() // doutput + .Arg() // stats + .RemainingArgs() // score_mod tensor operands + .Ret() // dq + .Ret() // dk + .Ret() // dv + .Ret() // workspace + .Attrs()); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 70d0403b3e..bdb4507323 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -64,6 +64,12 @@ pybind11::dict Registrations() { dict["te_fused_attn_backward_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler)); + dict["te_fused_attn_score_mod_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(FusedAttnScoreModForwardHandler)); + dict["te_fused_attn_score_mod_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(FusedAttnScoreModBackwardHandler)); // GEMM dict["te_gemm_ffi"] = @@ -121,6 +127,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); + m.def("register_fused_attn_score_mod_graph", &RegisterFusedAttnScoreModGraph); m.def("get_topk_workspace_sizes", &GetTopkWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); From 6b0532824215979ba7ef2dff64aae5ec637ef984 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Tue, 12 May 2026 05:38:04 +0000 Subject: [PATCH 02/11] Stabilize score_mod callback cache keys Signed-off-by: Vladimir Cherepanov --- tests/jax/test_fused_attn.py | 41 +++++++++++++++++++ .../jax/cpp_extensions/attention.py | 23 +++++++++-- 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index f12435fe5b..001fa049ce 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -337,6 +337,47 @@ def test_fused_attn_score_mod_config_splits_tensors_and_pass_by_value_scalars(): assert len(config.score_mod_scalars[0].value) == np.dtype(np.float32).itemsize +def test_fused_attn_score_mod_config_stabilizes_bound_method_cache_keys(): + softcap_score_mod = _ScoreModSoftcap() + first_forward = softcap_score_mod.forward + second_forward = softcap_score_mod.forward + first_backward = softcap_score_mod.backward + second_backward = softcap_score_mod.backward + + assert first_forward is not second_forward + assert first_backward is not second_backward + + config_1, _, _ = make_fused_attn_score_mod_config( + first_forward, + first_backward, + {"softcap": 0.8}, + {"softcap": 0.8}, + 0.125, + True, + ) + config_2, _, _ = make_fused_attn_score_mod_config( + second_forward, + second_backward, + {"softcap": 0.8}, + {"softcap": 0.8}, + 0.125, + True, + ) + other_softcap_score_mod = _ScoreModSoftcap() + config_3, _, _ = make_fused_attn_score_mod_config( + other_softcap_score_mod.forward, + other_softcap_score_mod.backward, + {"softcap": 0.8}, + {"softcap": 0.8}, + 0.125, + True, + ) + + assert config_1 == config_2 + assert hash(config_1) == hash(config_2) + assert config_1 != config_3 + + @pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") def test_fused_attn_score_mod_relative_position_optional_bprop(): _require_cudnn_frontend_score_mod() diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index f90efec2fa..6b4d9f8025 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -282,12 +282,25 @@ class _ScoreModScalarSpec: stride: Tuple[int, ...] = (1, 1, 1, 1) +def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Optional[Tuple[Any, ...]]: + """Return a stable cache key for callbacks that may be bound methods.""" + if callback is None: + return None + self_obj = getattr(callback, "__self__", None) + func = getattr(callback, "__func__", None) + if self_obj is not None and func is not None: + return ("bound_method", id(self_obj), id(func)) + return ("callable", id(callback)) + + @dataclass(frozen=True) class _FusedAttnScoreModConfig: """Static configuration for cuDNN frontend score_mod SDPA graphs.""" score_mod: Callable score_mod_bprop: Optional[Callable] + score_mod_key: Tuple[Any, ...] + score_mod_bprop_key: Optional[Tuple[Any, ...]] score_mod_tensor_names: Tuple[str, ...] score_mod_bprop_tensor_names: Tuple[str, ...] score_mod_scalars: Tuple[_ScoreModScalarSpec, ...] @@ -299,8 +312,8 @@ class _FusedAttnScoreModConfig: def __hash__(self): return hash( ( - id(self.score_mod), - id(self.score_mod_bprop) if self.score_mod_bprop is not None else None, + self.score_mod_key, + self.score_mod_bprop_key, self.score_mod_tensor_names, self.score_mod_bprop_tensor_names, self.score_mod_scalars, @@ -315,8 +328,8 @@ def __eq__(self, other): if not isinstance(other, _FusedAttnScoreModConfig): return False return ( - self.score_mod is other.score_mod - and self.score_mod_bprop is other.score_mod_bprop + self.score_mod_key == other.score_mod_key + and self.score_mod_bprop_key == other.score_mod_bprop_key and self.score_mod_tensor_names == other.score_mod_tensor_names and self.score_mod_bprop_tensor_names == other.score_mod_bprop_tensor_names and self.score_mod_scalars == other.score_mod_scalars @@ -446,6 +459,8 @@ def make_fused_attn_score_mod_config( config = _FusedAttnScoreModConfig( score_mod=score_mod, score_mod_bprop=score_mod_bprop, + score_mod_key=_score_mod_callback_cache_key(score_mod), + score_mod_bprop_key=_score_mod_callback_cache_key(score_mod_bprop), score_mod_tensor_names=tensor_names, score_mod_bprop_tensor_names=bprop_tensor_names, score_mod_scalars=scalars, From 1a9635201ffaeb9463956867cc7cfd903eab620b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 03:19:59 +0000 Subject: [PATCH 03/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_fused_attn.py | 46 +++++++------------ transformer_engine/jax/attention.py | 4 +- .../jax/cpp_extensions/attention.py | 10 ++-- .../jax/csrc/extensions/attention.cpp | 37 ++++++++------- 4 files changed, 40 insertions(+), 57 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 001fa049ce..1a7658fb8a 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -384,15 +384,9 @@ def test_fused_attn_score_mod_relative_position_optional_bprop(): key = jax.random.key(0) q_key, k_key, v_key = jax.random.split(key, 3) - q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( - jnp.float16 - ) - k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( - jnp.float16 - ) - v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( - jnp.float16 - ) + q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) scale = 1.0 / sqrt(q.shape[-1]) def score_mod_loss(query, key_, value): @@ -419,9 +413,9 @@ def ref_loss(query, key_, value): (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( score_mod_loss, argnums=(0, 1, 2), has_aux=True )(q, k, v) - (ref_value, ref_out), ref_grads = value_and_grad( - ref_loss, argnums=(0, 1, 2), has_aux=True - )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( + q, k, v + ) assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) @@ -435,15 +429,9 @@ def test_fused_attn_score_mod_causal_with_bprop(): key = jax.random.key(1) q_key, k_key, v_key = jax.random.split(key, 3) - q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( - jnp.float16 - ) - k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( - jnp.float16 - ) - v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( - jnp.float16 - ) + q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) scale = 1.0 / sqrt(q.shape[-1]) def score_mod_loss(query, key_, value): @@ -473,9 +461,9 @@ def ref_loss(query, key_, value): (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( score_mod_loss, argnums=(0, 1, 2), has_aux=True )(q, k, v) - (ref_value, ref_out), ref_grads = value_and_grad( - ref_loss, argnums=(0, 1, 2), has_aux=True - )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( + q, k, v + ) assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) @@ -491,9 +479,7 @@ def test_fused_attn_score_mod_softcap_with_bprop(): q_key, k_key, v_key, d_out_key = jax.random.split(key, 4) q = jax.random.normal(q_key, (1, 16, 2, 64), dtype=jnp.float16) k = jax.random.normal(k_key, (1, 16, 2, 64), dtype=jnp.float16) - v = (0.1 * jax.random.normal(v_key, (1, 16, 2, 64), dtype=jnp.float16)).astype( - jnp.float16 - ) + v = (0.1 * jax.random.normal(v_key, (1, 16, 2, 64), dtype=jnp.float16)).astype(jnp.float16) d_out = jax.random.normal(d_out_key, (1, 16, 2, 64), dtype=jnp.float16) scale = 1.0 / sqrt(q.shape[-1]) softcap = 0.8 @@ -526,9 +512,9 @@ def ref_loss(query, key_, value): (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( score_mod_loss, argnums=(0, 1, 2), has_aux=True )(q, k, v) - (ref_value, ref_out), ref_grads = value_and_grad( - ref_loss, argnums=(0, 1, 2), has_aux=True - )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( + q, k, v + ) assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2) assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index c418dce4de..adbbcd02fa 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -1441,9 +1441,7 @@ def _fused_attn_score_mod_bwd_rule(config, context_checkpoint_name, ctx, dz): ) -_fused_attn_score_mod.defvjp( - _fused_attn_score_mod_fwd_rule, _fused_attn_score_mod_bwd_rule -) +_fused_attn_score_mod.defvjp(_fused_attn_score_mod_fwd_rule, _fused_attn_score_mod_bwd_rule) def _validate_fused_attn_score_mod( diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6b4d9f8025..fa9d9f1467 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -381,8 +381,10 @@ def _dtype_name(dtype) -> str: def _is_array_operand(value: Any) -> bool: - return hasattr(value, "shape") and hasattr(value, "dtype") and not isinstance( - value, (bool, int, float, complex, np.generic) + return ( + hasattr(value, "shape") + and hasattr(value, "dtype") + and not isinstance(value, (bool, int, float, complex, np.generic)) ) @@ -639,9 +641,7 @@ def _build_score_mod_fwd_graph(q_aval, k_aval, v_aval, score_mod_avals, config): batch, q_seqlen, q_heads, _ = q_aval.shape _, _, _, v_head_dim = v_aval.shape output_dim, output_stride = _bshd_as_bhsd_dim_stride((batch, q_seqlen, q_heads, v_head_dim)) - output.set_output(True).set_uid(_SCORE_MOD_UID_O).set_dim(output_dim).set_stride( - output_stride - ) + output.set_output(True).set_uid(_SCORE_MOD_UID_O).set_dim(output_dim).set_stride(output_stride) output.set_data_type(io_data_type) output_uids = [_SCORE_MOD_UID_O] diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 26a3e5e8c0..665ac40277 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -4,10 +4,6 @@ * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" -#include "transformer_engine/fused_attn.h" -#include "transformer_engine/transformer_engine.h" - #include #include #include @@ -15,6 +11,10 @@ #include #include +#include "../extensions.h" +#include "transformer_engine/fused_attn.h" +#include "transformer_engine/transformer_engine.h" + namespace transformer_engine { namespace jax { @@ -766,9 +766,8 @@ Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id, const std::vector &input_ptrs, const std::vector &output_ptrs, void *workspace) { auto entry = GetScoreModGraphEntry(graph_id); - NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), - "cuDNN score_mod graph expected ", entry->input_uids.size(), " inputs but got ", - input_ptrs.size()); + NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), "cuDNN score_mod graph expected ", + entry->input_uids.size(), " inputs but got ", input_ptrs.size()); NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(), "cuDNN score_mod graph expected at least ", entry->output_uids.size(), " outputs but got ", output_ptrs.size()); @@ -788,8 +787,7 @@ Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id, user_ptrs.reserve(entry->user_uids.size()); for (const auto uid : entry->user_uids) { auto it = variant_pack.find(uid); - NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", - uid); + NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", uid); user_ptrs.push_back(reinterpret_cast(it->second)); } @@ -860,7 +858,7 @@ Error_Type FusedAttnScoreModForwardFFI(cudaStream_t stream, Buffer_Type q_buf, B Result_Type workspace_buf, Dictionary attrs) { int64_t graph_id = get_attr_value(attrs, "graph_id"); std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), - v_buf.untyped_data()}; + v_buf.untyped_data()}; AppendRemainingBuffers(score_mod_args, &input_ptrs); std::vector output_ptrs = {output_buf->untyped_data(), stats_buf->untyped_data()}; @@ -880,19 +878,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnScoreModForwardHandler, FusedAttnScoreMod .Ret() // workspace .Attrs()); -Error_Type FusedAttnScoreModBackwardFFI( - cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf, - Buffer_Type output_buf, Buffer_Type doutput_buf, Buffer_Type stats_buf, - Variadic_Buffer_Type score_mod_args, Result_Type dq_buf, Result_Type dk_buf, - Result_Type dv_buf, Result_Type workspace_buf, Dictionary attrs) { +Error_Type FusedAttnScoreModBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Buffer_Type output_buf, + Buffer_Type doutput_buf, Buffer_Type stats_buf, + Variadic_Buffer_Type score_mod_args, Result_Type dq_buf, + Result_Type dk_buf, Result_Type dv_buf, + Result_Type workspace_buf, Dictionary attrs) { int64_t graph_id = get_attr_value(attrs, "graph_id"); - std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), - v_buf.untyped_data(), output_buf.untyped_data(), - doutput_buf.untyped_data(), stats_buf.untyped_data()}; + std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), + v_buf.untyped_data(), output_buf.untyped_data(), + doutput_buf.untyped_data(), stats_buf.untyped_data()}; AppendRemainingBuffers(score_mod_args, &input_ptrs); std::vector output_ptrs = {dq_buf->untyped_data(), dk_buf->untyped_data(), - dv_buf->untyped_data()}; + dv_buf->untyped_data()}; return ExecuteScoreModGraph(stream, graph_id, input_ptrs, output_ptrs, workspace_buf->untyped_data()); } From 3bf9e97e88af0d6208ca4961cc42db229f2f60e3 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 13 May 2026 23:17:31 +0000 Subject: [PATCH 04/11] Add distributed JAX score mod attention test Signed-off-by: Vladimir Cherepanov --- tests/jax/test_distributed_fused_attn.py | 124 ++++++++++++++++++++++- 1 file changed, 122 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 50c5de1db7..7a61b208ce 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -6,14 +6,25 @@ import pytest import jax import jax.numpy as jnp +import numpy as np from jax import random +from jax.sharding import Mesh, NamedSharding, PartitionSpec from distributed_test_base import ( generate_configs, generate_context_parallel_configs_for_attn, generate_collectives_count, ) -from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat -from utils import pytest_parametrize_wrapper +from test_fused_attn import ( + FusedAttnRunner, + BiasShape, + SeqDescFormat, + _ScoreModSoftcap, + _has_cudnn_frontend_python, + _reference_attention, + _require_cudnn_frontend_score_mod, +) +from utils import assert_allclose, pytest_parametrize_wrapper +from transformer_engine.jax import autocast from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, AttnBiasType, @@ -25,6 +36,7 @@ inverse_reorder_causal_load_balancing, CPStrategy, ReorderStrategy, + fused_attn, ) @@ -272,6 +284,114 @@ def test_cross_attn( runner.test_backward() +DISTRIBUTED_SCORE_MOD_DATA_SHAPES = { + "L0": [], + "L1": [(4, 16, 4, 64)], + "L2": [(4, 16, 4, 64)], +} + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +class TestDistributedScoreModSelfAttn: + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SCORE_MOD_DATA_SHAPES) + @pytest.mark.parametrize("dtype", DTYPES) + def test_softcap_score_mod_with_aux_params_backward( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + dtype, + ): + _require_cudnn_frontend_score_mod() + batch, seqlen, num_heads, head_dim = data_shape + dp_axis = mesh_resource.dp_resource + tp_axis = mesh_resource.tpsp_resource + + if dp_axis is not None: + dp_size = mesh_shape[mesh_axes.index(dp_axis)] + if batch % dp_size != 0: + pytest.skip(f"{batch=} must be divisible by {dp_size=}") + if tp_axis is not None: + tp_size = mesh_shape[mesh_axes.index(tp_axis)] + if num_heads % tp_size != 0: + pytest.skip(f"{num_heads=} must be divisible by {tp_size=}") + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + qkv_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, tp_axis, None)) + + key = random.PRNGKey(2025) + q_key, k_key, v_key, dout_key = random.split(key, 4) + query = (0.125 * random.normal(q_key, data_shape, dtype=dtype)).astype(dtype) + key_tensor = (0.125 * random.normal(k_key, data_shape, dtype=dtype)).astype(dtype) + value = (0.125 * random.normal(v_key, data_shape, dtype=dtype)).astype(dtype) + doutput = random.normal(dout_key, data_shape, dtype=dtype) + + scaling_factor = head_dim**-0.5 + softcap = 0.8 + softcap_score_mod = _ScoreModSoftcap() + + def score_mod_loss(q, k, v, dout): + out = fused_attn( + (q, k, v), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.NO_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + scaling_factor, + 0.0, + True, + score_mod=softcap_score_mod.forward, + score_mod_bprop=softcap_score_mod.backward, + score_mod_tensors={"softcap": softcap}, + score_mod_bprop_tensors={"softcap": softcap}, + ) + loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32)) + return loss, out + + def ref_loss(q, k, v, dout): + out = _reference_attention(q, k, v, scaling_factor, softcap=softcap) + loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32)) + return loss, out + + jitted_score_mod = jax.jit( + jax.value_and_grad(score_mod_loss, argnums=(0, 1, 2), has_aux=True), + in_shardings=( + qkv_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ), + out_shardings=((None, qkv_sharding), (qkv_sharding, qkv_sharding, qkv_sharding)), + ) + jitted_ref = jax.jit(jax.value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)) + + sharded_args = ( + jax.device_put(query, qkv_sharding), + jax.device_put(key_tensor, qkv_sharding), + jax.device_put(value, qkv_sharding), + jax.device_put(doutput, qkv_sharding), + ) + with mesh, autocast(mesh_resource=mesh_resource): + (score_mod_value, score_mod_out), score_mod_grads = jitted_score_mod(*sharded_args) + (ref_value, ref_out), ref_grads = jitted_ref(query, key_tensor, value, doutput) + + assert score_mod_out.sharding == qkv_sharding + for grad in score_mod_grads: + assert grad.sharding == qkv_sharding + + assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2) + assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2) + + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [ pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"), From 29bbac74eebd3af726bfe1b3cc727aa5699cd7c6 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 15 May 2026 03:35:17 +0000 Subject: [PATCH 05/11] Address JAX score_mod review items Signed-off-by: Vladimir Cherepanov --- tests/jax/test_fused_attn.py | 64 +++++--- transformer_engine/jax/attention.py | 2 +- .../jax/cpp_extensions/attention.py | 149 ++++++++++++++++-- 3 files changed, 179 insertions(+), 36 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 1a7658fb8a..73a5404a98 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -2,7 +2,6 @@ # # See LICENSE for license information. """Tests for fused attention""" -import importlib.util import os from enum import Enum, auto from dataclasses import dataclass, field @@ -24,6 +23,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.typing import ArrayLike, DTypeLike +import transformer_engine.jax.cpp_extensions.attention as tex_attention from transformer_engine.jax import autocast from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.attention import ( @@ -56,11 +56,15 @@ def _has_cudnn_frontend_python(): - return importlib.util.find_spec("cudnn") is not None + try: + tex_attention._import_cudnn_for_score_mod() + except ImportError: + return False + return True def _score_mod_causal(graph, score, tensors): - import cudnn # pylint: disable=import-outside-toplevel + cudnn = tex_attention._import_cudnn_for_score_mod() row_index = graph.gen_index( input=score, @@ -84,7 +88,7 @@ def _score_mod_causal(graph, score, tensors): def _score_mod_causal_bprop(graph, dscore, tensors): - import cudnn # pylint: disable=import-outside-toplevel + cudnn = tex_attention._import_cudnn_for_score_mod() row_index = graph.gen_index( input=dscore, @@ -107,8 +111,8 @@ def _score_mod_causal_bprop(graph, dscore, tensors): return graph.binary_select(input0=dscore, input1=tensors["zero"], mask=keep) -def _score_mod_relative_position(graph, score, tensors): - import cudnn # pylint: disable=import-outside-toplevel +def _score_mod_post_scale_bias(graph, score, tensors): + cudnn = tex_attention._import_cudnn_for_score_mod() row_index = graph.gen_index( input=score, @@ -122,15 +126,15 @@ def _score_mod_relative_position(graph, score, tensors): compute_data_type=cudnn.data_type.INT32, ) col_index.set_data_type(cudnn.data_type.INT32) - relative_position = graph.sub( + post_scale_bias = graph.sub( a=row_index, b=col_index, compute_data_type=cudnn.data_type.FLOAT, ) - relative_position.set_data_type(cudnn.data_type.FLOAT) + post_scale_bias.set_data_type(cudnn.data_type.FLOAT) return graph.add( a=score, - b=relative_position, + b=post_scale_bias, compute_data_type=cudnn.data_type.FLOAT, ) @@ -141,8 +145,12 @@ class _ScoreModSoftcap: def __init__(self): self.before_tanh_activation = None + def score_mod_graph_cache_key(self): + """Graph topology key for softcap score_mod.""" + return ("softcap",) + def forward(self, graph, score, tensors): - import cudnn # pylint: disable=import-outside-toplevel + cudnn = tex_attention._import_cudnn_for_score_mod() self.before_tanh_activation = graph.div( a=score, @@ -159,7 +167,7 @@ def forward(self, graph, score, tensors): ) def backward(self, graph, dscore, tensors): - import cudnn # pylint: disable=import-outside-toplevel + cudnn = tex_attention._import_cudnn_for_score_mod() d_tanh_out = graph.mul( a=dscore, @@ -181,14 +189,14 @@ def backward(self, graph, dscore, tensors): def _reference_attention( - query, key, value, scale, *, causal=False, relative_position=False, softcap=None + query, key, value, scale, *, causal=False, post_scale_bias=False, softcap=None ): scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale if causal: q_pos = jnp.arange(query.shape[1])[:, None] kv_pos = jnp.arange(key.shape[1])[None, :] scores = jnp.where(q_pos >= kv_pos, scores, -1e9) - if relative_position: + if post_scale_bias: q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None] kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :] scores = scores + q_pos - kv_pos @@ -283,7 +291,10 @@ def general_dot_product_attention( def _require_cudnn_frontend_score_mod(): - cudnn = pytest.importorskip("cudnn", reason="cuDNN Python frontend is required for score_mod") + try: + cudnn = tex_attention._import_cudnn_for_score_mod() + except ImportError: + pytest.skip("cuDNN Python frontend is required for score_mod") version = tuple(int(part) for part in cudnn.backend_version_string().split(".")[:2]) if version < (9, 6): pytest.skip("cuDNN score_mod SDPA requires cuDNN frontend 9.6 or newer") @@ -375,11 +386,28 @@ def test_fused_attn_score_mod_config_stabilizes_bound_method_cache_keys(): assert config_1 == config_2 assert hash(config_1) == hash(config_2) - assert config_1 != config_3 + assert config_1 == config_3 + + +def test_fused_attn_score_mod_config_leaves_unkeyed_bound_methods_uncached(): + class UnkeyedScoreMod: + def forward(self, _graph, score, _tensors): + return score + + score_mod = UnkeyedScoreMod() + config_1, _, _ = make_fused_attn_score_mod_config( + score_mod.forward, None, None, None, 0.125, True + ) + config_2, _, _ = make_fused_attn_score_mod_config( + score_mod.forward, None, None, None, 0.125, True + ) + + assert config_1 != config_2 + assert tex_attention._graph_cache_key("fwd", config_1, ()) is None @pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") -def test_fused_attn_score_mod_relative_position_optional_bprop(): +def test_fused_attn_score_mod_post_scale_bias_optional_bprop(): _require_cudnn_frontend_score_mod() key = jax.random.key(0) @@ -402,12 +430,12 @@ def score_mod_loss(query, key_, value): scale, 0.0, True, - score_mod=_score_mod_relative_position, + score_mod=_score_mod_post_scale_bias, ) return jnp.sum(out.astype(jnp.float32)), out def ref_loss(query, key_, value): - out = _reference_attention(query, key_, value, scale, relative_position=True) + out = _reference_attention(query, key_, value, scale, post_scale_bias=True) return jnp.sum(out.astype(jnp.float32)), out (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index adbbcd02fa..dbee272046 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -1574,7 +1574,7 @@ def fused_attn( Currently, a stripe_size > 1 is only supported for CP + THD + Striped + AG, whereas a stripe_size=1 is supported for both, CP + THD + Striped + AG and CP + THD + Striped + P2P(Ring) None indicates no striping strategy - score_mod (Optional[Callable]): Optional cuDNN frontend score modification callback. + score_mod (Optional[Callable]): Experimental cuDNN frontend score modification callback. The callback is called as `score_mod(graph, score, tensors)` while building a cuDNN frontend graph. When provided, this path only supports BSHD_BSHD_BSHD layout and is mutually exclusive with masks, padding, bias, dropout, context diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index fa9d9f1467..c7ef11688e 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -2,8 +2,12 @@ # # See LICENSE for license information. """JAX/TE custom ops for attention""" +import importlib +import inspect import operator import os +from pathlib import Path +import sys import warnings from dataclasses import dataclass, replace from functools import partial, reduce @@ -60,6 +64,10 @@ "fused_attn_score_mod_bwd", ] +_CUDNN_FRONTEND_PYTHON_PATH = ( + Path(__file__).resolve().parents[3] / "3rdparty" / "cudnn-frontend" / "python" +) + @partial( jax.tree_util.register_dataclass, @@ -282,15 +290,101 @@ class _ScoreModScalarSpec: stride: Tuple[int, ...] = (1, 1, 1, 1) -def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Optional[Tuple[Any, ...]]: - """Return a stable cache key for callbacks that may be bound methods.""" +class _UncacheableScoreModKey: + """Unique static key for callbacks that must not share compiled score_mod graphs.""" + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return self is other + + +def _score_mod_key_is_uncacheable(key: Any) -> bool: + return isinstance(key, _UncacheableScoreModKey) + + +def _freeze_score_mod_cache_key(value: Any) -> Any: + """Convert a user-provided score_mod graph key into a hashable structure.""" + if _is_array_operand(value): + raise TypeError( + "score_mod_graph_cache_key() must not include tensors. Pass runtime tensors " + "through score_mod_tensors or score_mod_bprop_tensors instead." + ) + if isinstance(value, Mapping): + items = ( + ( + _freeze_score_mod_cache_key(key), + _freeze_score_mod_cache_key(val), + ) + for key, val in value.items() + ) + return tuple(sorted(items, key=repr)) + if isinstance(value, (list, tuple)): + return tuple(_freeze_score_mod_cache_key(item) for item in value) + if isinstance(value, (set, frozenset)): + items = (_freeze_score_mod_cache_key(item) for item in value) + return tuple(sorted(items, key=repr)) + try: + hash(value) + except TypeError as exc: + raise TypeError( + "score_mod_graph_cache_key() must return a hashable value or a nested " + "combination of mapping/list/tuple/set values." + ) from exc + return value + + +def _score_mod_explicit_cache_key(callback_owner: Any) -> Optional[Any]: + """Return a user-provided structural graph key for a score_mod callback.""" + explicit_key = getattr(callback_owner, "score_mod_graph_cache_key", None) + if explicit_key is None: + return None + explicit_key = explicit_key() if callable(explicit_key) else explicit_key + return _freeze_score_mod_cache_key(explicit_key) + + +def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Any: + """Create a stable graph cache key for a score_mod callable. + + Module-level functions are assumed to have stable topology. Stateful bound methods and + callable instances need an explicit score_mod_graph_cache_key(); otherwise their graphs + are left uncached to avoid reusing stale graphs after Python object address reuse. + """ if callback is None: return None self_obj = getattr(callback, "__self__", None) - func = getattr(callback, "__func__", None) - if self_obj is not None and func is not None: - return ("bound_method", id(self_obj), id(func)) - return ("callable", id(callback)) + func_obj = getattr(callback, "__func__", None) + if self_obj is not None and func_obj is not None: + explicit_key = _score_mod_explicit_cache_key(self_obj) + if explicit_key is None: + return _UncacheableScoreModKey() + return ( + "bound_method", + type(self_obj), + func_obj.__module__, + func_obj.__qualname__, + explicit_key, + ) + + explicit_key = _score_mod_explicit_cache_key(callback) + if explicit_key is not None: + return ( + "callable", + type(callback), + getattr(callback, "__module__", None), + getattr(callback, "__qualname__", None), + explicit_key, + ) + + if ( + inspect.isfunction(callback) + and callback.__closure__ is None + and "" not in callback.__qualname__ + ): + return ("function", callback.__module__, callback.__qualname__) + + return _UncacheableScoreModKey() @dataclass(frozen=True) @@ -299,8 +393,8 @@ class _FusedAttnScoreModConfig: score_mod: Callable score_mod_bprop: Optional[Callable] - score_mod_key: Tuple[Any, ...] - score_mod_bprop_key: Optional[Tuple[Any, ...]] + score_mod_key: Any + score_mod_bprop_key: Any score_mod_tensor_names: Tuple[str, ...] score_mod_bprop_tensor_names: Tuple[str, ...] score_mod_scalars: Tuple[_ScoreModScalarSpec, ...] @@ -572,7 +666,12 @@ def _graph_cache_key( direction: str, config: _FusedAttnScoreModConfig, avals: Sequence[Any], -) -> Tuple[Any, ...]: +) -> Optional[Tuple[Any, ...]]: + if ( + _score_mod_key_is_uncacheable(config.score_mod_key) + or _score_mod_key_is_uncacheable(config.score_mod_bprop_key) + ): + return None return ( direction, config, @@ -585,13 +684,16 @@ def _shape_dtype(value) -> jax.ShapeDtypeStruct: def _import_cudnn_for_score_mod(): + cudnn_frontend_path = str(_CUDNN_FRONTEND_PYTHON_PATH) + cudnn_frontend_package = _CUDNN_FRONTEND_PYTHON_PATH / "cudnn" + if any(cudnn_frontend_package.glob("_compiled_module*")) and cudnn_frontend_path not in sys.path: + sys.path.insert(0, cudnn_frontend_path) try: - import cudnn # pylint: disable=import-outside-toplevel + return importlib.import_module("cudnn") except ImportError as exc: raise ImportError( "score_mod fused_attn requires the cuDNN frontend Python package (`cudnn`)." ) from exc - return cudnn def _build_score_mod_fwd_graph(q_aval, k_aval, v_aval, score_mod_avals, config): @@ -787,11 +889,16 @@ def fused_attn_score_mod_fwd( q_aval, k_aval, v_aval = map(_shape_dtype, (q, k, v)) score_mod_avals = tuple(_shape_dtype(arg) for arg in score_mod_tensors) key = _graph_cache_key("fwd", config, (q_aval, k_aval, v_aval, *score_mod_avals)) - if key not in _score_mod_graph_cache: - _score_mod_graph_cache[key] = _build_score_mod_fwd_graph( + if key is None: + graph_id, workspace_size = _build_score_mod_fwd_graph( q_aval, k_aval, v_aval, score_mod_avals, config ) - graph_id, workspace_size = _score_mod_graph_cache[key] + else: + if key not in _score_mod_graph_cache: + _score_mod_graph_cache[key] = _build_score_mod_fwd_graph( + q_aval, k_aval, v_aval, score_mod_avals, config + ) + graph_id, workspace_size = _score_mod_graph_cache[key] batch, q_seqlen, q_heads, _ = q.shape _, _, _, v_head_dim = v.shape @@ -824,14 +931,22 @@ def fused_attn_score_mod_bwd( all_inputs = (*all_inputs, *score_mod_bprop_tensors) avals = tuple(_shape_dtype(arg) for arg in all_inputs) key = _graph_cache_key("bwd", config, avals) - if key not in _score_mod_graph_cache: - _score_mod_graph_cache[key] = _build_score_mod_bwd_graph( + if key is None: + graph_id, workspace_size = _build_score_mod_bwd_graph( *avals[:6], avals[6 : 6 + len(score_mod_tensors)], avals[6 + len(score_mod_tensors) :], config, ) - graph_id, workspace_size = _score_mod_graph_cache[key] + else: + if key not in _score_mod_graph_cache: + _score_mod_graph_cache[key] = _build_score_mod_bwd_graph( + *avals[:6], + avals[6 : 6 + len(score_mod_tensors)], + avals[6 + len(score_mod_tensors) :], + config, + ) + graph_id, workspace_size = _score_mod_graph_cache[key] dq = jax.ShapeDtypeStruct(q.shape, q.dtype) dk = jax.ShapeDtypeStruct(k.shape, k.dtype) From c597af52d5eb7ff37204efd6778d68030bca9784 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 03:38:51 +0000 Subject: [PATCH 06/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/attention.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index c7ef11688e..970f7267b1 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -667,9 +667,8 @@ def _graph_cache_key( config: _FusedAttnScoreModConfig, avals: Sequence[Any], ) -> Optional[Tuple[Any, ...]]: - if ( - _score_mod_key_is_uncacheable(config.score_mod_key) - or _score_mod_key_is_uncacheable(config.score_mod_bprop_key) + if _score_mod_key_is_uncacheable(config.score_mod_key) or _score_mod_key_is_uncacheable( + config.score_mod_bprop_key ): return None return ( @@ -686,7 +685,10 @@ def _shape_dtype(value) -> jax.ShapeDtypeStruct: def _import_cudnn_for_score_mod(): cudnn_frontend_path = str(_CUDNN_FRONTEND_PYTHON_PATH) cudnn_frontend_package = _CUDNN_FRONTEND_PYTHON_PATH / "cudnn" - if any(cudnn_frontend_package.glob("_compiled_module*")) and cudnn_frontend_path not in sys.path: + if ( + any(cudnn_frontend_package.glob("_compiled_module*")) + and cudnn_frontend_path not in sys.path + ): sys.path.insert(0, cudnn_frontend_path) try: return importlib.import_module("cudnn") From f8bd844bbb8fc56f17661a0e169101b5bd3f1978 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 18 May 2026 23:54:22 +0000 Subject: [PATCH 07/11] Use serialized cuDNN graphs for score_mod attention Signed-off-by: Vladimir Cherepanov --- build_tools/jax.py | 14 ++ tests/jax/test_fused_attn.py | 20 ++ .../jax/cpp_extensions/attention.py | 176 ++++++++++++--- transformer_engine/jax/csrc/extensions.h | 7 - .../jax/csrc/extensions/attention.cpp | 211 ++++++++++-------- .../jax/csrc/extensions/pybind.cpp | 2 +- .../jax/csrc/extensions/utils.cpp | 3 + .../jax/csrc/extensions/utils.h | 1 + 8 files changed, 301 insertions(+), 133 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index a7b200f915..f8cf695c16 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -83,6 +83,20 @@ def setup_jax_extension( # Header files include_dirs = get_cuda_include_dirs() + cudnn_frontend_include_dir = None + for base_path in (Path(common_header_files), *Path(common_header_files).parents): + candidate = base_path / "3rdparty" / "cudnn-frontend" / "include" + if candidate.exists(): + cudnn_frontend_include_dir = candidate + break + if cudnn_frontend_include_dir is None: + for base_path in Path(__file__).resolve().parents: + candidate = base_path / "3rdparty" / "cudnn-frontend" / "include" + if candidate.exists(): + cudnn_frontend_include_dir = candidate + break + if cudnn_frontend_include_dir is not None: + include_dirs.append(cudnn_frontend_include_dir) include_dirs.extend( [ common_header_files, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 73a5404a98..2e4640a470 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -348,6 +348,26 @@ def test_fused_attn_score_mod_config_splits_tensors_and_pass_by_value_scalars(): assert len(config.score_mod_scalars[0].value) == np.dtype(np.float32).itemsize +def test_fused_attn_score_mod_cudnn_frontend_version_check(monkeypatch): + class FakeCudnn: + __version__ = "1.22.0" + + monkeypatch.setattr( + tex_attention.transformer_engine_jax, + "get_cudnn_frontend_version", + lambda: 12200, + ) + assert tex_attention._check_cudnn_frontend_version_match(FakeCudnn) == 12200 + + monkeypatch.setattr( + tex_attention.transformer_engine_jax, + "get_cudnn_frontend_version", + lambda: 12100, + ) + with pytest.raises(RuntimeError, match="Python/C\\+\\+ version mismatch"): + tex_attention._check_cudnn_frontend_version_match(FakeCudnn) + + def test_fused_attn_score_mod_config_stabilizes_bound_method_cache_keys(): softcap_score_mod = _ScoreModSoftcap() first_forward = softcap_score_mod.forward diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 970f7267b1..8ff41ba3e7 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """JAX/TE custom ops for attention""" +import hashlib import importlib import inspect import operator @@ -434,6 +435,21 @@ def __eq__(self, other): ) +@dataclass(frozen=True) +class _SerializedScoreModGraph: + """Serialized cuDNN frontend graph and static metadata for C++ execution.""" + + serialized_graph: bytes + graph_hash: Tuple[int, int] + cudnn_frontend_version: int + workspace_size: int + input_uids: np.ndarray + output_uids: np.ndarray + scalar_uids: np.ndarray + scalar_sizes: np.ndarray + scalar_values: np.ndarray + + _SCORE_MOD_UID_Q = 1 _SCORE_MOD_UID_K = 2 _SCORE_MOD_UID_V = 3 @@ -448,7 +464,7 @@ def __eq__(self, other): _SCORE_MOD_FWD_SCALAR_UID_BASE = 3000 _SCORE_MOD_BPROP_SCALAR_UID_BASE = 4000 -_score_mod_graph_cache: Dict[Tuple[Any, ...], Tuple[int, int]] = {} +_score_mod_graph_cache: Dict[Tuple[Any, ...], _SerializedScoreModGraph] = {} def _row_major_stride(shape: Sequence[int]) -> Tuple[int, ...]: @@ -640,6 +656,75 @@ def _score_mod_graph_tensors( return graph_tensors, tuple(tensor_uids), tuple(scalar_uids), tuple(scalar_values) +def _encode_cudnn_frontend_version(version: str) -> int: + public_version = version.split("+", 1)[0].split("-", 1)[0] + parts = public_version.split(".") + if len(parts) < 3: + raise RuntimeError(f"Could not parse cuDNN frontend Python version: {version!r}.") + major, minor, patch = (int(part) for part in parts[:3]) + return major * 10000 + minor * 100 + patch + + +def _check_cudnn_frontend_version_match(cudnn) -> int: + python_version_string = getattr(cudnn, "__version__", None) + if python_version_string is None: + raise RuntimeError("cuDNN frontend Python package does not expose __version__.") + python_version = _encode_cudnn_frontend_version(python_version_string) + cpp_version = int(transformer_engine_jax.get_cudnn_frontend_version()) + if python_version != cpp_version: + raise RuntimeError( + "cuDNN frontend Python/C++ version mismatch for score_mod graph serialization: " + f"Python cudnn.__version__={python_version_string!r} encodes to {python_version}, " + f"but Transformer Engine C++ was built with CUDNN_FRONTEND_VERSION={cpp_version}. " + "Use matching cuDNN frontend Python package and C++ headers." + ) + return python_version + + +def _score_mod_graph_hash(serialized_graph: bytes) -> Tuple[int, int]: + digest = hashlib.sha256(serialized_graph).digest() + return ( + int.from_bytes(digest[0:8], byteorder="little", signed=True), + int.from_bytes(digest[8:16], byteorder="little", signed=True), + ) + + +def _pack_score_mod_scalar_values( + scalar_values: Sequence[bytes], +) -> Tuple[np.ndarray, np.ndarray]: + scalar_sizes = np.asarray([len(value) for value in scalar_values], dtype=np.int64) + packed_values = np.zeros((len(scalar_values), 16), dtype=np.uint8) + for index, value in enumerate(scalar_values): + if len(value) > 16: + raise ValueError("score_mod pass-by-value scalars must be at most 16 bytes.") + packed_values[index, : len(value)] = np.frombuffer(value, dtype=np.uint8) + return scalar_sizes, packed_values.reshape(-1) + + +def _serialized_score_mod_graph( + *, + serialized_graph: bytes, + cudnn_frontend_version: int, + workspace_size: int, + input_uids: Sequence[int], + output_uids: Sequence[int], + scalar_uids: Sequence[int], + scalar_values: Sequence[bytes], +) -> _SerializedScoreModGraph: + scalar_sizes, packed_scalar_values = _pack_score_mod_scalar_values(scalar_values) + return _SerializedScoreModGraph( + serialized_graph=serialized_graph, + graph_hash=_score_mod_graph_hash(serialized_graph), + cudnn_frontend_version=int(cudnn_frontend_version), + workspace_size=int(workspace_size), + input_uids=np.asarray(input_uids, dtype=np.int64), + output_uids=np.asarray(output_uids, dtype=np.int64), + scalar_uids=np.asarray(scalar_uids, dtype=np.int64), + scalar_sizes=scalar_sizes, + scalar_values=packed_scalar_values, + ) + + def _wrap_score_mod(score_mod: Optional[Callable], graph_tensors: Dict[str, Any]): if score_mod is None: return None @@ -650,7 +735,7 @@ def wrapped_score_mod(sdpa_graph, score_tensor): return wrapped_score_mod -def _finalize_score_mod_graph(cudnn, graph) -> int: +def _finalize_score_mod_graph(cudnn, graph) -> Tuple[int, bytes, int]: graph.validate() graph.build_operation_graph() try: @@ -659,7 +744,12 @@ def _finalize_score_mod_graph(cudnn, graph) -> int: except cudnn.cudnnGraphNotSupportedError as exc: raise RuntimeError(f"cuDNN score_mod SDPA graph is not supported: {exc}") from exc graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE) - return max(int(graph.get_workspace_size()), 1) + serialized_graph = bytes(graph.serialize()) + return ( + max(int(graph.get_workspace_size()), 1), + serialized_graph, + _check_cudnn_frontend_version_match(cudnn), + ) def _graph_cache_key( @@ -691,11 +781,13 @@ def _import_cudnn_for_score_mod(): ): sys.path.insert(0, cudnn_frontend_path) try: - return importlib.import_module("cudnn") + cudnn = importlib.import_module("cudnn") except ImportError as exc: raise ImportError( "score_mod fused_attn requires the cuDNN frontend Python package (`cudnn`)." ) from exc + _check_cudnn_frontend_version_match(cudnn) + return cudnn def _build_score_mod_fwd_graph(q_aval, k_aval, v_aval, score_mod_avals, config): @@ -757,16 +849,16 @@ def _build_score_mod_fwd_graph(q_aval, k_aval, v_aval, score_mod_avals, config): stats.set_data_type(cudnn.data_type.FLOAT) output_uids.append(_SCORE_MOD_UID_STATS) - workspace_size = _finalize_score_mod_graph(cudnn, graph) - graph_id = transformer_engine_jax.register_fused_attn_score_mod_graph( - graph, - [int(uid) for uid in graph._get_variant_pack_uids_sorted()], - [_SCORE_MOD_UID_Q, _SCORE_MOD_UID_K, _SCORE_MOD_UID_V, *tensor_uids], - output_uids, - list(scalar_uids), - list(scalar_values), + workspace_size, serialized_graph, frontend_version = _finalize_score_mod_graph(cudnn, graph) + return _serialized_score_mod_graph( + serialized_graph=serialized_graph, + cudnn_frontend_version=frontend_version, + workspace_size=workspace_size, + input_uids=[_SCORE_MOD_UID_Q, _SCORE_MOD_UID_K, _SCORE_MOD_UID_V, *tensor_uids], + output_uids=output_uids, + scalar_uids=scalar_uids, + scalar_values=scalar_values, ) - return graph_id, workspace_size def _build_score_mod_bwd_graph( @@ -860,11 +952,12 @@ def _build_score_mod_bwd_graph( dk.set_output(True).set_uid(_SCORE_MOD_UID_DK).set_dim(k_dim).set_stride(k_stride) dv.set_output(True).set_uid(_SCORE_MOD_UID_DV).set_dim(v_dim).set_stride(v_stride) - workspace_size = _finalize_score_mod_graph(cudnn, graph) - graph_id = transformer_engine_jax.register_fused_attn_score_mod_graph( - graph, - [int(uid) for uid in graph._get_variant_pack_uids_sorted()], - [ + workspace_size, serialized_graph, frontend_version = _finalize_score_mod_graph(cudnn, graph) + return _serialized_score_mod_graph( + serialized_graph=serialized_graph, + cudnn_frontend_version=frontend_version, + workspace_size=workspace_size, + input_uids=[ _SCORE_MOD_UID_Q, _SCORE_MOD_UID_K, _SCORE_MOD_UID_V, @@ -874,11 +967,10 @@ def _build_score_mod_bwd_graph( *tensor_uids, *bprop_tensor_uids, ], - [_SCORE_MOD_UID_DQ, _SCORE_MOD_UID_DK, _SCORE_MOD_UID_DV], - [*scalar_uids, *bprop_scalar_uids], - [*scalar_values, *bprop_scalar_values], + output_uids=[_SCORE_MOD_UID_DQ, _SCORE_MOD_UID_DK, _SCORE_MOD_UID_DV], + scalar_uids=[*scalar_uids, *bprop_scalar_uids], + scalar_values=[*scalar_values, *bprop_scalar_values], ) - return graph_id, workspace_size def fused_attn_score_mod_fwd( @@ -892,26 +984,38 @@ def fused_attn_score_mod_fwd( score_mod_avals = tuple(_shape_dtype(arg) for arg in score_mod_tensors) key = _graph_cache_key("fwd", config, (q_aval, k_aval, v_aval, *score_mod_avals)) if key is None: - graph_id, workspace_size = _build_score_mod_fwd_graph( - q_aval, k_aval, v_aval, score_mod_avals, config - ) + graph = _build_score_mod_fwd_graph(q_aval, k_aval, v_aval, score_mod_avals, config) else: if key not in _score_mod_graph_cache: _score_mod_graph_cache[key] = _build_score_mod_fwd_graph( q_aval, k_aval, v_aval, score_mod_avals, config ) - graph_id, workspace_size = _score_mod_graph_cache[key] + graph = _score_mod_graph_cache[key] batch, q_seqlen, q_heads, _ = q.shape _, _, _, v_head_dim = v.shape output_shape = jax.ShapeDtypeStruct((batch, q_seqlen, q_heads, v_head_dim), q.dtype) stats_shape = (batch, q_heads, q_seqlen, 1) if config.is_training else (0,) stats = jax.ShapeDtypeStruct(stats_shape, jnp.float32) - workspace = jax.ShapeDtypeStruct((workspace_size,), jnp.uint8) + workspace = jax.ShapeDtypeStruct((graph.workspace_size,), jnp.uint8) output, softmax_stats, _ = ffi.ffi_call( "te_fused_attn_score_mod_forward_ffi", (output_shape, stats, workspace), - )(q, k, v, *score_mod_tensors, graph_id=graph_id) + )( + q, + k, + v, + *score_mod_tensors, + serialized_graph=graph.serialized_graph, + graph_hash0=graph.graph_hash[0], + graph_hash1=graph.graph_hash[1], + cudnn_frontend_version=graph.cudnn_frontend_version, + input_uids=graph.input_uids, + output_uids=graph.output_uids, + scalar_uids=graph.scalar_uids, + scalar_sizes=graph.scalar_sizes, + scalar_values=graph.scalar_values, + ) return output, softmax_stats @@ -934,7 +1038,7 @@ def fused_attn_score_mod_bwd( avals = tuple(_shape_dtype(arg) for arg in all_inputs) key = _graph_cache_key("bwd", config, avals) if key is None: - graph_id, workspace_size = _build_score_mod_bwd_graph( + graph = _build_score_mod_bwd_graph( *avals[:6], avals[6 : 6 + len(score_mod_tensors)], avals[6 + len(score_mod_tensors) :], @@ -948,12 +1052,12 @@ def fused_attn_score_mod_bwd( avals[6 + len(score_mod_tensors) :], config, ) - graph_id, workspace_size = _score_mod_graph_cache[key] + graph = _score_mod_graph_cache[key] dq = jax.ShapeDtypeStruct(q.shape, q.dtype) dk = jax.ShapeDtypeStruct(k.shape, k.dtype) dv = jax.ShapeDtypeStruct(v.shape, v.dtype) - workspace = jax.ShapeDtypeStruct((workspace_size,), jnp.uint8) + workspace = jax.ShapeDtypeStruct((graph.workspace_size,), jnp.uint8) dq, dk, dv, _ = ffi.ffi_call( "te_fused_attn_score_mod_backward_ffi", (dq, dk, dv, workspace), @@ -966,7 +1070,15 @@ def fused_attn_score_mod_bwd( softmax_stats, *score_mod_tensors, *score_mod_bprop_tensors, - graph_id=graph_id, + serialized_graph=graph.serialized_graph, + graph_hash0=graph.graph_hash[0], + graph_hash1=graph.graph_hash[1], + cudnn_frontend_version=graph.cudnn_frontend_version, + input_uids=graph.input_uids, + output_uids=graph.output_uids, + scalar_uids=graph.scalar_uids, + scalar_sizes=graph.scalar_sizes, + scalar_values=graph.scalar_values, ) return dq, dk, dv diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index f17788a068..f886c4ed4d 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -173,13 +173,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal); -int64_t RegisterFusedAttnScoreModGraph(pybind11::object graph, - const std::vector& user_uids, - const std::vector& input_uids, - const std::vector& output_uids, - const std::vector& scalar_uids, - const std::vector& scalar_values); - // GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmV2Handler); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 665ac40277..5cdf8315bf 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -6,10 +6,16 @@ #include #include -#include +#include +#include #include #include +#include #include +#include +#include + +#include #include "../extensions.h" #include "transformer_engine/fused_attn.h" @@ -703,30 +709,45 @@ struct ScoreModScalarStorage { size_t size = 0; }; -struct ScoreModGraphEntry { - PyObject *py_graph = nullptr; - std::vector user_uids; - std::vector input_uids; - std::vector output_uids; - std::vector scalar_uids; - std::vector scalar_values; +struct ScoreModGraphCacheKey { + int device_id = 0; + int64_t hash0 = 0; + int64_t hash1 = 0; + int64_t frontend_version = 0; + + bool operator==(const ScoreModGraphCacheKey &other) const { + return device_id == other.device_id && hash0 == other.hash0 && hash1 == other.hash1 && + frontend_version == other.frontend_version; + } +}; + +struct ScoreModGraphCacheKeyHash { + size_t operator()(const ScoreModGraphCacheKey &key) const { + size_t seed = std::hash{}(key.device_id); + auto combine = [&seed](int64_t value) { + seed ^= std::hash{}(value) + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2); + }; + combine(key.hash0); + combine(key.hash1); + combine(key.frontend_version); + return seed; + } }; -std::unordered_map> &ScoreModGraphRegistry() { - static std::unordered_map> registry; - return registry; +using ScoreModGraphPtr = std::shared_ptr; + +std::unordered_map + &ScoreModGraphCache() { + static std::unordered_map + cache; + return cache; } -std::mutex &ScoreModGraphRegistryMutex() { +std::mutex &ScoreModGraphCacheMutex() { static std::mutex mutex; return mutex; } -std::atomic &NextScoreModGraphId() { - static std::atomic next_id{1}; - return next_id; -} - struct ScoreModCudnnHandleCache { std::unordered_map handles; @@ -754,55 +775,98 @@ cudnnHandle_t GetScoreModCudnnHandle() { return cache.GetHandle(); } -std::shared_ptr GetScoreModGraphEntry(int64_t graph_id) { - std::lock_guard lock(ScoreModGraphRegistryMutex()); - auto ®istry = ScoreModGraphRegistry(); - auto it = registry.find(graph_id); - NVTE_CHECK(it != registry.end(), "Unknown cuDNN score_mod graph id: ", graph_id); - return it->second; +ScoreModGraphCacheKey GetScoreModGraphCacheKey(Dictionary &attrs) { + const int64_t frontend_version = get_attr_value(attrs, "cudnn_frontend_version"); + NVTE_CHECK(frontend_version == CUDNN_FRONTEND_VERSION, + "cuDNN frontend version mismatch for score_mod graph deserialization: graph was " + "serialized with Python cuDNN frontend version ", + frontend_version, ", but Transformer Engine C++ was built with CUDNN_FRONTEND_VERSION ", + CUDNN_FRONTEND_VERSION, "."); + + int device_id = 0; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + return ScoreModGraphCacheKey{ + device_id, + get_attr_value(attrs, "graph_hash0"), + get_attr_value(attrs, "graph_hash1"), + frontend_version, + }; +} + +ScoreModGraphPtr GetScoreModGraph(cudaStream_t stream, Dictionary &attrs) { + const auto key = GetScoreModGraphCacheKey(attrs); + { + std::lock_guard lock(ScoreModGraphCacheMutex()); + auto &cache = ScoreModGraphCache(); + auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + } + + const auto serialized_graph = get_attr_value(attrs, "serialized_graph"); + std::vector serialized_data(serialized_graph.begin(), serialized_graph.end()); + + auto handle = GetScoreModCudnnHandle(); + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + + auto graph = std::make_shared(); + auto status = graph->deserialize(handle, serialized_data); + NVTE_CHECK(status.is_good(), "Failed to deserialize cuDNN score_mod SDPA graph: ", + status.get_message()); + + std::lock_guard lock(ScoreModGraphCacheMutex()); + auto &cache = ScoreModGraphCache(); + auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + cache.emplace(key, graph); + return graph; } -Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id, +Error_Type ExecuteScoreModGraph(cudaStream_t stream, Dictionary &attrs, const std::vector &input_ptrs, const std::vector &output_ptrs, void *workspace) { - auto entry = GetScoreModGraphEntry(graph_id); - NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), "cuDNN score_mod graph expected ", - entry->input_uids.size(), " inputs but got ", input_ptrs.size()); - NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(), - "cuDNN score_mod graph expected at least ", entry->output_uids.size(), + auto graph = GetScoreModGraph(stream, attrs); + auto input_uids = get_attr_value>(attrs, "input_uids"); + auto output_uids = get_attr_value>(attrs, "output_uids"); + auto scalar_uids = get_attr_value>(attrs, "scalar_uids"); + auto scalar_sizes = get_attr_value>(attrs, "scalar_sizes"); + auto scalar_values = get_attr_value>(attrs, "scalar_values"); + + NVTE_CHECK(input_ptrs.size() == input_uids.size(), "cuDNN score_mod graph expected ", + input_uids.size(), " inputs but got ", input_ptrs.size()); + NVTE_CHECK(output_ptrs.size() >= output_uids.size(), + "cuDNN score_mod graph expected at least ", output_uids.size(), " outputs but got ", output_ptrs.size()); + NVTE_CHECK(scalar_uids.size() == scalar_sizes.size(), + "Mismatched score_mod scalar uid/value-size counts."); + NVTE_CHECK(scalar_values.size() == scalar_uids.size() * 16, + "Mismatched score_mod packed scalar value size."); std::unordered_map variant_pack; - for (size_t i = 0; i < entry->input_uids.size(); ++i) { - variant_pack.emplace(entry->input_uids[i], input_ptrs[i]); + for (size_t i = 0; i < input_uids.size(); ++i) { + variant_pack.emplace(input_uids[i], input_ptrs[i]); } - for (size_t i = 0; i < entry->output_uids.size(); ++i) { - variant_pack.emplace(entry->output_uids[i], output_ptrs[i]); - } - for (size_t i = 0; i < entry->scalar_uids.size(); ++i) { - variant_pack.emplace(entry->scalar_uids[i], entry->scalar_values[i].data.data()); + for (size_t i = 0; i < output_uids.size(); ++i) { + variant_pack.emplace(output_uids[i], output_ptrs[i]); } - std::vector user_ptrs; - user_ptrs.reserve(entry->user_uids.size()); - for (const auto uid : entry->user_uids) { - auto it = variant_pack.find(uid); - NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", uid); - user_ptrs.push_back(reinterpret_cast(it->second)); + std::vector scalar_storage(scalar_uids.size()); + for (size_t i = 0; i < scalar_uids.size(); ++i) { + NVTE_CHECK(scalar_sizes[i] >= 0 && scalar_sizes[i] <= 16, + "score_mod pass-by-value scalars must be at most 16 bytes."); + scalar_storage[i].size = static_cast(scalar_sizes[i]); + std::copy_n(scalar_values.begin() + i * 16, 16, scalar_storage[i].data.begin()); + variant_pack.emplace(scalar_uids[i], scalar_storage[i].data.data()); } auto handle = GetScoreModCudnnHandle(); NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - { - pybind11::gil_scoped_acquire gil; - try { - auto graph = pybind11::reinterpret_borrow(entry->py_graph); - graph.attr("_execute_with_ptrs")(user_ptrs, reinterpret_cast(workspace), - reinterpret_cast(handle)); - } catch (const pybind11::error_already_set &exc) { - NVTE_ERROR("cuDNN score_mod SDPA graph execution failed: ", exc.what()); - } - } + auto status = graph->execute(handle, variant_pack, workspace); + NVTE_CHECK(status.is_good(), "cuDNN score_mod SDPA graph execution failed: ", + status.get_message()); return ffi_with_cuda_error_check(); } @@ -817,53 +881,16 @@ void AppendRemainingBuffers(Variadic_Buffer_Type args, std::vector *ptrs } // namespace -int64_t RegisterFusedAttnScoreModGraph(pybind11::object graph, - const std::vector &user_uids, - const std::vector &input_uids, - const std::vector &output_uids, - const std::vector &scalar_uids, - const std::vector &scalar_values) { - NVTE_CHECK(!graph.is_none(), "Cannot register an empty cuDNN score_mod graph."); - NVTE_CHECK(!user_uids.empty(), "Cannot register a cuDNN score_mod graph without variant UIDs."); - NVTE_CHECK(scalar_uids.size() == scalar_values.size(), - "Mismatched score_mod scalar uid/value counts."); - - auto entry = std::make_shared(); - entry->py_graph = graph.ptr(); - Py_INCREF(entry->py_graph); - entry->user_uids = user_uids; - entry->input_uids = input_uids; - entry->output_uids = output_uids; - entry->scalar_uids = scalar_uids; - entry->scalar_values.reserve(scalar_values.size()); - for (const auto &value : scalar_values) { - NVTE_CHECK(value.size() <= 16, "score_mod pass-by-value scalars must be at most 16 bytes."); - ScoreModScalarStorage storage; - storage.size = value.size(); - std::copy(value.begin(), value.end(), storage.data.begin()); - entry->scalar_values.push_back(storage); - } - - const int64_t graph_id = NextScoreModGraphId().fetch_add(1); - { - std::lock_guard lock(ScoreModGraphRegistryMutex()); - ScoreModGraphRegistry().emplace(graph_id, std::move(entry)); - } - return graph_id; -} - Error_Type FusedAttnScoreModForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf, Variadic_Buffer_Type score_mod_args, Result_Type output_buf, Result_Type stats_buf, Result_Type workspace_buf, Dictionary attrs) { - int64_t graph_id = get_attr_value(attrs, "graph_id"); std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data()}; AppendRemainingBuffers(score_mod_args, &input_ptrs); std::vector output_ptrs = {output_buf->untyped_data(), stats_buf->untyped_data()}; - return ExecuteScoreModGraph(stream, graph_id, input_ptrs, output_ptrs, - workspace_buf->untyped_data()); + return ExecuteScoreModGraph(stream, attrs, input_ptrs, output_ptrs, workspace_buf->untyped_data()); } XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnScoreModForwardHandler, FusedAttnScoreModForwardFFI, @@ -884,7 +911,6 @@ Error_Type FusedAttnScoreModBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Variadic_Buffer_Type score_mod_args, Result_Type dq_buf, Result_Type dk_buf, Result_Type dv_buf, Result_Type workspace_buf, Dictionary attrs) { - int64_t graph_id = get_attr_value(attrs, "graph_id"); std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), output_buf.untyped_data(), doutput_buf.untyped_data(), stats_buf.untyped_data()}; @@ -892,8 +918,7 @@ Error_Type FusedAttnScoreModBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, std::vector output_ptrs = {dq_buf->untyped_data(), dk_buf->untyped_data(), dv_buf->untyped_data()}; - return ExecuteScoreModGraph(stream, graph_id, input_ptrs, output_ptrs, - workspace_buf->untyped_data()); + return ExecuteScoreModGraph(stream, attrs, input_ptrs, output_ptrs, workspace_buf->untyped_data()); } XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnScoreModBackwardHandler, FusedAttnScoreModBackwardFFI, diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index bdb4507323..4f6cd385be 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -118,6 +118,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cudnn_version", &GetCudnnRuntimeVersion); + m.def("get_cudnn_frontend_version", &GetCudnnFrontendVersion); m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_num_compute_streams", &nvte_get_num_compute_streams); m.def("get_cublasLt_version", &cublasLtGetVersion); @@ -127,7 +128,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); - m.def("register_fused_attn_score_mod_graph", &RegisterFusedAttnScoreModGraph); m.def("get_topk_workspace_sizes", &GetTopkWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); diff --git a/transformer_engine/jax/csrc/extensions/utils.cpp b/transformer_engine/jax/csrc/extensions/utils.cpp index 52ab2edf0f..d8fe1762dd 100644 --- a/transformer_engine/jax/csrc/extensions/utils.cpp +++ b/transformer_engine/jax/csrc/extensions/utils.cpp @@ -6,6 +6,7 @@ #include "utils.h" #include +#include #include @@ -22,6 +23,8 @@ int GetCudaRuntimeVersion() { size_t GetCudnnRuntimeVersion() { return cudnnGetVersion(); } +size_t GetCudnnFrontendVersion() { return CUDNN_FRONTEND_VERSION; } + int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); } } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/utils.h b/transformer_engine/jax/csrc/extensions/utils.h index c55c8d86ce..939b0a881e 100644 --- a/transformer_engine/jax/csrc/extensions/utils.h +++ b/transformer_engine/jax/csrc/extensions/utils.h @@ -20,6 +20,7 @@ namespace jax { int GetCudaRuntimeVersion(); size_t GetCudnnRuntimeVersion(); +size_t GetCudnnFrontendVersion(); int GetDeviceComputeCapability(int gpu_id); class cudaDevicePropertiesManager { From 2c01c5e8b2d51d65f4af282d1324e65e49331352 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 May 2026 23:56:58 +0000 Subject: [PATCH 08/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/csrc/extensions/attention.cpp | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 5cdf8315bf..44e37a664a 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -4,6 +4,8 @@ * See LICENSE for license information. ************************************************************************/ +#include + #include #include #include @@ -15,8 +17,6 @@ #include #include -#include - #include "../extensions.h" #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" @@ -736,8 +736,8 @@ struct ScoreModGraphCacheKeyHash { using ScoreModGraphPtr = std::shared_ptr; -std::unordered_map - &ScoreModGraphCache() { +std::unordered_map & +ScoreModGraphCache() { static std::unordered_map cache; return cache; @@ -780,7 +780,8 @@ ScoreModGraphCacheKey GetScoreModGraphCacheKey(Dictionary &attrs) { NVTE_CHECK(frontend_version == CUDNN_FRONTEND_VERSION, "cuDNN frontend version mismatch for score_mod graph deserialization: graph was " "serialized with Python cuDNN frontend version ", - frontend_version, ", but Transformer Engine C++ was built with CUDNN_FRONTEND_VERSION ", + frontend_version, + ", but Transformer Engine C++ was built with CUDNN_FRONTEND_VERSION ", CUDNN_FRONTEND_VERSION, "."); int device_id = 0; @@ -812,8 +813,8 @@ ScoreModGraphPtr GetScoreModGraph(cudaStream_t stream, Dictionary &attrs) { auto graph = std::make_shared(); auto status = graph->deserialize(handle, serialized_data); - NVTE_CHECK(status.is_good(), "Failed to deserialize cuDNN score_mod SDPA graph: ", - status.get_message()); + NVTE_CHECK(status.is_good(), + "Failed to deserialize cuDNN score_mod SDPA graph: ", status.get_message()); std::lock_guard lock(ScoreModGraphCacheMutex()); auto &cache = ScoreModGraphCache(); @@ -837,9 +838,8 @@ Error_Type ExecuteScoreModGraph(cudaStream_t stream, Dictionary &attrs, NVTE_CHECK(input_ptrs.size() == input_uids.size(), "cuDNN score_mod graph expected ", input_uids.size(), " inputs but got ", input_ptrs.size()); - NVTE_CHECK(output_ptrs.size() >= output_uids.size(), - "cuDNN score_mod graph expected at least ", output_uids.size(), - " outputs but got ", output_ptrs.size()); + NVTE_CHECK(output_ptrs.size() >= output_uids.size(), "cuDNN score_mod graph expected at least ", + output_uids.size(), " outputs but got ", output_ptrs.size()); NVTE_CHECK(scalar_uids.size() == scalar_sizes.size(), "Mismatched score_mod scalar uid/value-size counts."); NVTE_CHECK(scalar_values.size() == scalar_uids.size() * 16, @@ -865,8 +865,8 @@ Error_Type ExecuteScoreModGraph(cudaStream_t stream, Dictionary &attrs, auto handle = GetScoreModCudnnHandle(); NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); auto status = graph->execute(handle, variant_pack, workspace); - NVTE_CHECK(status.is_good(), "cuDNN score_mod SDPA graph execution failed: ", - status.get_message()); + NVTE_CHECK(status.is_good(), + "cuDNN score_mod SDPA graph execution failed: ", status.get_message()); return ffi_with_cuda_error_check(); } @@ -890,7 +890,8 @@ Error_Type FusedAttnScoreModForwardFFI(cudaStream_t stream, Buffer_Type q_buf, B AppendRemainingBuffers(score_mod_args, &input_ptrs); std::vector output_ptrs = {output_buf->untyped_data(), stats_buf->untyped_data()}; - return ExecuteScoreModGraph(stream, attrs, input_ptrs, output_ptrs, workspace_buf->untyped_data()); + return ExecuteScoreModGraph(stream, attrs, input_ptrs, output_ptrs, + workspace_buf->untyped_data()); } XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnScoreModForwardHandler, FusedAttnScoreModForwardFFI, @@ -918,7 +919,8 @@ Error_Type FusedAttnScoreModBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, std::vector output_ptrs = {dq_buf->untyped_data(), dk_buf->untyped_data(), dv_buf->untyped_data()}; - return ExecuteScoreModGraph(stream, attrs, input_ptrs, output_ptrs, workspace_buf->untyped_data()); + return ExecuteScoreModGraph(stream, attrs, input_ptrs, output_ptrs, + workspace_buf->untyped_data()); } XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnScoreModBackwardHandler, FusedAttnScoreModBackwardFFI, From ba6a1a759388d91890291b634f810599fee29757 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Tue, 19 May 2026 00:29:42 +0000 Subject: [PATCH 09/11] Rename score_mod graph cache helpers Signed-off-by: Vladimir Cherepanov --- .../jax/csrc/extensions/attention.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 44e37a664a..1cea12f6a0 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -736,14 +736,14 @@ struct ScoreModGraphCacheKeyHash { using ScoreModGraphPtr = std::shared_ptr; -std::unordered_map & -ScoreModGraphCache() { +std::unordered_map + &getScoreModeGraphCache() { static std::unordered_map cache; return cache; } -std::mutex &ScoreModGraphCacheMutex() { +std::mutex &getScoreModGraphCacheMutex() { static std::mutex mutex; return mutex; } @@ -797,8 +797,8 @@ ScoreModGraphCacheKey GetScoreModGraphCacheKey(Dictionary &attrs) { ScoreModGraphPtr GetScoreModGraph(cudaStream_t stream, Dictionary &attrs) { const auto key = GetScoreModGraphCacheKey(attrs); { - std::lock_guard lock(ScoreModGraphCacheMutex()); - auto &cache = ScoreModGraphCache(); + std::lock_guard lock(getScoreModGraphCacheMutex()); + auto &cache = getScoreModeGraphCache(); auto it = cache.find(key); if (it != cache.end()) { return it->second; @@ -816,8 +816,8 @@ ScoreModGraphPtr GetScoreModGraph(cudaStream_t stream, Dictionary &attrs) { NVTE_CHECK(status.is_good(), "Failed to deserialize cuDNN score_mod SDPA graph: ", status.get_message()); - std::lock_guard lock(ScoreModGraphCacheMutex()); - auto &cache = ScoreModGraphCache(); + std::lock_guard lock(getScoreModGraphCacheMutex()); + auto &cache = getScoreModeGraphCache(); auto it = cache.find(key); if (it != cache.end()) { return it->second; From deebf8ebb94a68cbdc03dc361fbe97e1ace27d9a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 May 2026 00:39:18 +0000 Subject: [PATCH 10/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/csrc/extensions/attention.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 1cea12f6a0..e74b78df17 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -736,8 +736,8 @@ struct ScoreModGraphCacheKeyHash { using ScoreModGraphPtr = std::shared_ptr; -std::unordered_map - &getScoreModeGraphCache() { +std::unordered_map & +getScoreModeGraphCache() { static std::unordered_map cache; return cache; From 9a92dd2da29155df7aebd43c06958ad1bbe95f78 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 20 May 2026 22:15:39 +0000 Subject: [PATCH 11/11] Add Flax score_mod attention support Signed-off-by: Vladimir Cherepanov --- tests/jax/test_distributed_fused_attn.py | 26 +- tests/jax/test_fused_attn.py | 434 +----------- tests/jax/test_fused_attn_score_mod.py | 666 ++++++++++++++++++ .../jax/cpp_extensions/attention.py | 2 + transformer_engine/jax/flax/transformer.py | 232 ++++-- 5 files changed, 851 insertions(+), 509 deletions(-) create mode 100644 tests/jax/test_fused_attn_score_mod.py diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 7a61b208ce..dfcfab3312 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -18,6 +18,9 @@ FusedAttnRunner, BiasShape, SeqDescFormat, + customcall_fused_dpa, +) +from test_fused_attn_score_mod import ( _ScoreModSoftcap, _has_cudnn_frontend_python, _reference_attention, @@ -36,7 +39,6 @@ inverse_reorder_causal_load_balancing, CPStrategy, ReorderStrategy, - fused_attn, ) @@ -287,7 +289,6 @@ def test_cross_attn( DISTRIBUTED_SCORE_MOD_DATA_SHAPES = { "L0": [], "L1": [(4, 16, 4, 64)], - "L2": [(4, 16, 4, 64)], } @@ -335,18 +336,21 @@ def test_softcap_score_mod_with_aux_params_backward( softcap_score_mod = _ScoreModSoftcap() def score_mod_loss(q, k, v, dout): - out = fused_attn( - (q, k, v), + out = customcall_fused_dpa( + q, + k, + v, + None, None, None, None, - AttnBiasType.NO_BIAS, - AttnMaskType.NO_MASK, - QKVLayout.BSHD_BSHD_BSHD, - AttnSoftmaxType.VANILLA_SOFTMAX, - scaling_factor, - 0.0, - True, + attn_bias_type=AttnBiasType.NO_BIAS, + attn_mask_type=AttnMaskType.NO_MASK, + qkv_layout=QKVLayout.BSHD_BSHD_BSHD, + softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX, + scaling_factor=scaling_factor, + dropout_probability=0.0, + is_training=True, score_mod=softcap_score_mod.forward, score_mod_bprop=softcap_score_mod.backward, score_mod_tensors={"softcap": softcap}, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 2e4640a470..1fb0108068 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -23,7 +23,6 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.typing import ArrayLike, DTypeLike -import transformer_engine.jax.cpp_extensions.attention as tex_attention from transformer_engine.jax import autocast from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.attention import ( @@ -41,7 +40,7 @@ CPStrategy, ReorderStrategy, ) -from transformer_engine.jax.cpp_extensions import FusedAttnHelper, make_fused_attn_score_mod_config +from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine_jax import ( NVTE_Fused_Attn_Backend, get_cudnn_version, @@ -55,157 +54,6 @@ _deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) -def _has_cudnn_frontend_python(): - try: - tex_attention._import_cudnn_for_score_mod() - except ImportError: - return False - return True - - -def _score_mod_causal(graph, score, tensors): - cudnn = tex_attention._import_cudnn_for_score_mod() - - row_index = graph.gen_index( - input=score, - axis=2, - compute_data_type=cudnn.data_type.INT32, - ) - row_index.set_data_type(cudnn.data_type.INT32) - col_index = graph.gen_index( - input=score, - axis=3, - compute_data_type=cudnn.data_type.INT32, - ) - col_index.set_data_type(cudnn.data_type.INT32) - keep = graph.cmp_ge( - input=row_index, - comparison=col_index, - compute_data_type=cudnn.data_type.BOOLEAN, - ) - keep.set_data_type(cudnn.data_type.BOOLEAN) - return graph.binary_select(input0=score, input1=tensors["neg_inf"], mask=keep) - - -def _score_mod_causal_bprop(graph, dscore, tensors): - cudnn = tex_attention._import_cudnn_for_score_mod() - - row_index = graph.gen_index( - input=dscore, - axis=2, - compute_data_type=cudnn.data_type.INT32, - ) - row_index.set_data_type(cudnn.data_type.INT32) - col_index = graph.gen_index( - input=dscore, - axis=3, - compute_data_type=cudnn.data_type.INT32, - ) - col_index.set_data_type(cudnn.data_type.INT32) - keep = graph.cmp_ge( - input=row_index, - comparison=col_index, - compute_data_type=cudnn.data_type.BOOLEAN, - ) - keep.set_data_type(cudnn.data_type.BOOLEAN) - return graph.binary_select(input0=dscore, input1=tensors["zero"], mask=keep) - - -def _score_mod_post_scale_bias(graph, score, tensors): - cudnn = tex_attention._import_cudnn_for_score_mod() - - row_index = graph.gen_index( - input=score, - axis=2, - compute_data_type=cudnn.data_type.INT32, - ) - row_index.set_data_type(cudnn.data_type.INT32) - col_index = graph.gen_index( - input=score, - axis=3, - compute_data_type=cudnn.data_type.INT32, - ) - col_index.set_data_type(cudnn.data_type.INT32) - post_scale_bias = graph.sub( - a=row_index, - b=col_index, - compute_data_type=cudnn.data_type.FLOAT, - ) - post_scale_bias.set_data_type(cudnn.data_type.FLOAT) - return graph.add( - a=score, - b=post_scale_bias, - compute_data_type=cudnn.data_type.FLOAT, - ) - - -class _ScoreModSoftcap: - """cuDNN frontend score_mod implementing softcapping.""" - - def __init__(self): - self.before_tanh_activation = None - - def score_mod_graph_cache_key(self): - """Graph topology key for softcap score_mod.""" - return ("softcap",) - - def forward(self, graph, score, tensors): - cudnn = tex_attention._import_cudnn_for_score_mod() - - self.before_tanh_activation = graph.div( - a=score, - b=tensors["softcap"], - compute_data_type=cudnn.data_type.FLOAT, - ) - self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) - tanh_out = graph.tanh(input=self.before_tanh_activation) - tanh_out.set_data_type(cudnn.data_type.FLOAT) - return graph.mul( - a=tanh_out, - b=tensors["softcap"], - compute_data_type=cudnn.data_type.FLOAT, - ) - - def backward(self, graph, dscore, tensors): - cudnn = tex_attention._import_cudnn_for_score_mod() - - d_tanh_out = graph.mul( - a=dscore, - b=tensors["softcap"], - compute_data_type=cudnn.data_type.FLOAT, - ) - d_tanh_out.set_data_type(cudnn.data_type.FLOAT) - d_before_tanh_activation = graph.tanh_backward( - loss=d_tanh_out, - input=self.before_tanh_activation, - compute_data_type=cudnn.data_type.FLOAT, - ) - d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) - return graph.div( - a=d_before_tanh_activation, - b=tensors["softcap"], - compute_data_type=cudnn.data_type.FLOAT, - ) - - -def _reference_attention( - query, key, value, scale, *, causal=False, post_scale_bias=False, softcap=None -): - scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale - if causal: - q_pos = jnp.arange(query.shape[1])[:, None] - kv_pos = jnp.arange(key.shape[1])[None, :] - scores = jnp.where(q_pos >= kv_pos, scores, -1e9) - if post_scale_bias: - q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None] - kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :] - scores = scores + q_pos - kv_pos - if softcap is not None: - scores = softcap * jnp.tanh(scores / softcap) - probs = jax.nn.softmax(scores, axis=-1) - return jnp.einsum("bhqk,bkhd->bqhd", probs, value).astype(query.dtype) - - @pytest.fixture(autouse=True, scope="module") def init(): """ @@ -290,286 +138,6 @@ def general_dot_product_attention( return context -def _require_cudnn_frontend_score_mod(): - try: - cudnn = tex_attention._import_cudnn_for_score_mod() - except ImportError: - pytest.skip("cuDNN Python frontend is required for score_mod") - version = tuple(int(part) for part in cudnn.backend_version_string().split(".")[:2]) - if version < (9, 6): - pytest.skip("cuDNN score_mod SDPA requires cuDNN frontend 9.6 or newer") - - -def _identity_score_mod(_graph, score, _tensors): - return score - - -def test_fused_attn_score_mod_validation_rejects_masks_without_cudnn_frontend(): - q = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) - k = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) - v = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) - - with pytest.raises(ValueError, match="mutually exclusive with attention masks"): - fused_attn( - (q, k, v), - None, - None, - None, - AttnBiasType.NO_BIAS, - AttnMaskType.CAUSAL_MASK, - QKVLayout.BSHD_BSHD_BSHD, - AttnSoftmaxType.VANILLA_SOFTMAX, - 1.0, - 0.0, - True, - score_mod=_identity_score_mod, - ) - - -def test_fused_attn_score_mod_config_splits_tensors_and_pass_by_value_scalars(): - tensor = jnp.ones((1, 1, 1, 1), dtype=jnp.float32) - - config, tensor_operands, bprop_tensor_operands = make_fused_attn_score_mod_config( - _identity_score_mod, - None, - {"tensor": tensor, "neg_inf": -1e9}, - None, - 0.125, - True, - ) - - assert config.score_mod_tensor_names == ("tensor",) - assert len(tensor_operands) == 1 - assert tensor_operands[0].shape == (1, 1, 1, 1) - assert len(bprop_tensor_operands) == 0 - assert len(config.score_mod_scalars) == 1 - assert config.score_mod_scalars[0].name == "neg_inf" - assert config.score_mod_scalars[0].dtype == "float32" - assert len(config.score_mod_scalars[0].value) == np.dtype(np.float32).itemsize - - -def test_fused_attn_score_mod_cudnn_frontend_version_check(monkeypatch): - class FakeCudnn: - __version__ = "1.22.0" - - monkeypatch.setattr( - tex_attention.transformer_engine_jax, - "get_cudnn_frontend_version", - lambda: 12200, - ) - assert tex_attention._check_cudnn_frontend_version_match(FakeCudnn) == 12200 - - monkeypatch.setattr( - tex_attention.transformer_engine_jax, - "get_cudnn_frontend_version", - lambda: 12100, - ) - with pytest.raises(RuntimeError, match="Python/C\\+\\+ version mismatch"): - tex_attention._check_cudnn_frontend_version_match(FakeCudnn) - - -def test_fused_attn_score_mod_config_stabilizes_bound_method_cache_keys(): - softcap_score_mod = _ScoreModSoftcap() - first_forward = softcap_score_mod.forward - second_forward = softcap_score_mod.forward - first_backward = softcap_score_mod.backward - second_backward = softcap_score_mod.backward - - assert first_forward is not second_forward - assert first_backward is not second_backward - - config_1, _, _ = make_fused_attn_score_mod_config( - first_forward, - first_backward, - {"softcap": 0.8}, - {"softcap": 0.8}, - 0.125, - True, - ) - config_2, _, _ = make_fused_attn_score_mod_config( - second_forward, - second_backward, - {"softcap": 0.8}, - {"softcap": 0.8}, - 0.125, - True, - ) - other_softcap_score_mod = _ScoreModSoftcap() - config_3, _, _ = make_fused_attn_score_mod_config( - other_softcap_score_mod.forward, - other_softcap_score_mod.backward, - {"softcap": 0.8}, - {"softcap": 0.8}, - 0.125, - True, - ) - - assert config_1 == config_2 - assert hash(config_1) == hash(config_2) - assert config_1 == config_3 - - -def test_fused_attn_score_mod_config_leaves_unkeyed_bound_methods_uncached(): - class UnkeyedScoreMod: - def forward(self, _graph, score, _tensors): - return score - - score_mod = UnkeyedScoreMod() - config_1, _, _ = make_fused_attn_score_mod_config( - score_mod.forward, None, None, None, 0.125, True - ) - config_2, _, _ = make_fused_attn_score_mod_config( - score_mod.forward, None, None, None, 0.125, True - ) - - assert config_1 != config_2 - assert tex_attention._graph_cache_key("fwd", config_1, ()) is None - - -@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") -def test_fused_attn_score_mod_post_scale_bias_optional_bprop(): - _require_cudnn_frontend_score_mod() - - key = jax.random.key(0) - q_key, k_key, v_key = jax.random.split(key, 3) - q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) - k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) - v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) - scale = 1.0 / sqrt(q.shape[-1]) - - def score_mod_loss(query, key_, value): - out = fused_attn( - (query, key_, value), - None, - None, - None, - AttnBiasType.NO_BIAS, - AttnMaskType.NO_MASK, - QKVLayout.BSHD_BSHD_BSHD, - AttnSoftmaxType.VANILLA_SOFTMAX, - scale, - 0.0, - True, - score_mod=_score_mod_post_scale_bias, - ) - return jnp.sum(out.astype(jnp.float32)), out - - def ref_loss(query, key_, value): - out = _reference_attention(query, key_, value, scale, post_scale_bias=True) - return jnp.sum(out.astype(jnp.float32)), out - - (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( - score_mod_loss, argnums=(0, 1, 2), has_aux=True - )(q, k, v) - (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( - q, k, v - ) - - assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) - assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) - for grad, ref_grad in zip(score_mod_grads, ref_grads): - assert_allclose(grad, ref_grad, rtol=5e-2, atol=5e-2) - - -@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") -def test_fused_attn_score_mod_causal_with_bprop(): - _require_cudnn_frontend_score_mod() - - key = jax.random.key(1) - q_key, k_key, v_key = jax.random.split(key, 3) - q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) - k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) - v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) - scale = 1.0 / sqrt(q.shape[-1]) - - def score_mod_loss(query, key_, value): - out = fused_attn( - (query, key_, value), - None, - None, - None, - AttnBiasType.NO_BIAS, - AttnMaskType.NO_MASK, - QKVLayout.BSHD_BSHD_BSHD, - AttnSoftmaxType.VANILLA_SOFTMAX, - scale, - 0.0, - True, - score_mod=_score_mod_causal, - score_mod_bprop=_score_mod_causal_bprop, - score_mod_tensors={"neg_inf": -1e9}, - score_mod_bprop_tensors={"zero": 0.0}, - ) - return jnp.sum(out.astype(jnp.float32)), out - - def ref_loss(query, key_, value): - out = _reference_attention(query, key_, value, scale, causal=True) - return jnp.sum(out.astype(jnp.float32)), out - - (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( - score_mod_loss, argnums=(0, 1, 2), has_aux=True - )(q, k, v) - (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( - q, k, v - ) - - assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) - assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) - for grad, ref_grad in zip(score_mod_grads, ref_grads): - assert_allclose(grad, ref_grad, rtol=5e-2, atol=5e-2) - - -@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") -def test_fused_attn_score_mod_softcap_with_bprop(): - _require_cudnn_frontend_score_mod() - - key = jax.random.key(2) - q_key, k_key, v_key, d_out_key = jax.random.split(key, 4) - q = jax.random.normal(q_key, (1, 16, 2, 64), dtype=jnp.float16) - k = jax.random.normal(k_key, (1, 16, 2, 64), dtype=jnp.float16) - v = (0.1 * jax.random.normal(v_key, (1, 16, 2, 64), dtype=jnp.float16)).astype(jnp.float16) - d_out = jax.random.normal(d_out_key, (1, 16, 2, 64), dtype=jnp.float16) - scale = 1.0 / sqrt(q.shape[-1]) - softcap = 0.8 - softcap_score_mod = _ScoreModSoftcap() - - def score_mod_loss(query, key_, value): - out = fused_attn( - (query, key_, value), - None, - None, - None, - AttnBiasType.NO_BIAS, - AttnMaskType.NO_MASK, - QKVLayout.BSHD_BSHD_BSHD, - AttnSoftmaxType.VANILLA_SOFTMAX, - scale, - 0.0, - True, - score_mod=softcap_score_mod.forward, - score_mod_bprop=softcap_score_mod.backward, - score_mod_tensors={"softcap": softcap}, - score_mod_bprop_tensors={"softcap": softcap}, - ) - return jnp.sum(out.astype(jnp.float32) * d_out.astype(jnp.float32)), out - - def ref_loss(query, key_, value): - out = _reference_attention(query, key_, value, scale, softcap=softcap) - return jnp.sum(out.astype(jnp.float32) * d_out.astype(jnp.float32)), out - - (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( - score_mod_loss, argnums=(0, 1, 2), has_aux=True - )(q, k, v) - (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( - q, k, v - ) - - assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2) - assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2) - for grad, ref_grad in zip(score_mod_grads, ref_grads): - assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2) - - @jax.jit def make_causal_mask( segment_ids_q: ArrayLike, diff --git a/tests/jax/test_fused_attn_score_mod.py b/tests/jax/test_fused_attn_score_mod.py new file mode 100644 index 0000000000..c9debde05e --- /dev/null +++ b/tests/jax/test_fused_attn_score_mod.py @@ -0,0 +1,666 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Tests for cuDNN frontend score_mod fused attention.""" +from math import sqrt + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from jax import value_and_grad + +import transformer_engine.jax.cpp_extensions.attention as tex_attention +from transformer_engine.jax.attention import ( + AttnBiasType, + AttnMaskType, + AttnSoftmaxType, + QKVLayout, + fused_attn, +) +from transformer_engine.jax.cpp_extensions import make_fused_attn_score_mod_config +from transformer_engine.jax.flax import transformer as flax_transformer +from utils import assert_allclose + + +def _has_cudnn_frontend_python(): + try: + tex_attention._import_cudnn_for_score_mod() + except ImportError: + return False + return True + + +def _score_mod_causal(graph, score, tensors): + cudnn = tex_attention._import_cudnn_for_score_mod() + + row_index = graph.gen_index( + input=score, + axis=2, + compute_data_type=cudnn.data_type.INT32, + ) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = graph.gen_index( + input=score, + axis=3, + compute_data_type=cudnn.data_type.INT32, + ) + col_index.set_data_type(cudnn.data_type.INT32) + keep = graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return graph.binary_select(input0=score, input1=tensors["neg_inf"], mask=keep) + + +def _score_mod_causal_bprop(graph, dscore, tensors): + cudnn = tex_attention._import_cudnn_for_score_mod() + + row_index = graph.gen_index( + input=dscore, + axis=2, + compute_data_type=cudnn.data_type.INT32, + ) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = graph.gen_index( + input=dscore, + axis=3, + compute_data_type=cudnn.data_type.INT32, + ) + col_index.set_data_type(cudnn.data_type.INT32) + keep = graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return graph.binary_select(input0=dscore, input1=tensors["zero"], mask=keep) + + +def _score_mod_post_scale_bias(graph, score, tensors): + cudnn = tex_attention._import_cudnn_for_score_mod() + + row_index = graph.gen_index( + input=score, + axis=2, + compute_data_type=cudnn.data_type.INT32, + ) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = graph.gen_index( + input=score, + axis=3, + compute_data_type=cudnn.data_type.INT32, + ) + col_index.set_data_type(cudnn.data_type.INT32) + post_scale_bias = graph.sub( + a=row_index, + b=col_index, + compute_data_type=cudnn.data_type.FLOAT, + ) + post_scale_bias.set_data_type(cudnn.data_type.FLOAT) + return graph.add( + a=score, + b=post_scale_bias, + compute_data_type=cudnn.data_type.FLOAT, + ) + + +class _ScoreModSoftcap: + """cuDNN frontend score_mod implementing softcapping.""" + + def __init__(self): + self.before_tanh_activation = None + + def score_mod_graph_cache_key(self): + """Graph topology key for softcap score_mod.""" + return ("softcap",) + + def forward(self, graph, score, tensors): + cudnn = tex_attention._import_cudnn_for_score_mod() + + self.before_tanh_activation = graph.div( + a=score, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + tanh_out = graph.tanh(input=self.before_tanh_activation) + tanh_out.set_data_type(cudnn.data_type.FLOAT) + return graph.mul( + a=tanh_out, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + def backward(self, graph, dscore, tensors): + cudnn = tex_attention._import_cudnn_for_score_mod() + + d_tanh_out = graph.mul( + a=dscore, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + d_tanh_out.set_data_type(cudnn.data_type.FLOAT) + d_before_tanh_activation = graph.tanh_backward( + loss=d_tanh_out, + input=self.before_tanh_activation, + compute_data_type=cudnn.data_type.FLOAT, + ) + d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + return graph.div( + a=d_before_tanh_activation, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + +def _reference_attention( + query, key, value, scale, *, causal=False, post_scale_bias=False, softcap=None +): + scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale + if causal: + q_pos = jnp.arange(query.shape[1])[:, None] + kv_pos = jnp.arange(key.shape[1])[None, :] + scores = jnp.where(q_pos >= kv_pos, scores, -1e9) + if post_scale_bias: + q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None] + kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :] + scores = scores + q_pos - kv_pos + if softcap is not None: + scores = softcap * jnp.tanh(scores / softcap) + probs = jax.nn.softmax(scores, axis=-1) + return jnp.einsum("bhqk,bkhd->bqhd", probs, value).astype(query.dtype) + + +@pytest.fixture(autouse=True, scope="module") +def init(): + """ + WAR for CUDA uninitialize error + """ + # Calling customcalls before jax may cause CUDA uninitialize error + _ = jnp.zeros(0) + yield + + +def _require_cudnn_frontend_score_mod(): + try: + cudnn = tex_attention._import_cudnn_for_score_mod() + except ImportError: + pytest.skip("cuDNN Python frontend is required for score_mod") + version = tuple(int(part) for part in cudnn.backend_version_string().split(".")[:2]) + if version < (9, 6): + pytest.skip("cuDNN score_mod SDPA requires cuDNN frontend 9.6 or newer") + + +def _identity_score_mod(_graph, score, _tensors): + return score + + +def _install_fake_flax_fused_attn(monkeypatch): + captured = {} + + def fused_attn_kernel_check_should_not_run(*_args, **_kwargs): + raise AssertionError("score_mod path should not use the standard fused-attn kernel check") + + def fake_fused_attn( + qkv, + bias, + sequence_descriptor, + seed, + *, + attn_bias_type, + attn_mask_type, + qkv_layout, + softmax_type, + scaling_factor, + dropout_probability, + is_training, + max_segments_per_seq=1, + window_size=None, + context_parallel_strategy=None, + context_parallel_causal_load_balanced=False, + context_parallel_axis="", + context_checkpoint_name="context", + softmax_offset=None, + stripe_size=None, + score_mod=None, + score_mod_bprop=None, + score_mod_tensors=None, + score_mod_bprop_tensors=None, + ): + captured.update( + qkv=qkv, + bias=bias, + sequence_descriptor=sequence_descriptor, + seed=seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + softmax_type=softmax_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + window_size=window_size, + context_parallel_strategy=context_parallel_strategy, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, + context_checkpoint_name=context_checkpoint_name, + softmax_offset=softmax_offset, + stripe_size=stripe_size, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, + ) + return qkv[0] + + monkeypatch.setattr( + flax_transformer, + "is_fused_attn_kernel_available", + fused_attn_kernel_check_should_not_run, + ) + monkeypatch.setattr(flax_transformer, "fused_attn", fake_fused_attn) + return captured + + +def test_dot_product_attention_score_mod_requires_fused_attn(monkeypatch): + monkeypatch.setenv("NVTE_FUSED_ATTN", "0") + query = jnp.ones((1, 8, 1, 16), dtype=jnp.float16) + key = jnp.ones((1, 8, 1, 16), dtype=jnp.float16) + value = jnp.ones((1, 8, 1, 16), dtype=jnp.float16) + + dpa = flax_transformer.DotProductAttention( + head_dim=16, + num_attention_heads=1, + num_gqa_groups=1, + attn_mask_type="no_mask", + qkv_layout="bshd_bshd_bshd", + transpose_batch_sequence=False, + score_mod=_identity_score_mod, + ) + + with pytest.raises(ValueError, match="score_mod requires fused attention"): + dpa.apply({}, query, key, value, deterministic=True) + + +def test_dot_product_attention_plumbs_score_mod_to_fused_attn(monkeypatch): + captured = _install_fake_flax_fused_attn(monkeypatch) + query = jnp.ones((1, 8, 1, 16), dtype=jnp.float16) + key = jnp.ones((1, 8, 1, 16), dtype=jnp.float16) + value = jnp.ones((1, 8, 1, 16), dtype=jnp.float16) + aux = jnp.ones((1, 1, 1, 1), dtype=jnp.float32) + + dpa = flax_transformer.DotProductAttention( + head_dim=16, + num_attention_heads=1, + num_gqa_groups=1, + attn_mask_type="no_mask", + qkv_layout="bshd_bshd_bshd", + transpose_batch_sequence=False, + score_mod=_identity_score_mod, + ) + out = dpa.apply({}, query, key, value, deterministic=True, score_mod_tensors={"aux": aux}) + + np.testing.assert_array_equal(out, query) + assert captured["score_mod"] is _identity_score_mod + assert captured["score_mod_tensors"]["aux"].shape == aux.shape + assert captured["score_mod_bprop"] is None + assert captured["score_mod_bprop_tensors"] is None + assert captured["attn_mask_type"] is AttnMaskType.NO_MASK + assert captured["attn_bias_type"] is AttnBiasType.NO_BIAS + assert captured["qkv_layout"] is QKVLayout.BSHD_BSHD_BSHD + assert captured["softmax_type"] is AttnSoftmaxType.VANILLA_SOFTMAX + + +def test_dot_product_attention_unpacks_packed_score_mod_to_separate_layout(monkeypatch): + captured = _install_fake_flax_fused_attn(monkeypatch) + qkv = jnp.ones((1, 8, 3, 1, 16), dtype=jnp.float16) + + dpa = flax_transformer.DotProductAttention( + head_dim=16, + num_attention_heads=1, + num_gqa_groups=1, + attn_mask_type="no_mask", + qkv_layout="bs3hd", + transpose_batch_sequence=False, + score_mod=_identity_score_mod, + ) + out = dpa.apply({}, qkv, None, None, deterministic=True) + + assert out.shape == (1, 8, 1, 16) + assert len(captured["qkv"]) == 3 + assert captured["qkv"][0].shape == (1, 8, 1, 16) + assert captured["qkv_layout"] is QKVLayout.BSHD_BSHD_BSHD + assert captured["score_mod"] is _identity_score_mod + + +def test_multi_head_attention_plumbs_score_mod_to_dot_product_attention(monkeypatch): + captured = _install_fake_flax_fused_attn(monkeypatch) + + class FakeLayerNormDenseGeneral: + def __init__(self, *, features, return_layernorm_output=False, **_kwargs): + self.features = features + self.return_layernorm_output = return_layernorm_output + + def __call__(self, inputs): + features = self.features if isinstance(self.features, tuple) else (self.features,) + output = jnp.ones((*inputs.shape[:-1], *features), dtype=inputs.dtype) + ln_out = inputs if self.return_layernorm_output else None + return output, ln_out + + class FakeDenseGeneral: + def __init__(self, *, features, **_kwargs): + self.features = features + + def __call__(self, inputs): + features = self.features if isinstance(self.features, tuple) else (self.features,) + return jnp.ones((*inputs.shape[:-1], *features), dtype=inputs.dtype) + + monkeypatch.setattr(flax_transformer, "LayerNormDenseGeneral", FakeLayerNormDenseGeneral) + monkeypatch.setattr(flax_transformer, "DenseGeneral", FakeDenseGeneral) + + inputs = jnp.ones((1, 8, 16), dtype=jnp.float16) + aux = jnp.ones((1, 1, 1, 1), dtype=jnp.float32) + + mha = flax_transformer.MultiHeadAttention( + head_dim=16, + num_attention_heads=1, + num_gqa_groups=1, + input_layernorm=False, + attention_dropout=0.0, + attn_mask_type="no_mask", + fuse_qkv_params=True, + transpose_batch_sequence=False, + score_mod=_identity_score_mod, + ) + + variables = mha.init( + jax.random.key(0), + inputs, + inputs, + deterministic=True, + score_mod_tensors={"aux": aux}, + ) + out, ln_out = mha.apply( + variables, + inputs, + inputs, + deterministic=True, + score_mod_tensors={"aux": aux}, + ) + + assert out.shape == inputs.shape + assert ln_out is None + assert len(captured["qkv"]) == 3 + assert captured["qkv_layout"] is QKVLayout.BSHD_BSHD_BSHD + assert captured["score_mod"] is _identity_score_mod + assert captured["score_mod_tensors"]["aux"].shape == aux.shape + + +def test_fused_attn_score_mod_validation_rejects_masks_without_cudnn_frontend(): + q = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) + k = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) + v = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) + + with pytest.raises(ValueError, match="mutually exclusive with attention masks"): + fused_attn( + (q, k, v), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.CAUSAL_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + 1.0, + 0.0, + True, + score_mod=_identity_score_mod, + ) + + +def test_fused_attn_score_mod_config_splits_tensors_and_pass_by_value_scalars(): + tensor = jnp.ones((1, 1, 1, 1), dtype=jnp.float32) + + config, tensor_operands, bprop_tensor_operands = make_fused_attn_score_mod_config( + _identity_score_mod, + None, + {"tensor": tensor, "neg_inf": -1e9}, + None, + 0.125, + True, + ) + + assert config.score_mod_tensor_names == ("tensor",) + assert len(tensor_operands) == 1 + assert tensor_operands[0].shape == (1, 1, 1, 1) + assert len(bprop_tensor_operands) == 0 + assert len(config.score_mod_scalars) == 1 + assert config.score_mod_scalars[0].name == "neg_inf" + assert config.score_mod_scalars[0].dtype == "float32" + assert len(config.score_mod_scalars[0].value) == np.dtype(np.float32).itemsize + + +def test_fused_attn_score_mod_cudnn_frontend_version_check(monkeypatch): + class FakeCudnn: + __version__ = "1.22.0" + + monkeypatch.setattr( + tex_attention.transformer_engine_jax, + "get_cudnn_frontend_version", + lambda: 12200, + ) + assert tex_attention._check_cudnn_frontend_version_match(FakeCudnn) == 12200 + + monkeypatch.setattr( + tex_attention.transformer_engine_jax, + "get_cudnn_frontend_version", + lambda: 12100, + ) + with pytest.raises(RuntimeError, match="Python/C\\+\\+ version mismatch"): + tex_attention._check_cudnn_frontend_version_match(FakeCudnn) + + +def test_fused_attn_score_mod_config_stabilizes_bound_method_cache_keys(): + softcap_score_mod = _ScoreModSoftcap() + first_forward = softcap_score_mod.forward + second_forward = softcap_score_mod.forward + first_backward = softcap_score_mod.backward + second_backward = softcap_score_mod.backward + + assert first_forward is not second_forward + assert first_backward is not second_backward + + config_1, _, _ = make_fused_attn_score_mod_config( + first_forward, + first_backward, + {"softcap": 0.8}, + {"softcap": 0.8}, + 0.125, + True, + ) + config_2, _, _ = make_fused_attn_score_mod_config( + second_forward, + second_backward, + {"softcap": 0.8}, + {"softcap": 0.8}, + 0.125, + True, + ) + other_softcap_score_mod = _ScoreModSoftcap() + config_3, _, _ = make_fused_attn_score_mod_config( + other_softcap_score_mod.forward, + other_softcap_score_mod.backward, + {"softcap": 0.8}, + {"softcap": 0.8}, + 0.125, + True, + ) + + assert config_1 == config_2 + assert hash(config_1) == hash(config_2) + assert config_1 == config_3 + + +def test_fused_attn_score_mod_config_leaves_unkeyed_bound_methods_uncached(): + class UnkeyedScoreMod: + def forward(self, _graph, score, _tensors): + return score + + score_mod = UnkeyedScoreMod() + config_1, _, _ = make_fused_attn_score_mod_config( + score_mod.forward, None, None, None, 0.125, True + ) + config_2, _, _ = make_fused_attn_score_mod_config( + score_mod.forward, None, None, None, 0.125, True + ) + + assert config_1 != config_2 + assert tex_attention._graph_cache_key("fwd", config_1, ()) is None + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +def test_fused_attn_score_mod_post_scale_bias_optional_bprop(): + _require_cudnn_frontend_score_mod() + + key = jax.random.key(0) + q_key, k_key, v_key = jax.random.split(key, 3) + q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + scale = 1.0 / sqrt(q.shape[-1]) + + def score_mod_loss(query, key_, value): + out = fused_attn( + (query, key_, value), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.NO_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + scale, + 0.0, + True, + score_mod=_score_mod_post_scale_bias, + ) + return jnp.sum(out.astype(jnp.float32)), out + + def ref_loss(query, key_, value): + out = _reference_attention(query, key_, value, scale, post_scale_bias=True) + return jnp.sum(out.astype(jnp.float32)), out + + (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( + score_mod_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( + q, k, v + ) + + assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) + assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=5e-2, atol=5e-2) + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +def test_fused_attn_score_mod_causal_with_bprop(): + _require_cudnn_frontend_score_mod() + + key = jax.random.key(1) + q_key, k_key, v_key = jax.random.split(key, 3) + q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + scale = 1.0 / sqrt(q.shape[-1]) + + def score_mod_loss(query, key_, value): + out = fused_attn( + (query, key_, value), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.NO_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + scale, + 0.0, + True, + score_mod=_score_mod_causal, + score_mod_bprop=_score_mod_causal_bprop, + score_mod_tensors={"neg_inf": -1e9}, + score_mod_bprop_tensors={"zero": 0.0}, + ) + return jnp.sum(out.astype(jnp.float32)), out + + def ref_loss(query, key_, value): + out = _reference_attention(query, key_, value, scale, causal=True) + return jnp.sum(out.astype(jnp.float32)), out + + (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( + score_mod_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( + q, k, v + ) + + assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) + assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=5e-2, atol=5e-2) + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +def test_fused_attn_score_mod_softcap_with_bprop(): + _require_cudnn_frontend_score_mod() + + key = jax.random.key(2) + q_key, k_key, v_key, d_out_key = jax.random.split(key, 4) + q = jax.random.normal(q_key, (1, 16, 2, 64), dtype=jnp.float16) + k = jax.random.normal(k_key, (1, 16, 2, 64), dtype=jnp.float16) + v = (0.1 * jax.random.normal(v_key, (1, 16, 2, 64), dtype=jnp.float16)).astype(jnp.float16) + d_out = jax.random.normal(d_out_key, (1, 16, 2, 64), dtype=jnp.float16) + scale = 1.0 / sqrt(q.shape[-1]) + softcap = 0.8 + softcap_score_mod = _ScoreModSoftcap() + + def score_mod_loss(query, key_, value): + out = fused_attn( + (query, key_, value), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.NO_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + scale, + 0.0, + True, + score_mod=softcap_score_mod.forward, + score_mod_bprop=softcap_score_mod.backward, + score_mod_tensors={"softcap": softcap}, + score_mod_bprop_tensors={"softcap": softcap}, + ) + return jnp.sum(out.astype(jnp.float32) * d_out.astype(jnp.float32)), out + + def ref_loss(query, key_, value): + out = _reference_attention(query, key_, value, scale, softcap=softcap) + return jnp.sum(out.astype(jnp.float32) * d_out.astype(jnp.float32)), out + + (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( + score_mod_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( + q, k, v + ) + + assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2) + assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 8ff41ba3e7..654d3dbf73 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -450,6 +450,8 @@ class _SerializedScoreModGraph: scalar_values: np.ndarray +# cuDNN frontend tensor UIDs are arbitrary, but assigning stable values makes serialized +# graphs deterministic and simplifies debugging across the Python graph builder and C++ executor. _SCORE_MOD_UID_Q = 1 _SCORE_MOD_UID_K = 2 _SCORE_MOD_UID_V = 3 diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index a2e7920843..63b84288e8 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -8,7 +8,7 @@ from enum import Enum from math import sqrt import os -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union import warnings import jax @@ -305,6 +305,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT context_checkpoint_name: str = "context" softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX + score_mod: Optional[Callable] = None + score_mod_bprop: Optional[Callable] = None @nn.compact def __call__( @@ -317,6 +319,8 @@ def __call__( *, dropout_rng: Optional[PRNGKey] = None, deterministic: bool = False, + score_mod_tensors: Optional[Mapping[str, Any]] = None, + score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, ) -> Array: seed = None @@ -340,6 +344,32 @@ def __call__( jnp.float32, ) + score_mod_requested = ( + self.score_mod is not None + or self.score_mod_bprop is not None + or score_mod_tensors is not None + or score_mod_bprop_tensors is not None + ) + fused_attn_kwargs = { + "attn_mask_type": self.attn_mask_type, + "attn_bias_type": self.attn_bias_type, + "softmax_type": self.softmax_type, + "scaling_factor": scale_factor, + "dropout_probability": self.attention_dropout, + "is_training": not deterministic, + "window_size": self.window_size, + "max_segments_per_seq": self.max_segments_per_seq, + "context_parallel_causal_load_balanced": self.context_parallel_causal_load_balanced, + "context_parallel_axis": self.context_parallel_axis, + "context_parallel_strategy": self.context_parallel_strategy, + "context_checkpoint_name": self.context_checkpoint_name, + "softmax_offset": softmax_offset, + "score_mod": self.score_mod, + "score_mod_bprop": self.score_mod_bprop, + "score_mod_tensors": score_mod_tensors, + "score_mod_bprop_tensors": score_mod_bprop_tensors, + } + if self.qkv_layout.is_qkvpacked(): """qkvpacked format, treat query: qkvpacked tensor, shape = [..., 3, h, d] @@ -349,25 +379,23 @@ def __call__( qkv_packed = query if self.transpose_batch_sequence: qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4]) + if score_mod_requested: + query, key, value = jnp.split(qkv_packed, [1, 2], axis=-3) + query, key, value = map( + functools.partial(jnp.squeeze, axis=-3), [query, key, value] + ) + qkv_layout = self.qkv_layout.to_separate() + qkv_args = (query, key, value) + else: + qkv_layout = self.qkv_layout + qkv_args = (qkv_packed,) x = fused_attn( - (qkv_packed,), + qkv_args, bias, sequence_descriptor, seed, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - qkv_layout=self.qkv_layout, - softmax_type=self.softmax_type, - scaling_factor=scale_factor, - dropout_probability=self.attention_dropout, - is_training=not deterministic, - window_size=self.window_size, - max_segments_per_seq=self.max_segments_per_seq, - context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, - context_parallel_axis=self.context_parallel_axis, - context_parallel_strategy=self.context_parallel_strategy, - context_checkpoint_name=self.context_checkpoint_name, - softmax_offset=softmax_offset, + qkv_layout=qkv_layout, + **fused_attn_kwargs, ) elif self.qkv_layout.is_kvpacked(): """kvpacked format, treat @@ -379,25 +407,21 @@ def __call__( if self.transpose_batch_sequence: query = query.transpose([1, 0, 2, 3]) kv_packed = kv_packed.transpose([1, 0, 2, 3, 4]) + if score_mod_requested: + key, value = jnp.split(kv_packed, [1], axis=-3) + key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value]) + qkv_layout = self.qkv_layout.to_separate() + qkv_args = (query, key, value) + else: + qkv_layout = self.qkv_layout + qkv_args = (query, kv_packed) x = fused_attn( - (query, kv_packed), + qkv_args, bias, sequence_descriptor, seed, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - qkv_layout=self.qkv_layout, - softmax_type=self.softmax_type, - scaling_factor=scale_factor, - dropout_probability=self.attention_dropout, - is_training=not deterministic, - window_size=self.window_size, - max_segments_per_seq=self.max_segments_per_seq, - context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, - context_parallel_axis=self.context_parallel_axis, - context_parallel_strategy=self.context_parallel_strategy, - context_checkpoint_name=self.context_checkpoint_name, - softmax_offset=softmax_offset, + qkv_layout=qkv_layout, + **fused_attn_kwargs, ) elif self.qkv_layout.is_separate(): if self.transpose_batch_sequence: @@ -409,20 +433,8 @@ def __call__( bias, sequence_descriptor, seed, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, qkv_layout=self.qkv_layout, - softmax_type=self.softmax_type, - scaling_factor=scale_factor, - dropout_probability=self.attention_dropout, - is_training=not deterministic, - window_size=self.window_size, - max_segments_per_seq=self.max_segments_per_seq, - context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, - context_parallel_axis=self.context_parallel_axis, - context_parallel_strategy=self.context_parallel_strategy, - context_checkpoint_name=self.context_checkpoint_name, - softmax_offset=softmax_offset, + **fused_attn_kwargs, ) else: raise ValueError(f"Unsupported {self.qkv_layout=}.") @@ -602,6 +614,16 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention (``'zero sink'`` and ``'learnable sink'``). + score_mod: Optional[Callable], default = None + Experimental cuDNN frontend score modification callback. When set, fused attention is + required and the same score_mod limitations as :func:`fused_attn` apply. + score_mod_bprop: Optional[Callable], default = None + Optional score modification backward callback. + score_mod_tensors: Optional[Mapping[str, Any]], default = None + Additional tensors or pass-by-value scalars for ``score_mod``. This is a call-time + argument to keep tensor operands as normal JAX inputs. + score_mod_bprop_tensors: Optional[Mapping[str, Any]], default = None + Additional tensors or pass-by-value scalars for ``score_mod_bprop``. Optimization parameters ----------------------- @@ -628,6 +650,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods context_parallel_strategy: str = "DEFAULT" context_checkpoint_name: str = "context" softmax_type: str = "vanilla" + score_mod: Optional[Callable] = None + score_mod_bprop: Optional[Callable] = None def __post_init__(self): # TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12 @@ -669,6 +693,8 @@ def __call__( *, deterministic: bool = False, mask: Optional[Union[SequenceDescriptor, Array]] = None, + score_mod_tensors: Optional[Mapping[str, Any]] = None, + score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, ) -> Array: """ Parameters @@ -689,6 +715,10 @@ def __call__( Below parameters are keyword only deterministic: bool, default = False Disable dropout layers if set to True. + score_mod_tensors: Optional[Mapping[str, Any]], default = None + Additional tensor/scalar operands passed to ``score_mod``. + score_mod_bprop_tensors: Optional[Mapping[str, Any]], default = None + Additional tensor/scalar operands passed to ``score_mod_bprop``. Returns ------- @@ -715,6 +745,12 @@ def __call__( qkv_layout = QKVLayout[self.qkv_layout.upper()] softmax_type = AttnSoftmaxType.from_str(self.softmax_type) del self.attn_bias_type, self.attn_mask_type, self.qkv_layout + score_mod_requested = ( + self.score_mod is not None + or self.score_mod_bprop is not None + or score_mod_tensors is not None + or score_mod_bprop_tensors is not None + ) if attn_bias_type == AttnBiasType.NO_BIAS: assert ( @@ -760,25 +796,30 @@ def __call__( head_dim_qk = self.head_dim head_dim_v = self.head_dim - has_fused_attn_kernel = is_fused_attn_kernel_available( - # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. - not deterministic, - input_dtype, - # self._assert_dtypes enforces Q, K, V, bias to have the same dtype so using input_dtype as kv dtype is sufficient - input_dtype, - qkv_layout, - attn_bias_type, - attn_mask_type, - softmax_type, - self.attention_dropout, - self.num_attention_heads, - self.num_gqa_groups, - seqlen_q, - seqlen_kv, - head_dim_qk, - head_dim_v, - self.window_size, - ) + if score_mod_requested: + if not enable_fused_attn: + raise ValueError("score_mod requires fused attention, but NVTE_FUSED_ATTN=0.") + has_fused_attn_kernel = True + else: + has_fused_attn_kernel = is_fused_attn_kernel_available( + # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. + not deterministic, + input_dtype, + # self._assert_dtypes enforces Q, K, V, bias to have the same dtype so using input_dtype as kv dtype is sufficient + input_dtype, + qkv_layout, + attn_bias_type, + attn_mask_type, + softmax_type, + self.attention_dropout, + self.num_attention_heads, + self.num_gqa_groups, + seqlen_q, + seqlen_kv, + head_dim_qk, + head_dim_v, + self.window_size, + ) use_fused_attn = enable_fused_attn and has_fused_attn_kernel @@ -872,6 +913,8 @@ def __call__( context_parallel_strategy=context_parallel_strategy, context_checkpoint_name=self.context_checkpoint_name, softmax_type=softmax_type, + score_mod=self.score_mod, + score_mod_bprop=self.score_mod_bprop, )( query, key, @@ -880,6 +923,8 @@ def __call__( bias, dropout_rng=dropout_rng, deterministic=deterministic, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, ) assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}" return x @@ -1178,6 +1223,16 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention (``'zero sink'`` and ``'learnable sink'``). + score_mod: Optional[Callable], default = None + Experimental cuDNN frontend score modification callback. When set, fused attention is + required and the same score_mod limitations as :func:`fused_attn` apply. + score_mod_bprop: Optional[Callable], default = None + Optional score modification backward callback. + score_mod_tensors: Optional[Mapping[str, Any]], default = None + Additional tensors or pass-by-value scalars for ``score_mod``. This is a call-time + argument to keep tensor operands as normal JAX inputs. + score_mod_bprop_tensors: Optional[Mapping[str, Any]], default = None + Additional tensors or pass-by-value scalars for ``score_mod_bprop``. """ head_dim: int @@ -1210,6 +1265,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods float32_logits: bool = False window_size: Optional[Tuple[int, int]] = None softmax_type: str = "vanilla" + score_mod: Optional[Callable] = None + score_mod_bprop: Optional[Callable] = None # Deprecated parameters num_heads: Optional[int] = None @@ -1278,6 +1335,8 @@ def __call__( *, decode: bool = False, deterministic: bool = False, + score_mod_tensors: Optional[Mapping[str, Any]] = None, + score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, ) -> Array: """ MultiHeadAttention Layer: @@ -1300,6 +1359,10 @@ def __call__( Indicate whether to prepare and use an autoregressive cache. deterministic: bool, default = False Disable dropout layers if set to True. + score_mod_tensors: Optional[Mapping[str, Any]], default = None + Additional tensor/scalar operands passed to ``score_mod``. + score_mod_bprop_tensors: Optional[Mapping[str, Any]], default = None + Additional tensor/scalar operands passed to ``score_mod_bprop``. Returns ------- @@ -1631,7 +1694,16 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): transpose_batch_sequence=self.transpose_batch_sequence, window_size=self.window_size, softmax_type=self.softmax_type, - )(*dpa_args, mask, bias, deterministic=deterministic) + score_mod=self.score_mod, + score_mod_bprop=self.score_mod_bprop, + )( + *dpa_args, + mask, + bias, + deterministic=deterministic, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, + ) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES) @@ -1939,6 +2011,17 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention (``'zero sink'`` and ``'learnable sink'``). Only supported for fused attention backend. + score_mod: Optional[Callable], default = None + Experimental cuDNN frontend score modification callback for the self-attention block. + When set, fused attention is required and the same score_mod limitations as + :func:`fused_attn` apply. + score_mod_bprop: Optional[Callable], default = None + Optional score modification backward callback for the self-attention block. + score_mod_tensors: Optional[Mapping[str, Any]], default = None + Additional tensors or pass-by-value scalars for ``score_mod``. This is a call-time + argument to keep tensor operands as normal JAX inputs. + score_mod_bprop_tensors: Optional[Mapping[str, Any]], default = None + Additional tensors or pass-by-value scalars for ``score_mod_bprop``. Optimization parameters ----------------------- @@ -2005,6 +2088,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods scaled_query_init: bool = True window_size: Optional[Tuple[int, int]] = None softmax_type: str = "vanilla" + score_mod: Optional[Callable] = None + score_mod_bprop: Optional[Callable] = None def __post_init__(self): if self.mha_kernel_init is None: @@ -2029,6 +2114,8 @@ def __call__( deterministic: bool = False, decode: bool = False, max_decode_length: bool = None, + score_mod_tensors: Optional[Mapping[str, Any]] = None, + score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, ): """ Transformer Layer: attention block and a feedforward network (MLP) @@ -2057,6 +2144,10 @@ def __call__( The maximum length to generate relative embedding biases when :attr:`layer_type=TransformerLayerType.DECODER` and :attr:`enable_relative_embedding=True`. + score_mod_tensors: Optional[Mapping[str, Any]], default = None + Additional tensor/scalar operands passed to the self-attention ``score_mod``. + score_mod_bprop_tensors: Optional[Mapping[str, Any]], default = None + Additional tensor/scalar operands passed to the self-attention ``score_mod_bprop``. Returns ------- @@ -2168,7 +2259,18 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): name=mha_name, window_size=self.window_size, softmax_type=self.softmax_type, - )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode) + score_mod=self.score_mod, + score_mod_bprop=self.score_mod_bprop, + )( + inputs, + inputs, + attention_mask, + attn_bias, + deterministic=deterministic, + decode=decode, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, + ) def hidden_dropout(x, deterministic): assert isinstance(