diff --git a/build_tools/jax.py b/build_tools/jax.py index a7b200f915..f8cf695c16 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -83,6 +83,20 @@ def setup_jax_extension( # Header files include_dirs = get_cuda_include_dirs() + cudnn_frontend_include_dir = None + for base_path in (Path(common_header_files), *Path(common_header_files).parents): + candidate = base_path / "3rdparty" / "cudnn-frontend" / "include" + if candidate.exists(): + cudnn_frontend_include_dir = candidate + break + if cudnn_frontend_include_dir is None: + for base_path in Path(__file__).resolve().parents: + candidate = base_path / "3rdparty" / "cudnn-frontend" / "include" + if candidate.exists(): + cudnn_frontend_include_dir = candidate + break + if cudnn_frontend_include_dir is not None: + include_dirs.append(cudnn_frontend_include_dir) include_dirs.extend( [ common_header_files, diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 50c5de1db7..dfcfab3312 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -6,14 +6,28 @@ import pytest import jax import jax.numpy as jnp +import numpy as np from jax import random +from jax.sharding import Mesh, NamedSharding, PartitionSpec from distributed_test_base import ( generate_configs, generate_context_parallel_configs_for_attn, generate_collectives_count, ) -from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat -from utils import pytest_parametrize_wrapper +from test_fused_attn import ( + FusedAttnRunner, + BiasShape, + SeqDescFormat, + customcall_fused_dpa, +) +from test_fused_attn_score_mod import ( + _ScoreModSoftcap, + _has_cudnn_frontend_python, + _reference_attention, + _require_cudnn_frontend_score_mod, +) +from utils import assert_allclose, pytest_parametrize_wrapper +from transformer_engine.jax import autocast from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, AttnBiasType, @@ -272,6 +286,116 @@ def test_cross_attn( runner.test_backward() +DISTRIBUTED_SCORE_MOD_DATA_SHAPES = { + "L0": [], + "L1": [(4, 16, 4, 64)], +} + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +class TestDistributedScoreModSelfAttn: + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SCORE_MOD_DATA_SHAPES) + @pytest.mark.parametrize("dtype", DTYPES) + def test_softcap_score_mod_with_aux_params_backward( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + dtype, + ): + _require_cudnn_frontend_score_mod() + batch, seqlen, num_heads, head_dim = data_shape + dp_axis = mesh_resource.dp_resource + tp_axis = mesh_resource.tpsp_resource + + if dp_axis is not None: + dp_size = mesh_shape[mesh_axes.index(dp_axis)] + if batch % dp_size != 0: + pytest.skip(f"{batch=} must be divisible by {dp_size=}") + if tp_axis is not None: + tp_size = mesh_shape[mesh_axes.index(tp_axis)] + if num_heads % tp_size != 0: + pytest.skip(f"{num_heads=} must be divisible by {tp_size=}") + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + qkv_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, tp_axis, None)) + + key = random.PRNGKey(2025) + q_key, k_key, v_key, dout_key = random.split(key, 4) + query = (0.125 * random.normal(q_key, data_shape, dtype=dtype)).astype(dtype) + key_tensor = (0.125 * random.normal(k_key, data_shape, dtype=dtype)).astype(dtype) + value = (0.125 * random.normal(v_key, data_shape, dtype=dtype)).astype(dtype) + doutput = random.normal(dout_key, data_shape, dtype=dtype) + + scaling_factor = head_dim**-0.5 + softcap = 0.8 + softcap_score_mod = _ScoreModSoftcap() + + def score_mod_loss(q, k, v, dout): + out = customcall_fused_dpa( + q, + k, + v, + None, + None, + None, + None, + attn_bias_type=AttnBiasType.NO_BIAS, + attn_mask_type=AttnMaskType.NO_MASK, + qkv_layout=QKVLayout.BSHD_BSHD_BSHD, + softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX, + scaling_factor=scaling_factor, + dropout_probability=0.0, + is_training=True, + score_mod=softcap_score_mod.forward, + score_mod_bprop=softcap_score_mod.backward, + score_mod_tensors={"softcap": softcap}, + score_mod_bprop_tensors={"softcap": softcap}, + ) + loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32)) + return loss, out + + def ref_loss(q, k, v, dout): + out = _reference_attention(q, k, v, scaling_factor, softcap=softcap) + loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32)) + return loss, out + + jitted_score_mod = jax.jit( + jax.value_and_grad(score_mod_loss, argnums=(0, 1, 2), has_aux=True), + in_shardings=( + qkv_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ), + out_shardings=((None, qkv_sharding), (qkv_sharding, qkv_sharding, qkv_sharding)), + ) + jitted_ref = jax.jit(jax.value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)) + + sharded_args = ( + jax.device_put(query, qkv_sharding), + jax.device_put(key_tensor, qkv_sharding), + jax.device_put(value, qkv_sharding), + jax.device_put(doutput, qkv_sharding), + ) + with mesh, autocast(mesh_resource=mesh_resource): + (score_mod_value, score_mod_out), score_mod_grads = jitted_score_mod(*sharded_args) + (ref_value, ref_out), ref_grads = jitted_ref(query, key_tensor, value, doutput) + + assert score_mod_out.sharding == qkv_sharding + for grad in score_mod_grads: + assert grad.sharding == qkv_sharding + + assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2) + assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2) + + DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [ pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"), diff --git a/tests/jax/test_fused_attn_score_mod.py b/tests/jax/test_fused_attn_score_mod.py new file mode 100644 index 0000000000..c9debde05e --- /dev/null +++ b/tests/jax/test_fused_attn_score_mod.py @@ -0,0 +1,666 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Tests for cuDNN frontend score_mod fused attention.""" +from math import sqrt + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from jax import value_and_grad + +import transformer_engine.jax.cpp_extensions.attention as tex_attention +from transformer_engine.jax.attention import ( + AttnBiasType, + AttnMaskType, + AttnSoftmaxType, + QKVLayout, + fused_attn, +) +from transformer_engine.jax.cpp_extensions import make_fused_attn_score_mod_config +from transformer_engine.jax.flax import transformer as flax_transformer +from utils import assert_allclose + + +def _has_cudnn_frontend_python(): + try: + tex_attention._import_cudnn_for_score_mod() + except ImportError: + return False + return True + + +def _score_mod_causal(graph, score, tensors): + cudnn = tex_attention._import_cudnn_for_score_mod() + + row_index = graph.gen_index( + input=score, + axis=2, + compute_data_type=cudnn.data_type.INT32, + ) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = graph.gen_index( + input=score, + axis=3, + compute_data_type=cudnn.data_type.INT32, + ) + col_index.set_data_type(cudnn.data_type.INT32) + keep = graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return graph.binary_select(input0=score, input1=tensors["neg_inf"], mask=keep) + + +def _score_mod_causal_bprop(graph, dscore, tensors): + cudnn = tex_attention._import_cudnn_for_score_mod() + + row_index = graph.gen_index( + input=dscore, + axis=2, + compute_data_type=cudnn.data_type.INT32, + ) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = graph.gen_index( + input=dscore, + axis=3, + compute_data_type=cudnn.data_type.INT32, + ) + col_index.set_data_type(cudnn.data_type.INT32) + keep = graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return graph.binary_select(input0=dscore, input1=tensors["zero"], mask=keep) + + +def _score_mod_post_scale_bias(graph, score, tensors): + cudnn = tex_attention._import_cudnn_for_score_mod() + + row_index = graph.gen_index( + input=score, + axis=2, + compute_data_type=cudnn.data_type.INT32, + ) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = graph.gen_index( + input=score, + axis=3, + compute_data_type=cudnn.data_type.INT32, + ) + col_index.set_data_type(cudnn.data_type.INT32) + post_scale_bias = graph.sub( + a=row_index, + b=col_index, + compute_data_type=cudnn.data_type.FLOAT, + ) + post_scale_bias.set_data_type(cudnn.data_type.FLOAT) + return graph.add( + a=score, + b=post_scale_bias, + compute_data_type=cudnn.data_type.FLOAT, + ) + + +class _ScoreModSoftcap: + """cuDNN frontend score_mod implementing softcapping.""" + + def __init__(self): + self.before_tanh_activation = None + + def score_mod_graph_cache_key(self): + """Graph topology key for softcap score_mod.""" + return ("softcap",) + + def forward(self, graph, score, tensors): + cudnn = tex_attention._import_cudnn_for_score_mod() + + self.before_tanh_activation = graph.div( + a=score, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + tanh_out = graph.tanh(input=self.before_tanh_activation) + tanh_out.set_data_type(cudnn.data_type.FLOAT) + return graph.mul( + a=tanh_out, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + def backward(self, graph, dscore, tensors): + cudnn = tex_attention._import_cudnn_for_score_mod() + + d_tanh_out = graph.mul( + a=dscore, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + d_tanh_out.set_data_type(cudnn.data_type.FLOAT) + d_before_tanh_activation = graph.tanh_backward( + loss=d_tanh_out, + input=self.before_tanh_activation, + compute_data_type=cudnn.data_type.FLOAT, + ) + d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + return graph.div( + a=d_before_tanh_activation, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + +def _reference_attention( + query, key, value, scale, *, causal=False, post_scale_bias=False, softcap=None +): + scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale + if causal: + q_pos = jnp.arange(query.shape[1])[:, None] + kv_pos = jnp.arange(key.shape[1])[None, :] + scores = jnp.where(q_pos >= kv_pos, scores, -1e9) + if post_scale_bias: + q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None] + kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :] + scores = scores + q_pos - kv_pos + if softcap is not None: + scores = softcap * jnp.tanh(scores / softcap) + probs = jax.nn.softmax(scores, axis=-1) + return jnp.einsum("bhqk,bkhd->bqhd", probs, value).astype(query.dtype) + + +@pytest.fixture(autouse=True, scope="module") +def init(): + """ + WAR for CUDA uninitialize error + """ + # Calling customcalls before jax may cause CUDA uninitialize error + _ = jnp.zeros(0) + yield + + +def _require_cudnn_frontend_score_mod(): + try: + cudnn = tex_attention._import_cudnn_for_score_mod() + except ImportError: + pytest.skip("cuDNN Python frontend is required for score_mod") + version = tuple(int(part) for part in cudnn.backend_version_string().split(".")[:2]) + if version < (9, 6): + pytest.skip("cuDNN score_mod SDPA requires cuDNN frontend 9.6 or newer") + + +def _identity_score_mod(_graph, score, _tensors): + return score + + +def _install_fake_flax_fused_attn(monkeypatch): + captured = {} + + def fused_attn_kernel_check_should_not_run(*_args, **_kwargs): + raise AssertionError("score_mod path should not use the standard fused-attn kernel check") + + def fake_fused_attn( + qkv, + bias, + sequence_descriptor, + seed, + *, + attn_bias_type, + attn_mask_type, + qkv_layout, + softmax_type, + scaling_factor, + dropout_probability, + is_training, + max_segments_per_seq=1, + window_size=None, + context_parallel_strategy=None, + context_parallel_causal_load_balanced=False, + context_parallel_axis="", + context_checkpoint_name="context", + softmax_offset=None, + stripe_size=None, + score_mod=None, + score_mod_bprop=None, + score_mod_tensors=None, + score_mod_bprop_tensors=None, + ): + captured.update( + qkv=qkv, + bias=bias, + sequence_descriptor=sequence_descriptor, + seed=seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + softmax_type=softmax_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + window_size=window_size, + context_parallel_strategy=context_parallel_strategy, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, + context_checkpoint_name=context_checkpoint_name, + softmax_offset=softmax_offset, + stripe_size=stripe_size, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, + ) + return qkv[0] + + monkeypatch.setattr( + flax_transformer, + "is_fused_attn_kernel_available", + fused_attn_kernel_check_should_not_run, + ) + monkeypatch.setattr(flax_transformer, "fused_attn", fake_fused_attn) + return captured + + +def test_dot_product_attention_score_mod_requires_fused_attn(monkeypatch): + monkeypatch.setenv("NVTE_FUSED_ATTN", "0") + query = jnp.ones((1, 8, 1, 16), dtype=jnp.float16) + key = jnp.ones((1, 8, 1, 16), dtype=jnp.float16) + value = jnp.ones((1, 8, 1, 16), dtype=jnp.float16) + + dpa = flax_transformer.DotProductAttention( + head_dim=16, + num_attention_heads=1, + num_gqa_groups=1, + attn_mask_type="no_mask", + qkv_layout="bshd_bshd_bshd", + transpose_batch_sequence=False, + score_mod=_identity_score_mod, + ) + + with pytest.raises(ValueError, match="score_mod requires fused attention"): + dpa.apply({}, query, key, value, deterministic=True) + + +def test_dot_product_attention_plumbs_score_mod_to_fused_attn(monkeypatch): + captured = _install_fake_flax_fused_attn(monkeypatch) + query = jnp.ones((1, 8, 1, 16), dtype=jnp.float16) + key = jnp.ones((1, 8, 1, 16), dtype=jnp.float16) + value = jnp.ones((1, 8, 1, 16), dtype=jnp.float16) + aux = jnp.ones((1, 1, 1, 1), dtype=jnp.float32) + + dpa = flax_transformer.DotProductAttention( + head_dim=16, + num_attention_heads=1, + num_gqa_groups=1, + attn_mask_type="no_mask", + qkv_layout="bshd_bshd_bshd", + transpose_batch_sequence=False, + score_mod=_identity_score_mod, + ) + out = dpa.apply({}, query, key, value, deterministic=True, score_mod_tensors={"aux": aux}) + + np.testing.assert_array_equal(out, query) + assert captured["score_mod"] is _identity_score_mod + assert captured["score_mod_tensors"]["aux"].shape == aux.shape + assert captured["score_mod_bprop"] is None + assert captured["score_mod_bprop_tensors"] is None + assert captured["attn_mask_type"] is AttnMaskType.NO_MASK + assert captured["attn_bias_type"] is AttnBiasType.NO_BIAS + assert captured["qkv_layout"] is QKVLayout.BSHD_BSHD_BSHD + assert captured["softmax_type"] is AttnSoftmaxType.VANILLA_SOFTMAX + + +def test_dot_product_attention_unpacks_packed_score_mod_to_separate_layout(monkeypatch): + captured = _install_fake_flax_fused_attn(monkeypatch) + qkv = jnp.ones((1, 8, 3, 1, 16), dtype=jnp.float16) + + dpa = flax_transformer.DotProductAttention( + head_dim=16, + num_attention_heads=1, + num_gqa_groups=1, + attn_mask_type="no_mask", + qkv_layout="bs3hd", + transpose_batch_sequence=False, + score_mod=_identity_score_mod, + ) + out = dpa.apply({}, qkv, None, None, deterministic=True) + + assert out.shape == (1, 8, 1, 16) + assert len(captured["qkv"]) == 3 + assert captured["qkv"][0].shape == (1, 8, 1, 16) + assert captured["qkv_layout"] is QKVLayout.BSHD_BSHD_BSHD + assert captured["score_mod"] is _identity_score_mod + + +def test_multi_head_attention_plumbs_score_mod_to_dot_product_attention(monkeypatch): + captured = _install_fake_flax_fused_attn(monkeypatch) + + class FakeLayerNormDenseGeneral: + def __init__(self, *, features, return_layernorm_output=False, **_kwargs): + self.features = features + self.return_layernorm_output = return_layernorm_output + + def __call__(self, inputs): + features = self.features if isinstance(self.features, tuple) else (self.features,) + output = jnp.ones((*inputs.shape[:-1], *features), dtype=inputs.dtype) + ln_out = inputs if self.return_layernorm_output else None + return output, ln_out + + class FakeDenseGeneral: + def __init__(self, *, features, **_kwargs): + self.features = features + + def __call__(self, inputs): + features = self.features if isinstance(self.features, tuple) else (self.features,) + return jnp.ones((*inputs.shape[:-1], *features), dtype=inputs.dtype) + + monkeypatch.setattr(flax_transformer, "LayerNormDenseGeneral", FakeLayerNormDenseGeneral) + monkeypatch.setattr(flax_transformer, "DenseGeneral", FakeDenseGeneral) + + inputs = jnp.ones((1, 8, 16), dtype=jnp.float16) + aux = jnp.ones((1, 1, 1, 1), dtype=jnp.float32) + + mha = flax_transformer.MultiHeadAttention( + head_dim=16, + num_attention_heads=1, + num_gqa_groups=1, + input_layernorm=False, + attention_dropout=0.0, + attn_mask_type="no_mask", + fuse_qkv_params=True, + transpose_batch_sequence=False, + score_mod=_identity_score_mod, + ) + + variables = mha.init( + jax.random.key(0), + inputs, + inputs, + deterministic=True, + score_mod_tensors={"aux": aux}, + ) + out, ln_out = mha.apply( + variables, + inputs, + inputs, + deterministic=True, + score_mod_tensors={"aux": aux}, + ) + + assert out.shape == inputs.shape + assert ln_out is None + assert len(captured["qkv"]) == 3 + assert captured["qkv_layout"] is QKVLayout.BSHD_BSHD_BSHD + assert captured["score_mod"] is _identity_score_mod + assert captured["score_mod_tensors"]["aux"].shape == aux.shape + + +def test_fused_attn_score_mod_validation_rejects_masks_without_cudnn_frontend(): + q = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) + k = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) + v = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) + + with pytest.raises(ValueError, match="mutually exclusive with attention masks"): + fused_attn( + (q, k, v), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.CAUSAL_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + 1.0, + 0.0, + True, + score_mod=_identity_score_mod, + ) + + +def test_fused_attn_score_mod_config_splits_tensors_and_pass_by_value_scalars(): + tensor = jnp.ones((1, 1, 1, 1), dtype=jnp.float32) + + config, tensor_operands, bprop_tensor_operands = make_fused_attn_score_mod_config( + _identity_score_mod, + None, + {"tensor": tensor, "neg_inf": -1e9}, + None, + 0.125, + True, + ) + + assert config.score_mod_tensor_names == ("tensor",) + assert len(tensor_operands) == 1 + assert tensor_operands[0].shape == (1, 1, 1, 1) + assert len(bprop_tensor_operands) == 0 + assert len(config.score_mod_scalars) == 1 + assert config.score_mod_scalars[0].name == "neg_inf" + assert config.score_mod_scalars[0].dtype == "float32" + assert len(config.score_mod_scalars[0].value) == np.dtype(np.float32).itemsize + + +def test_fused_attn_score_mod_cudnn_frontend_version_check(monkeypatch): + class FakeCudnn: + __version__ = "1.22.0" + + monkeypatch.setattr( + tex_attention.transformer_engine_jax, + "get_cudnn_frontend_version", + lambda: 12200, + ) + assert tex_attention._check_cudnn_frontend_version_match(FakeCudnn) == 12200 + + monkeypatch.setattr( + tex_attention.transformer_engine_jax, + "get_cudnn_frontend_version", + lambda: 12100, + ) + with pytest.raises(RuntimeError, match="Python/C\\+\\+ version mismatch"): + tex_attention._check_cudnn_frontend_version_match(FakeCudnn) + + +def test_fused_attn_score_mod_config_stabilizes_bound_method_cache_keys(): + softcap_score_mod = _ScoreModSoftcap() + first_forward = softcap_score_mod.forward + second_forward = softcap_score_mod.forward + first_backward = softcap_score_mod.backward + second_backward = softcap_score_mod.backward + + assert first_forward is not second_forward + assert first_backward is not second_backward + + config_1, _, _ = make_fused_attn_score_mod_config( + first_forward, + first_backward, + {"softcap": 0.8}, + {"softcap": 0.8}, + 0.125, + True, + ) + config_2, _, _ = make_fused_attn_score_mod_config( + second_forward, + second_backward, + {"softcap": 0.8}, + {"softcap": 0.8}, + 0.125, + True, + ) + other_softcap_score_mod = _ScoreModSoftcap() + config_3, _, _ = make_fused_attn_score_mod_config( + other_softcap_score_mod.forward, + other_softcap_score_mod.backward, + {"softcap": 0.8}, + {"softcap": 0.8}, + 0.125, + True, + ) + + assert config_1 == config_2 + assert hash(config_1) == hash(config_2) + assert config_1 == config_3 + + +def test_fused_attn_score_mod_config_leaves_unkeyed_bound_methods_uncached(): + class UnkeyedScoreMod: + def forward(self, _graph, score, _tensors): + return score + + score_mod = UnkeyedScoreMod() + config_1, _, _ = make_fused_attn_score_mod_config( + score_mod.forward, None, None, None, 0.125, True + ) + config_2, _, _ = make_fused_attn_score_mod_config( + score_mod.forward, None, None, None, 0.125, True + ) + + assert config_1 != config_2 + assert tex_attention._graph_cache_key("fwd", config_1, ()) is None + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +def test_fused_attn_score_mod_post_scale_bias_optional_bprop(): + _require_cudnn_frontend_score_mod() + + key = jax.random.key(0) + q_key, k_key, v_key = jax.random.split(key, 3) + q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + scale = 1.0 / sqrt(q.shape[-1]) + + def score_mod_loss(query, key_, value): + out = fused_attn( + (query, key_, value), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.NO_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + scale, + 0.0, + True, + score_mod=_score_mod_post_scale_bias, + ) + return jnp.sum(out.astype(jnp.float32)), out + + def ref_loss(query, key_, value): + out = _reference_attention(query, key_, value, scale, post_scale_bias=True) + return jnp.sum(out.astype(jnp.float32)), out + + (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( + score_mod_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( + q, k, v + ) + + assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) + assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=5e-2, atol=5e-2) + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +def test_fused_attn_score_mod_causal_with_bprop(): + _require_cudnn_frontend_score_mod() + + key = jax.random.key(1) + q_key, k_key, v_key = jax.random.split(key, 3) + q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + scale = 1.0 / sqrt(q.shape[-1]) + + def score_mod_loss(query, key_, value): + out = fused_attn( + (query, key_, value), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.NO_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + scale, + 0.0, + True, + score_mod=_score_mod_causal, + score_mod_bprop=_score_mod_causal_bprop, + score_mod_tensors={"neg_inf": -1e9}, + score_mod_bprop_tensors={"zero": 0.0}, + ) + return jnp.sum(out.astype(jnp.float32)), out + + def ref_loss(query, key_, value): + out = _reference_attention(query, key_, value, scale, causal=True) + return jnp.sum(out.astype(jnp.float32)), out + + (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( + score_mod_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( + q, k, v + ) + + assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) + assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=5e-2, atol=5e-2) + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +def test_fused_attn_score_mod_softcap_with_bprop(): + _require_cudnn_frontend_score_mod() + + key = jax.random.key(2) + q_key, k_key, v_key, d_out_key = jax.random.split(key, 4) + q = jax.random.normal(q_key, (1, 16, 2, 64), dtype=jnp.float16) + k = jax.random.normal(k_key, (1, 16, 2, 64), dtype=jnp.float16) + v = (0.1 * jax.random.normal(v_key, (1, 16, 2, 64), dtype=jnp.float16)).astype(jnp.float16) + d_out = jax.random.normal(d_out_key, (1, 16, 2, 64), dtype=jnp.float16) + scale = 1.0 / sqrt(q.shape[-1]) + softcap = 0.8 + softcap_score_mod = _ScoreModSoftcap() + + def score_mod_loss(query, key_, value): + out = fused_attn( + (query, key_, value), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.NO_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + scale, + 0.0, + True, + score_mod=softcap_score_mod.forward, + score_mod_bprop=softcap_score_mod.backward, + score_mod_tensors={"softcap": softcap}, + score_mod_bprop_tensors={"softcap": softcap}, + ) + return jnp.sum(out.astype(jnp.float32) * d_out.astype(jnp.float32)), out + + def ref_loss(query, key_, value): + out = _reference_attention(query, key_, value, scale, softcap=softcap) + return jnp.sum(out.astype(jnp.float32) * d_out.astype(jnp.float32)), out + + (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( + score_mod_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( + q, k, v + ) + + assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2) + assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index f54a043fd2..dbee272046 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -5,7 +5,7 @@ from __future__ import annotations from enum import Enum from functools import partial -from typing import Optional, Tuple, Union +from typing import Any, Callable, Mapping, Optional, Tuple, Union import warnings from jax.ad_checkpoint import checkpoint_name @@ -1391,10 +1391,127 @@ def _fused_attn_bwd_rule( _fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4)) +def _fused_attn_score_mod( + qkv: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], + score_mod_tensors: Tuple[jnp.ndarray, ...], + score_mod_bprop_tensors: Tuple[jnp.ndarray, ...], + config, + context_checkpoint_name: str, +): + output, _ = _fused_attn_score_mod_fwd_rule( + qkv, + score_mod_tensors, + score_mod_bprop_tensors, + config, + context_checkpoint_name, + ) + return output + + +def _fused_attn_score_mod_fwd_rule( + qkv, + score_mod_tensors, + score_mod_bprop_tensors, + config, + context_checkpoint_name, +): + output, softmax_stats = tex.fused_attn_score_mod_fwd(qkv, score_mod_tensors, config) + output = checkpoint_name(output, context_checkpoint_name) + softmax_stats = checkpoint_name(softmax_stats, context_checkpoint_name) + return output, (qkv, score_mod_tensors, score_mod_bprop_tensors, output, softmax_stats) + + +def _fused_attn_score_mod_bwd_rule(config, context_checkpoint_name, ctx, dz): + del context_checkpoint_name + qkv, score_mod_tensors, score_mod_bprop_tensors, output, softmax_stats = ctx + grad_qkv = tex.fused_attn_score_mod_bwd( + qkv, + output, + dz, + softmax_stats, + score_mod_tensors, + score_mod_bprop_tensors, + config, + ) + return ( + grad_qkv, + tuple(None for _ in score_mod_tensors), + tuple(None for _ in score_mod_bprop_tensors), + ) + + +_fused_attn_score_mod.defvjp(_fused_attn_score_mod_fwd_rule, _fused_attn_score_mod_bwd_rule) + + +def _validate_fused_attn_score_mod( + qkv: Tuple[jnp.ndarray, ...], + bias: Optional[jnp.ndarray], + sequence_descriptor: Optional[SequenceDescriptor], + seed: Optional[jnp.ndarray], + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + softmax_type: AttnSoftmaxType, + dropout_probability: float, + max_segments_per_seq: int, + window_size: Optional[Tuple[int, int]], + context_parallel_strategy: CPStrategy, + context_parallel_causal_load_balanced: bool, + context_parallel_axis: str, + softmax_offset: Optional[jnp.ndarray], + stripe_size: int | None, +): + """Validate arguments for the cuDNN frontend score_mod path.""" + header = "score_mod fused_attn" + if qkv_layout is not QKVLayout.BSHD_BSHD_BSHD: + raise ValueError(f"{header} currently only supports QKVLayout.BSHD_BSHD_BSHD.") + if len(qkv) != 3: + raise ValueError(f"{header} requires separate query, key and value tensors.") + if any(tensor.ndim != 4 for tensor in qkv): + raise ValueError(f"{header} requires rank-4 BSHD query/key/value tensors.") + q, k, v = qkv + if q.dtype != k.dtype or q.dtype != v.dtype: + raise ValueError(f"{header} requires query, key and value to have the same dtype.") + if q.dtype not in (jnp.float16, jnp.bfloat16): + raise ValueError(f"{header} only supports FP16/BF16 query, key and value tensors.") + if q.shape[0] != k.shape[0] or q.shape[0] != v.shape[0]: + raise ValueError(f"{header} requires matching batch dimensions.") + if k.shape[1] != v.shape[1]: + raise ValueError(f"{header} requires key and value sequence lengths to match.") + if k.shape[2] != v.shape[2]: + raise ValueError(f"{header} requires key and value head counts to match.") + if q.shape[3] != k.shape[3]: + raise ValueError(f"{header} requires query/key head dimensions to match.") + + if bias is not None or attn_bias_type is not AttnBiasType.NO_BIAS: + raise ValueError(f"{header} is mutually exclusive with attention bias.") + if sequence_descriptor is not None: + raise ValueError(f"{header} is mutually exclusive with padding/sequence descriptors.") + if seed is not None: + raise ValueError(f"{header} is mutually exclusive with dropout seed.") + if attn_mask_type is not AttnMaskType.NO_MASK: + raise ValueError(f"{header} is mutually exclusive with attention masks.") + if softmax_type is not AttnSoftmaxType.VANILLA_SOFTMAX or softmax_offset is not None: + raise ValueError(f"{header} only supports vanilla softmax without softmax_offset.") + if dropout_probability != 0.0: + raise ValueError(f"{header} is mutually exclusive with dropout.") + if max_segments_per_seq != 1: + raise ValueError(f"{header} is mutually exclusive with packed/ragged sequence metadata.") + if window_size not in (None, (-1, -1)): + raise ValueError(f"{header} is mutually exclusive with sliding-window attention.") + if context_parallel_strategy is not CPStrategy.DEFAULT: + raise ValueError(f"{header} is mutually exclusive with context parallelism.") + if context_parallel_causal_load_balanced or context_parallel_axis: + raise ValueError(f"{header} is mutually exclusive with context parallelism.") + if stripe_size is not None: + raise ValueError(f"{header} is mutually exclusive with striped context parallelism.") + + def fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], - sequence_descriptor: SequenceDescriptor, + sequence_descriptor: Optional[SequenceDescriptor], seed: Optional[jnp.ndarray], attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, @@ -1411,6 +1528,10 @@ def fused_attn( context_checkpoint_name: str = "context", softmax_offset: Optional[jnp.ndarray] = None, stripe_size: int | None = None, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Mapping[str, Any]] = None, + score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, ): """ Perform cuDNN fused attention. @@ -1453,6 +1574,20 @@ def fused_attn( Currently, a stripe_size > 1 is only supported for CP + THD + Striped + AG, whereas a stripe_size=1 is supported for both, CP + THD + Striped + AG and CP + THD + Striped + P2P(Ring) None indicates no striping strategy + score_mod (Optional[Callable]): Experimental cuDNN frontend score modification callback. + The callback is called as `score_mod(graph, score, tensors)` while building a + cuDNN frontend graph. When provided, this path only supports BSHD_BSHD_BSHD + layout and is mutually exclusive with masks, padding, bias, dropout, context + parallelism, sliding windows, and non-vanilla softmax. + score_mod_bprop (Optional[Callable]): Optional score modification backward callback, + called as `score_mod_bprop(graph, dscore, tensors)`. If omitted, cuDNN uses the + default backward behavior for the forward score modification graph. + score_mod_tensors (Optional[Mapping[str, Any]]): Additional tensors or Python/NumPy + scalars made available to `score_mod` through its `tensors` dictionary. Scalars + are represented as cuDNN pass-by-value tensors. Tensor entries are treated as + non-differentiable auxiliary inputs. + score_mod_bprop_tensors (Optional[Mapping[str, Any]]): Additional tensors or + Python/NumPy scalars made available to `score_mod_bprop`. Returns: (jnp.ndarray): The output tensor from the fused attention. @@ -1485,6 +1620,48 @@ def fused_attn( AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, QKVLayout.T3HD, 0.125, 0, True, 3) """ + if score_mod is None: + if score_mod_bprop is not None: + raise ValueError("score_mod_bprop requires score_mod to be provided.") + if score_mod_tensors is not None: + raise ValueError("score_mod_tensors requires score_mod to be provided.") + if score_mod_bprop_tensors is not None: + raise ValueError("score_mod_bprop_tensors requires score_mod to be provided.") + else: + _validate_fused_attn_score_mod( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type, + attn_mask_type, + qkv_layout, + softmax_type, + dropout_probability, + max_segments_per_seq, + window_size, + context_parallel_strategy, + context_parallel_causal_load_balanced, + context_parallel_axis, + softmax_offset, + stripe_size, + ) + config, tensor_operands, bprop_tensor_operands = tex.make_fused_attn_score_mod_config( + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + scaling_factor, + is_training, + ) + return _fused_attn_score_mod( + qkv, + tensor_operands, + bprop_tensor_operands, + config, + context_checkpoint_name, + ) + if sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray): warnings.warn( "Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. " diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 489bfde997..654d3dbf73 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -2,15 +2,21 @@ # # See LICENSE for license information. """JAX/TE custom ops for attention""" +import hashlib +import importlib +import inspect import operator import os +from pathlib import Path +import sys import warnings from dataclasses import dataclass, replace from functools import partial, reduce -from typing import Optional, Tuple +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple import jax import jax.numpy as jnp +import numpy as np from jax import dtypes, lax, ffi from jax.sharding import PartitionSpec, NamedSharding from jax.experimental.custom_partitioning import SdyShardingRule @@ -54,8 +60,15 @@ "FusedAttnHelper", "fused_attn_fwd", "fused_attn_bwd", + "make_fused_attn_score_mod_config", + "fused_attn_score_mod_fwd", + "fused_attn_score_mod_bwd", ] +_CUDNN_FRONTEND_PYTHON_PATH = ( + Path(__file__).resolve().parents[3] / "3rdparty" / "cudnn-frontend" / "python" +) + @partial( jax.tree_util.register_dataclass, @@ -267,6 +280,811 @@ def check_seed(self, seed, dropout_probability, is_training): return seed +@dataclass(frozen=True) +class _ScoreModScalarSpec: + """Static pass-by-value scalar used when building a cuDNN frontend graph.""" + + name: str + dtype: str + value: bytes + dim: Tuple[int, ...] = (1, 1, 1, 1) + stride: Tuple[int, ...] = (1, 1, 1, 1) + + +class _UncacheableScoreModKey: + """Unique static key for callbacks that must not share compiled score_mod graphs.""" + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return self is other + + +def _score_mod_key_is_uncacheable(key: Any) -> bool: + return isinstance(key, _UncacheableScoreModKey) + + +def _freeze_score_mod_cache_key(value: Any) -> Any: + """Convert a user-provided score_mod graph key into a hashable structure.""" + if _is_array_operand(value): + raise TypeError( + "score_mod_graph_cache_key() must not include tensors. Pass runtime tensors " + "through score_mod_tensors or score_mod_bprop_tensors instead." + ) + if isinstance(value, Mapping): + items = ( + ( + _freeze_score_mod_cache_key(key), + _freeze_score_mod_cache_key(val), + ) + for key, val in value.items() + ) + return tuple(sorted(items, key=repr)) + if isinstance(value, (list, tuple)): + return tuple(_freeze_score_mod_cache_key(item) for item in value) + if isinstance(value, (set, frozenset)): + items = (_freeze_score_mod_cache_key(item) for item in value) + return tuple(sorted(items, key=repr)) + try: + hash(value) + except TypeError as exc: + raise TypeError( + "score_mod_graph_cache_key() must return a hashable value or a nested " + "combination of mapping/list/tuple/set values." + ) from exc + return value + + +def _score_mod_explicit_cache_key(callback_owner: Any) -> Optional[Any]: + """Return a user-provided structural graph key for a score_mod callback.""" + explicit_key = getattr(callback_owner, "score_mod_graph_cache_key", None) + if explicit_key is None: + return None + explicit_key = explicit_key() if callable(explicit_key) else explicit_key + return _freeze_score_mod_cache_key(explicit_key) + + +def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Any: + """Create a stable graph cache key for a score_mod callable. + + Module-level functions are assumed to have stable topology. Stateful bound methods and + callable instances need an explicit score_mod_graph_cache_key(); otherwise their graphs + are left uncached to avoid reusing stale graphs after Python object address reuse. + """ + if callback is None: + return None + self_obj = getattr(callback, "__self__", None) + func_obj = getattr(callback, "__func__", None) + if self_obj is not None and func_obj is not None: + explicit_key = _score_mod_explicit_cache_key(self_obj) + if explicit_key is None: + return _UncacheableScoreModKey() + return ( + "bound_method", + type(self_obj), + func_obj.__module__, + func_obj.__qualname__, + explicit_key, + ) + + explicit_key = _score_mod_explicit_cache_key(callback) + if explicit_key is not None: + return ( + "callable", + type(callback), + getattr(callback, "__module__", None), + getattr(callback, "__qualname__", None), + explicit_key, + ) + + if ( + inspect.isfunction(callback) + and callback.__closure__ is None + and "" not in callback.__qualname__ + ): + return ("function", callback.__module__, callback.__qualname__) + + return _UncacheableScoreModKey() + + +@dataclass(frozen=True) +class _FusedAttnScoreModConfig: + """Static configuration for cuDNN frontend score_mod SDPA graphs.""" + + score_mod: Callable + score_mod_bprop: Optional[Callable] + score_mod_key: Any + score_mod_bprop_key: Any + score_mod_tensor_names: Tuple[str, ...] + score_mod_bprop_tensor_names: Tuple[str, ...] + score_mod_scalars: Tuple[_ScoreModScalarSpec, ...] + score_mod_bprop_scalars: Tuple[_ScoreModScalarSpec, ...] + scaling_factor: float + is_training: bool + deterministic: bool + + def __hash__(self): + return hash( + ( + self.score_mod_key, + self.score_mod_bprop_key, + self.score_mod_tensor_names, + self.score_mod_bprop_tensor_names, + self.score_mod_scalars, + self.score_mod_bprop_scalars, + self.scaling_factor, + self.is_training, + self.deterministic, + ) + ) + + def __eq__(self, other): + if not isinstance(other, _FusedAttnScoreModConfig): + return False + return ( + self.score_mod_key == other.score_mod_key + and self.score_mod_bprop_key == other.score_mod_bprop_key + and self.score_mod_tensor_names == other.score_mod_tensor_names + and self.score_mod_bprop_tensor_names == other.score_mod_bprop_tensor_names + and self.score_mod_scalars == other.score_mod_scalars + and self.score_mod_bprop_scalars == other.score_mod_bprop_scalars + and self.scaling_factor == other.scaling_factor + and self.is_training == other.is_training + and self.deterministic == other.deterministic + ) + + +@dataclass(frozen=True) +class _SerializedScoreModGraph: + """Serialized cuDNN frontend graph and static metadata for C++ execution.""" + + serialized_graph: bytes + graph_hash: Tuple[int, int] + cudnn_frontend_version: int + workspace_size: int + input_uids: np.ndarray + output_uids: np.ndarray + scalar_uids: np.ndarray + scalar_sizes: np.ndarray + scalar_values: np.ndarray + + +# cuDNN frontend tensor UIDs are arbitrary, but assigning stable values makes serialized +# graphs deterministic and simplifies debugging across the Python graph builder and C++ executor. +_SCORE_MOD_UID_Q = 1 +_SCORE_MOD_UID_K = 2 +_SCORE_MOD_UID_V = 3 +_SCORE_MOD_UID_O = 4 +_SCORE_MOD_UID_STATS = 5 +_SCORE_MOD_UID_DO = 6 +_SCORE_MOD_UID_DQ = 7 +_SCORE_MOD_UID_DK = 8 +_SCORE_MOD_UID_DV = 9 +_SCORE_MOD_FWD_TENSOR_UID_BASE = 1000 +_SCORE_MOD_BPROP_TENSOR_UID_BASE = 2000 +_SCORE_MOD_FWD_SCALAR_UID_BASE = 3000 +_SCORE_MOD_BPROP_SCALAR_UID_BASE = 4000 + +_score_mod_graph_cache: Dict[Tuple[Any, ...], _SerializedScoreModGraph] = {} + + +def _row_major_stride(shape: Sequence[int]) -> Tuple[int, ...]: + stride = [] + running = 1 + for dim in reversed(tuple(shape)): + stride.append(running) + running *= dim + return tuple(reversed(stride)) + + +def _bshd_as_bhsd_dim_stride(shape: Sequence[int]) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + if len(shape) != 4: + raise ValueError(f"score_mod requires rank-4 BSHD tensors, got shape={shape}.") + batch, seqlen, heads, head_dim = tuple(shape) + return ( + (batch, heads, seqlen, head_dim), + (seqlen * heads * head_dim, head_dim, heads * head_dim, 1), + ) + + +def _dtype_name(dtype) -> str: + return str(jnp.dtype(dtype)) + + +def _is_array_operand(value: Any) -> bool: + return ( + hasattr(value, "shape") + and hasattr(value, "dtype") + and not isinstance(value, (bool, int, float, complex, np.generic)) + ) + + +def _scalar_to_spec(name: str, value: Any) -> _ScoreModScalarSpec: + if isinstance(value, bool): + dtype = np.bool_ + elif isinstance(value, int): + dtype = np.int32 + elif isinstance(value, float): + dtype = np.float32 + elif isinstance(value, np.generic): + dtype = value.dtype + else: + scalar = np.asarray(value) + if scalar.shape != (): + raise ValueError( + f"score_mod tensor '{name}' is neither a JAX array nor a scalar pass-by-value." + ) + dtype = scalar.dtype + + scalar = np.full((1, 1, 1, 1), value, dtype=dtype) + return _ScoreModScalarSpec(name=name, dtype=str(scalar.dtype), value=scalar.tobytes()) + + +def _split_score_mod_tensors( + tensors: Optional[Mapping[str, Any]], *, argument_name: str +) -> Tuple[Tuple[str, ...], Tuple[jnp.ndarray, ...], Tuple[_ScoreModScalarSpec, ...]]: + if tensors is None: + return (), (), () + if not isinstance(tensors, Mapping): + raise TypeError(f"{argument_name} must be a mapping from string names to tensors/scalars.") + + names = [] + operands = [] + scalars = [] + for name, value in tensors.items(): + if not isinstance(name, str): + raise TypeError(f"{argument_name} keys must be strings, got {type(name).__name__}.") + if _is_array_operand(value): + if len(value.shape) == 0: + raise ValueError( + f"{argument_name}['{name}'] is a rank-0 array. Use a Python/NumPy scalar " + "for cuDNN pass-by-value scalars, or reshape it to a tensor." + ) + names.append(name) + operands.append(jnp.asarray(value)) + else: + scalars.append(_scalar_to_spec(name, value)) + return tuple(names), tuple(operands), tuple(scalars) + + +def make_fused_attn_score_mod_config( + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Mapping[str, Any]], + score_mod_bprop_tensors: Optional[Mapping[str, Any]], + scaling_factor: float, + is_training: bool, +) -> Tuple[_FusedAttnScoreModConfig, Tuple[jnp.ndarray, ...], Tuple[jnp.ndarray, ...]]: + """Normalize score_mod operands and create a static graph-build config.""" + if not callable(score_mod): + raise TypeError("score_mod must be callable.") + if score_mod_bprop is not None and not callable(score_mod_bprop): + raise TypeError("score_mod_bprop must be callable when provided.") + if score_mod_bprop is None and score_mod_bprop_tensors: + raise ValueError("score_mod_bprop_tensors requires score_mod_bprop to be provided.") + + tensor_names, tensor_operands, scalars = _split_score_mod_tensors( + score_mod_tensors, argument_name="score_mod_tensors" + ) + bprop_tensor_names, bprop_tensor_operands, bprop_scalars = _split_score_mod_tensors( + score_mod_bprop_tensors, argument_name="score_mod_bprop_tensors" + ) + config = _FusedAttnScoreModConfig( + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, + score_mod_key=_score_mod_callback_cache_key(score_mod), + score_mod_bprop_key=_score_mod_callback_cache_key(score_mod_bprop), + score_mod_tensor_names=tensor_names, + score_mod_bprop_tensor_names=bprop_tensor_names, + score_mod_scalars=scalars, + score_mod_bprop_scalars=bprop_scalars, + scaling_factor=float(scaling_factor), + is_training=bool(is_training), + deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), + ) + return config, tensor_operands, bprop_tensor_operands + + +def _cudnn_data_type(cudnn, dtype): + dtype = jnp.dtype(dtype) + if dtype == jnp.float16: + return cudnn.data_type.HALF + if dtype == jnp.bfloat16: + return cudnn.data_type.BFLOAT16 + if dtype == jnp.float32: + return cudnn.data_type.FLOAT + if dtype == jnp.float64: + return cudnn.data_type.DOUBLE + if dtype == jnp.int32: + return cudnn.data_type.INT32 + if dtype == jnp.int64: + return cudnn.data_type.INT64 + if dtype == jnp.uint8: + return cudnn.data_type.UINT8 + if dtype == jnp.bool_: + return cudnn.data_type.BOOLEAN + raise ValueError(f"Unsupported score_mod tensor dtype: {dtype}.") + + +def _cudnn_data_type_from_name(cudnn, dtype_name: str): + if dtype_name == "bfloat16": + return cudnn.data_type.BFLOAT16 + return _cudnn_data_type(cudnn, np.dtype(dtype_name)) + + +def _graph_tensor_from_aval(cudnn, graph, name: str, aval, uid: int): + shape = tuple(int(dim) for dim in aval.shape) + return graph.tensor( + name=name, + dim=shape, + stride=_row_major_stride(shape), + data_type=_cudnn_data_type(cudnn, aval.dtype), + uid=uid, + ) + + +def _score_mod_graph_tensors( + cudnn, + graph, + names: Tuple[str, ...], + avals: Sequence[Any], + scalars: Tuple[_ScoreModScalarSpec, ...], + tensor_uid_base: int, + scalar_uid_base: int, +): + graph_tensors = {} + tensor_uids = [] + for index, (name, aval) in enumerate(zip(names, avals)): + uid = tensor_uid_base + index + graph_tensors[name] = _graph_tensor_from_aval(cudnn, graph, name, aval, uid) + tensor_uids.append(uid) + + scalar_uids = [] + scalar_values = [] + for index, scalar in enumerate(scalars): + uid = scalar_uid_base + index + graph_tensors[scalar.name] = graph.tensor( + name=scalar.name, + dim=scalar.dim, + stride=scalar.stride, + is_pass_by_value=True, + data_type=_cudnn_data_type_from_name(cudnn, scalar.dtype), + uid=uid, + ) + scalar_uids.append(uid) + scalar_values.append(scalar.value) + + return graph_tensors, tuple(tensor_uids), tuple(scalar_uids), tuple(scalar_values) + + +def _encode_cudnn_frontend_version(version: str) -> int: + public_version = version.split("+", 1)[0].split("-", 1)[0] + parts = public_version.split(".") + if len(parts) < 3: + raise RuntimeError(f"Could not parse cuDNN frontend Python version: {version!r}.") + major, minor, patch = (int(part) for part in parts[:3]) + return major * 10000 + minor * 100 + patch + + +def _check_cudnn_frontend_version_match(cudnn) -> int: + python_version_string = getattr(cudnn, "__version__", None) + if python_version_string is None: + raise RuntimeError("cuDNN frontend Python package does not expose __version__.") + python_version = _encode_cudnn_frontend_version(python_version_string) + cpp_version = int(transformer_engine_jax.get_cudnn_frontend_version()) + if python_version != cpp_version: + raise RuntimeError( + "cuDNN frontend Python/C++ version mismatch for score_mod graph serialization: " + f"Python cudnn.__version__={python_version_string!r} encodes to {python_version}, " + f"but Transformer Engine C++ was built with CUDNN_FRONTEND_VERSION={cpp_version}. " + "Use matching cuDNN frontend Python package and C++ headers." + ) + return python_version + + +def _score_mod_graph_hash(serialized_graph: bytes) -> Tuple[int, int]: + digest = hashlib.sha256(serialized_graph).digest() + return ( + int.from_bytes(digest[0:8], byteorder="little", signed=True), + int.from_bytes(digest[8:16], byteorder="little", signed=True), + ) + + +def _pack_score_mod_scalar_values( + scalar_values: Sequence[bytes], +) -> Tuple[np.ndarray, np.ndarray]: + scalar_sizes = np.asarray([len(value) for value in scalar_values], dtype=np.int64) + packed_values = np.zeros((len(scalar_values), 16), dtype=np.uint8) + for index, value in enumerate(scalar_values): + if len(value) > 16: + raise ValueError("score_mod pass-by-value scalars must be at most 16 bytes.") + packed_values[index, : len(value)] = np.frombuffer(value, dtype=np.uint8) + return scalar_sizes, packed_values.reshape(-1) + + +def _serialized_score_mod_graph( + *, + serialized_graph: bytes, + cudnn_frontend_version: int, + workspace_size: int, + input_uids: Sequence[int], + output_uids: Sequence[int], + scalar_uids: Sequence[int], + scalar_values: Sequence[bytes], +) -> _SerializedScoreModGraph: + scalar_sizes, packed_scalar_values = _pack_score_mod_scalar_values(scalar_values) + return _SerializedScoreModGraph( + serialized_graph=serialized_graph, + graph_hash=_score_mod_graph_hash(serialized_graph), + cudnn_frontend_version=int(cudnn_frontend_version), + workspace_size=int(workspace_size), + input_uids=np.asarray(input_uids, dtype=np.int64), + output_uids=np.asarray(output_uids, dtype=np.int64), + scalar_uids=np.asarray(scalar_uids, dtype=np.int64), + scalar_sizes=scalar_sizes, + scalar_values=packed_scalar_values, + ) + + +def _wrap_score_mod(score_mod: Optional[Callable], graph_tensors: Dict[str, Any]): + if score_mod is None: + return None + + def wrapped_score_mod(sdpa_graph, score_tensor): + return score_mod(sdpa_graph, score_tensor, graph_tensors) + + return wrapped_score_mod + + +def _finalize_score_mod_graph(cudnn, graph) -> Tuple[int, bytes, int]: + graph.validate() + graph.build_operation_graph() + try: + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + except cudnn.cudnnGraphNotSupportedError as exc: + raise RuntimeError(f"cuDNN score_mod SDPA graph is not supported: {exc}") from exc + graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE) + serialized_graph = bytes(graph.serialize()) + return ( + max(int(graph.get_workspace_size()), 1), + serialized_graph, + _check_cudnn_frontend_version_match(cudnn), + ) + + +def _graph_cache_key( + direction: str, + config: _FusedAttnScoreModConfig, + avals: Sequence[Any], +) -> Optional[Tuple[Any, ...]]: + if _score_mod_key_is_uncacheable(config.score_mod_key) or _score_mod_key_is_uncacheable( + config.score_mod_bprop_key + ): + return None + return ( + direction, + config, + tuple((tuple(aval.shape), _dtype_name(aval.dtype)) for aval in avals), + ) + + +def _shape_dtype(value) -> jax.ShapeDtypeStruct: + return jax.ShapeDtypeStruct(tuple(value.shape), value.dtype) + + +def _import_cudnn_for_score_mod(): + cudnn_frontend_path = str(_CUDNN_FRONTEND_PYTHON_PATH) + cudnn_frontend_package = _CUDNN_FRONTEND_PYTHON_PATH / "cudnn" + if ( + any(cudnn_frontend_package.glob("_compiled_module*")) + and cudnn_frontend_path not in sys.path + ): + sys.path.insert(0, cudnn_frontend_path) + try: + cudnn = importlib.import_module("cudnn") + except ImportError as exc: + raise ImportError( + "score_mod fused_attn requires the cuDNN frontend Python package (`cudnn`)." + ) from exc + _check_cudnn_frontend_version_match(cudnn) + return cudnn + + +def _build_score_mod_fwd_graph(q_aval, k_aval, v_aval, score_mod_avals, config): + cudnn = _import_cudnn_for_score_mod() + + io_data_type = _cudnn_data_type(cudnn, q_aval.dtype) + graph = cudnn.pygraph( + io_data_type=io_data_type, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape) + k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape) + v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape) + q = graph.tensor( + name="q", dim=q_dim, stride=q_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_Q + ) + k = graph.tensor( + name="k", dim=k_dim, stride=k_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_K + ) + v = graph.tensor( + name="v", dim=v_dim, stride=v_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_V + ) + + score_mod_graph_tensors, tensor_uids, scalar_uids, scalar_values = _score_mod_graph_tensors( + cudnn, + graph, + config.score_mod_tensor_names, + score_mod_avals, + config.score_mod_scalars, + _SCORE_MOD_FWD_TENSOR_UID_BASE, + _SCORE_MOD_FWD_SCALAR_UID_BASE, + ) + + output, stats = graph.sdpa( + name="te_score_mod_sdpa", + q=q, + k=k, + v=v, + generate_stats=config.is_training, + attn_scale=config.scaling_factor, + use_causal_mask=False, + score_mod=_wrap_score_mod(config.score_mod, score_mod_graph_tensors), + ) + + batch, q_seqlen, q_heads, _ = q_aval.shape + _, _, _, v_head_dim = v_aval.shape + output_dim, output_stride = _bshd_as_bhsd_dim_stride((batch, q_seqlen, q_heads, v_head_dim)) + output.set_output(True).set_uid(_SCORE_MOD_UID_O).set_dim(output_dim).set_stride(output_stride) + output.set_data_type(io_data_type) + + output_uids = [_SCORE_MOD_UID_O] + if config.is_training: + stats_shape = (batch, q_heads, q_seqlen, 1) + stats.set_output(True).set_uid(_SCORE_MOD_UID_STATS).set_dim(stats_shape).set_stride( + _row_major_stride(stats_shape) + ) + stats.set_data_type(cudnn.data_type.FLOAT) + output_uids.append(_SCORE_MOD_UID_STATS) + + workspace_size, serialized_graph, frontend_version = _finalize_score_mod_graph(cudnn, graph) + return _serialized_score_mod_graph( + serialized_graph=serialized_graph, + cudnn_frontend_version=frontend_version, + workspace_size=workspace_size, + input_uids=[_SCORE_MOD_UID_Q, _SCORE_MOD_UID_K, _SCORE_MOD_UID_V, *tensor_uids], + output_uids=output_uids, + scalar_uids=scalar_uids, + scalar_values=scalar_values, + ) + + +def _build_score_mod_bwd_graph( + q_aval, + k_aval, + v_aval, + output_aval, + doutput_aval, + stats_aval, + score_mod_avals, + score_mod_bprop_avals, + config, +): + cudnn = _import_cudnn_for_score_mod() + + io_data_type = _cudnn_data_type(cudnn, q_aval.dtype) + graph = cudnn.pygraph( + io_data_type=io_data_type, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape) + k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape) + v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape) + o_dim, o_stride = _bshd_as_bhsd_dim_stride(output_aval.shape) + do_dim, do_stride = _bshd_as_bhsd_dim_stride(doutput_aval.shape) + q = graph.tensor( + name="q", dim=q_dim, stride=q_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_Q + ) + k = graph.tensor( + name="k", dim=k_dim, stride=k_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_K + ) + v = graph.tensor( + name="v", dim=v_dim, stride=v_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_V + ) + output = graph.tensor( + name="o", dim=o_dim, stride=o_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_O + ) + doutput = graph.tensor( + name="dO", dim=do_dim, stride=do_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_DO + ) + stats = graph.tensor( + name="stats", + dim=tuple(int(dim) for dim in stats_aval.shape), + stride=_row_major_stride(stats_aval.shape), + data_type=cudnn.data_type.FLOAT, + uid=_SCORE_MOD_UID_STATS, + ) + + score_mod_graph_tensors, tensor_uids, scalar_uids, scalar_values = _score_mod_graph_tensors( + cudnn, + graph, + config.score_mod_tensor_names, + score_mod_avals, + config.score_mod_scalars, + _SCORE_MOD_FWD_TENSOR_UID_BASE, + _SCORE_MOD_FWD_SCALAR_UID_BASE, + ) + ( + score_mod_bprop_graph_tensors, + bprop_tensor_uids, + bprop_scalar_uids, + bprop_scalar_values, + ) = _score_mod_graph_tensors( + cudnn, + graph, + config.score_mod_bprop_tensor_names, + score_mod_bprop_avals, + config.score_mod_bprop_scalars, + _SCORE_MOD_BPROP_TENSOR_UID_BASE, + _SCORE_MOD_BPROP_SCALAR_UID_BASE, + ) + + dq, dk, dv = graph.sdpa_backward( + name="te_score_mod_sdpa_backward", + q=q, + k=k, + v=v, + o=output, + dO=doutput, + stats=stats, + attn_scale=config.scaling_factor, + use_causal_mask=False, + score_mod=_wrap_score_mod(config.score_mod, score_mod_graph_tensors), + score_mod_bprop=_wrap_score_mod(config.score_mod_bprop, score_mod_bprop_graph_tensors), + use_deterministic_algorithm=config.deterministic, + ) + + dq.set_output(True).set_uid(_SCORE_MOD_UID_DQ).set_dim(q_dim).set_stride(q_stride) + dk.set_output(True).set_uid(_SCORE_MOD_UID_DK).set_dim(k_dim).set_stride(k_stride) + dv.set_output(True).set_uid(_SCORE_MOD_UID_DV).set_dim(v_dim).set_stride(v_stride) + + workspace_size, serialized_graph, frontend_version = _finalize_score_mod_graph(cudnn, graph) + return _serialized_score_mod_graph( + serialized_graph=serialized_graph, + cudnn_frontend_version=frontend_version, + workspace_size=workspace_size, + input_uids=[ + _SCORE_MOD_UID_Q, + _SCORE_MOD_UID_K, + _SCORE_MOD_UID_V, + _SCORE_MOD_UID_O, + _SCORE_MOD_UID_DO, + _SCORE_MOD_UID_STATS, + *tensor_uids, + *bprop_tensor_uids, + ], + output_uids=[_SCORE_MOD_UID_DQ, _SCORE_MOD_UID_DK, _SCORE_MOD_UID_DV], + scalar_uids=[*scalar_uids, *bprop_scalar_uids], + scalar_values=[*scalar_values, *bprop_scalar_values], + ) + + +def fused_attn_score_mod_fwd( + qkv: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], + score_mod_tensors: Tuple[jnp.ndarray, ...], + config: _FusedAttnScoreModConfig, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Run cuDNN frontend SDPA forward with a score_mod callback.""" + q, k, v = qkv + q_aval, k_aval, v_aval = map(_shape_dtype, (q, k, v)) + score_mod_avals = tuple(_shape_dtype(arg) for arg in score_mod_tensors) + key = _graph_cache_key("fwd", config, (q_aval, k_aval, v_aval, *score_mod_avals)) + if key is None: + graph = _build_score_mod_fwd_graph(q_aval, k_aval, v_aval, score_mod_avals, config) + else: + if key not in _score_mod_graph_cache: + _score_mod_graph_cache[key] = _build_score_mod_fwd_graph( + q_aval, k_aval, v_aval, score_mod_avals, config + ) + graph = _score_mod_graph_cache[key] + + batch, q_seqlen, q_heads, _ = q.shape + _, _, _, v_head_dim = v.shape + output_shape = jax.ShapeDtypeStruct((batch, q_seqlen, q_heads, v_head_dim), q.dtype) + stats_shape = (batch, q_heads, q_seqlen, 1) if config.is_training else (0,) + stats = jax.ShapeDtypeStruct(stats_shape, jnp.float32) + workspace = jax.ShapeDtypeStruct((graph.workspace_size,), jnp.uint8) + output, softmax_stats, _ = ffi.ffi_call( + "te_fused_attn_score_mod_forward_ffi", + (output_shape, stats, workspace), + )( + q, + k, + v, + *score_mod_tensors, + serialized_graph=graph.serialized_graph, + graph_hash0=graph.graph_hash[0], + graph_hash1=graph.graph_hash[1], + cudnn_frontend_version=graph.cudnn_frontend_version, + input_uids=graph.input_uids, + output_uids=graph.output_uids, + scalar_uids=graph.scalar_uids, + scalar_sizes=graph.scalar_sizes, + scalar_values=graph.scalar_values, + ) + return output, softmax_stats + + +def fused_attn_score_mod_bwd( + qkv: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], + output: jnp.ndarray, + doutput: jnp.ndarray, + softmax_stats: jnp.ndarray, + score_mod_tensors: Tuple[jnp.ndarray, ...], + score_mod_bprop_tensors: Tuple[jnp.ndarray, ...], + config: _FusedAttnScoreModConfig, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Run cuDNN frontend SDPA backward with score_mod callbacks.""" + if not config.is_training: + raise RuntimeError("score_mod backward requires fused_attn(..., is_training=True).") + + q, k, v = qkv + all_inputs = (q, k, v, output, doutput, softmax_stats, *score_mod_tensors) + all_inputs = (*all_inputs, *score_mod_bprop_tensors) + avals = tuple(_shape_dtype(arg) for arg in all_inputs) + key = _graph_cache_key("bwd", config, avals) + if key is None: + graph = _build_score_mod_bwd_graph( + *avals[:6], + avals[6 : 6 + len(score_mod_tensors)], + avals[6 + len(score_mod_tensors) :], + config, + ) + else: + if key not in _score_mod_graph_cache: + _score_mod_graph_cache[key] = _build_score_mod_bwd_graph( + *avals[:6], + avals[6 : 6 + len(score_mod_tensors)], + avals[6 + len(score_mod_tensors) :], + config, + ) + graph = _score_mod_graph_cache[key] + + dq = jax.ShapeDtypeStruct(q.shape, q.dtype) + dk = jax.ShapeDtypeStruct(k.shape, k.dtype) + dv = jax.ShapeDtypeStruct(v.shape, v.dtype) + workspace = jax.ShapeDtypeStruct((graph.workspace_size,), jnp.uint8) + dq, dk, dv, _ = ffi.ffi_call( + "te_fused_attn_score_mod_backward_ffi", + (dq, dk, dv, workspace), + )( + q, + k, + v, + output, + doutput, + softmax_stats, + *score_mod_tensors, + *score_mod_bprop_tensors, + serialized_graph=graph.serialized_graph, + graph_hash0=graph.graph_hash[0], + graph_hash1=graph.graph_hash[1], + cudnn_frontend_version=graph.cudnn_frontend_version, + input_uids=graph.input_uids, + output_uids=graph.output_uids, + scalar_uids=graph.scalar_uids, + scalar_sizes=graph.scalar_sizes, + scalar_values=graph.scalar_values, + ) + return dq, dk, dv + + def generate_cu_seqlen(actual_seqlen): """ Generating cumsum seqlen for a batch diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 2ecfedc8a2..f886c4ed4d 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -146,6 +146,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnScoreModForwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnScoreModBackwardHandler); + NVTE_Fused_Attn_Backend GetFusedAttnBackend( bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index ed136d7b9e..e74b78df17 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -4,6 +4,19 @@ * See LICENSE for license information. ************************************************************************/ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + #include "../extensions.h" #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" @@ -689,5 +702,242 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, .Attrs(), FFI_CudaGraph_Traits); +namespace { + +struct ScoreModScalarStorage { + alignas(16) std::array data{}; + size_t size = 0; +}; + +struct ScoreModGraphCacheKey { + int device_id = 0; + int64_t hash0 = 0; + int64_t hash1 = 0; + int64_t frontend_version = 0; + + bool operator==(const ScoreModGraphCacheKey &other) const { + return device_id == other.device_id && hash0 == other.hash0 && hash1 == other.hash1 && + frontend_version == other.frontend_version; + } +}; + +struct ScoreModGraphCacheKeyHash { + size_t operator()(const ScoreModGraphCacheKey &key) const { + size_t seed = std::hash{}(key.device_id); + auto combine = [&seed](int64_t value) { + seed ^= std::hash{}(value) + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2); + }; + combine(key.hash0); + combine(key.hash1); + combine(key.frontend_version); + return seed; + } +}; + +using ScoreModGraphPtr = std::shared_ptr; + +std::unordered_map & +getScoreModeGraphCache() { + static std::unordered_map + cache; + return cache; +} + +std::mutex &getScoreModGraphCacheMutex() { + static std::mutex mutex; + return mutex; +} + +struct ScoreModCudnnHandleCache { + std::unordered_map handles; + + cudnnHandle_t GetHandle() { + int device_id = 0; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + auto it = handles.find(device_id); + if (it == handles.end()) { + cudnnHandle_t handle = nullptr; + NVTE_CHECK_CUDNN(cudnnCreate(&handle)); + it = handles.emplace(device_id, handle).first; + } + return it->second; + } + + ~ScoreModCudnnHandleCache() { + for (auto &[_, handle] : handles) { + cudnnDestroy(handle); + } + } +}; + +cudnnHandle_t GetScoreModCudnnHandle() { + static thread_local ScoreModCudnnHandleCache cache; + return cache.GetHandle(); +} + +ScoreModGraphCacheKey GetScoreModGraphCacheKey(Dictionary &attrs) { + const int64_t frontend_version = get_attr_value(attrs, "cudnn_frontend_version"); + NVTE_CHECK(frontend_version == CUDNN_FRONTEND_VERSION, + "cuDNN frontend version mismatch for score_mod graph deserialization: graph was " + "serialized with Python cuDNN frontend version ", + frontend_version, + ", but Transformer Engine C++ was built with CUDNN_FRONTEND_VERSION ", + CUDNN_FRONTEND_VERSION, "."); + + int device_id = 0; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + return ScoreModGraphCacheKey{ + device_id, + get_attr_value(attrs, "graph_hash0"), + get_attr_value(attrs, "graph_hash1"), + frontend_version, + }; +} + +ScoreModGraphPtr GetScoreModGraph(cudaStream_t stream, Dictionary &attrs) { + const auto key = GetScoreModGraphCacheKey(attrs); + { + std::lock_guard lock(getScoreModGraphCacheMutex()); + auto &cache = getScoreModeGraphCache(); + auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + } + + const auto serialized_graph = get_attr_value(attrs, "serialized_graph"); + std::vector serialized_data(serialized_graph.begin(), serialized_graph.end()); + + auto handle = GetScoreModCudnnHandle(); + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + + auto graph = std::make_shared(); + auto status = graph->deserialize(handle, serialized_data); + NVTE_CHECK(status.is_good(), + "Failed to deserialize cuDNN score_mod SDPA graph: ", status.get_message()); + + std::lock_guard lock(getScoreModGraphCacheMutex()); + auto &cache = getScoreModeGraphCache(); + auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + cache.emplace(key, graph); + return graph; +} + +Error_Type ExecuteScoreModGraph(cudaStream_t stream, Dictionary &attrs, + const std::vector &input_ptrs, + const std::vector &output_ptrs, void *workspace) { + auto graph = GetScoreModGraph(stream, attrs); + auto input_uids = get_attr_value>(attrs, "input_uids"); + auto output_uids = get_attr_value>(attrs, "output_uids"); + auto scalar_uids = get_attr_value>(attrs, "scalar_uids"); + auto scalar_sizes = get_attr_value>(attrs, "scalar_sizes"); + auto scalar_values = get_attr_value>(attrs, "scalar_values"); + + NVTE_CHECK(input_ptrs.size() == input_uids.size(), "cuDNN score_mod graph expected ", + input_uids.size(), " inputs but got ", input_ptrs.size()); + NVTE_CHECK(output_ptrs.size() >= output_uids.size(), "cuDNN score_mod graph expected at least ", + output_uids.size(), " outputs but got ", output_ptrs.size()); + NVTE_CHECK(scalar_uids.size() == scalar_sizes.size(), + "Mismatched score_mod scalar uid/value-size counts."); + NVTE_CHECK(scalar_values.size() == scalar_uids.size() * 16, + "Mismatched score_mod packed scalar value size."); + + std::unordered_map variant_pack; + for (size_t i = 0; i < input_uids.size(); ++i) { + variant_pack.emplace(input_uids[i], input_ptrs[i]); + } + for (size_t i = 0; i < output_uids.size(); ++i) { + variant_pack.emplace(output_uids[i], output_ptrs[i]); + } + + std::vector scalar_storage(scalar_uids.size()); + for (size_t i = 0; i < scalar_uids.size(); ++i) { + NVTE_CHECK(scalar_sizes[i] >= 0 && scalar_sizes[i] <= 16, + "score_mod pass-by-value scalars must be at most 16 bytes."); + scalar_storage[i].size = static_cast(scalar_sizes[i]); + std::copy_n(scalar_values.begin() + i * 16, 16, scalar_storage[i].data.begin()); + variant_pack.emplace(scalar_uids[i], scalar_storage[i].data.data()); + } + + auto handle = GetScoreModCudnnHandle(); + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + auto status = graph->execute(handle, variant_pack, workspace); + NVTE_CHECK(status.is_good(), + "cuDNN score_mod SDPA graph execution failed: ", status.get_message()); + return ffi_with_cuda_error_check(); +} + +void AppendRemainingBuffers(Variadic_Buffer_Type args, std::vector *ptrs) { + ptrs->reserve(ptrs->size() + args.size()); + for (size_t i = 0; i < args.size(); ++i) { + auto maybe_buf = args.get(i); + NVTE_CHECK(!maybe_buf.has_error(), "Failed to decode variadic score_mod input buffer."); + ptrs->push_back(maybe_buf.value().untyped_data()); + } +} + +} // namespace + +Error_Type FusedAttnScoreModForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Variadic_Buffer_Type score_mod_args, + Result_Type output_buf, Result_Type stats_buf, + Result_Type workspace_buf, Dictionary attrs) { + std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), + v_buf.untyped_data()}; + AppendRemainingBuffers(score_mod_args, &input_ptrs); + + std::vector output_ptrs = {output_buf->untyped_data(), stats_buf->untyped_data()}; + return ExecuteScoreModGraph(stream, attrs, input_ptrs, output_ptrs, + workspace_buf->untyped_data()); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnScoreModForwardHandler, FusedAttnScoreModForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // q + .Arg() // k + .Arg() // v + .RemainingArgs() // score_mod tensor operands + .Ret() // output + .Ret() // stats + .Ret() // workspace + .Attrs()); + +Error_Type FusedAttnScoreModBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Buffer_Type output_buf, + Buffer_Type doutput_buf, Buffer_Type stats_buf, + Variadic_Buffer_Type score_mod_args, Result_Type dq_buf, + Result_Type dk_buf, Result_Type dv_buf, + Result_Type workspace_buf, Dictionary attrs) { + std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), + v_buf.untyped_data(), output_buf.untyped_data(), + doutput_buf.untyped_data(), stats_buf.untyped_data()}; + AppendRemainingBuffers(score_mod_args, &input_ptrs); + + std::vector output_ptrs = {dq_buf->untyped_data(), dk_buf->untyped_data(), + dv_buf->untyped_data()}; + return ExecuteScoreModGraph(stream, attrs, input_ptrs, output_ptrs, + workspace_buf->untyped_data()); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnScoreModBackwardHandler, FusedAttnScoreModBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // q + .Arg() // k + .Arg() // v + .Arg() // output + .Arg() // doutput + .Arg() // stats + .RemainingArgs() // score_mod tensor operands + .Ret() // dq + .Ret() // dk + .Ret() // dv + .Ret() // workspace + .Attrs()); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 70d0403b3e..4f6cd385be 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -64,6 +64,12 @@ pybind11::dict Registrations() { dict["te_fused_attn_backward_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler)); + dict["te_fused_attn_score_mod_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(FusedAttnScoreModForwardHandler)); + dict["te_fused_attn_score_mod_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(FusedAttnScoreModBackwardHandler)); // GEMM dict["te_gemm_ffi"] = @@ -112,6 +118,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cudnn_version", &GetCudnnRuntimeVersion); + m.def("get_cudnn_frontend_version", &GetCudnnFrontendVersion); m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_num_compute_streams", &nvte_get_num_compute_streams); m.def("get_cublasLt_version", &cublasLtGetVersion); diff --git a/transformer_engine/jax/csrc/extensions/utils.cpp b/transformer_engine/jax/csrc/extensions/utils.cpp index 52ab2edf0f..d8fe1762dd 100644 --- a/transformer_engine/jax/csrc/extensions/utils.cpp +++ b/transformer_engine/jax/csrc/extensions/utils.cpp @@ -6,6 +6,7 @@ #include "utils.h" #include +#include #include @@ -22,6 +23,8 @@ int GetCudaRuntimeVersion() { size_t GetCudnnRuntimeVersion() { return cudnnGetVersion(); } +size_t GetCudnnFrontendVersion() { return CUDNN_FRONTEND_VERSION; } + int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); } } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/utils.h b/transformer_engine/jax/csrc/extensions/utils.h index c55c8d86ce..939b0a881e 100644 --- a/transformer_engine/jax/csrc/extensions/utils.h +++ b/transformer_engine/jax/csrc/extensions/utils.h @@ -20,6 +20,7 @@ namespace jax { int GetCudaRuntimeVersion(); size_t GetCudnnRuntimeVersion(); +size_t GetCudnnFrontendVersion(); int GetDeviceComputeCapability(int gpu_id); class cudaDevicePropertiesManager { diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index a2e7920843..63b84288e8 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -8,7 +8,7 @@ from enum import Enum from math import sqrt import os -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union import warnings import jax @@ -305,6 +305,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT context_checkpoint_name: str = "context" softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX + score_mod: Optional[Callable] = None + score_mod_bprop: Optional[Callable] = None @nn.compact def __call__( @@ -317,6 +319,8 @@ def __call__( *, dropout_rng: Optional[PRNGKey] = None, deterministic: bool = False, + score_mod_tensors: Optional[Mapping[str, Any]] = None, + score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, ) -> Array: seed = None @@ -340,6 +344,32 @@ def __call__( jnp.float32, ) + score_mod_requested = ( + self.score_mod is not None + or self.score_mod_bprop is not None + or score_mod_tensors is not None + or score_mod_bprop_tensors is not None + ) + fused_attn_kwargs = { + "attn_mask_type": self.attn_mask_type, + "attn_bias_type": self.attn_bias_type, + "softmax_type": self.softmax_type, + "scaling_factor": scale_factor, + "dropout_probability": self.attention_dropout, + "is_training": not deterministic, + "window_size": self.window_size, + "max_segments_per_seq": self.max_segments_per_seq, + "context_parallel_causal_load_balanced": self.context_parallel_causal_load_balanced, + "context_parallel_axis": self.context_parallel_axis, + "context_parallel_strategy": self.context_parallel_strategy, + "context_checkpoint_name": self.context_checkpoint_name, + "softmax_offset": softmax_offset, + "score_mod": self.score_mod, + "score_mod_bprop": self.score_mod_bprop, + "score_mod_tensors": score_mod_tensors, + "score_mod_bprop_tensors": score_mod_bprop_tensors, + } + if self.qkv_layout.is_qkvpacked(): """qkvpacked format, treat query: qkvpacked tensor, shape = [..., 3, h, d] @@ -349,25 +379,23 @@ def __call__( qkv_packed = query if self.transpose_batch_sequence: qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4]) + if score_mod_requested: + query, key, value = jnp.split(qkv_packed, [1, 2], axis=-3) + query, key, value = map( + functools.partial(jnp.squeeze, axis=-3), [query, key, value] + ) + qkv_layout = self.qkv_layout.to_separate() + qkv_args = (query, key, value) + else: + qkv_layout = self.qkv_layout + qkv_args = (qkv_packed,) x = fused_attn( - (qkv_packed,), + qkv_args, bias, sequence_descriptor, seed, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - qkv_layout=self.qkv_layout, - softmax_type=self.softmax_type, - scaling_factor=scale_factor, - dropout_probability=self.attention_dropout, - is_training=not deterministic, - window_size=self.window_size, - max_segments_per_seq=self.max_segments_per_seq, - context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, - context_parallel_axis=self.context_parallel_axis, - context_parallel_strategy=self.context_parallel_strategy, - context_checkpoint_name=self.context_checkpoint_name, - softmax_offset=softmax_offset, + qkv_layout=qkv_layout, + **fused_attn_kwargs, ) elif self.qkv_layout.is_kvpacked(): """kvpacked format, treat @@ -379,25 +407,21 @@ def __call__( if self.transpose_batch_sequence: query = query.transpose([1, 0, 2, 3]) kv_packed = kv_packed.transpose([1, 0, 2, 3, 4]) + if score_mod_requested: + key, value = jnp.split(kv_packed, [1], axis=-3) + key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value]) + qkv_layout = self.qkv_layout.to_separate() + qkv_args = (query, key, value) + else: + qkv_layout = self.qkv_layout + qkv_args = (query, kv_packed) x = fused_attn( - (query, kv_packed), + qkv_args, bias, sequence_descriptor, seed, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - qkv_layout=self.qkv_layout, - softmax_type=self.softmax_type, - scaling_factor=scale_factor, - dropout_probability=self.attention_dropout, - is_training=not deterministic, - window_size=self.window_size, - max_segments_per_seq=self.max_segments_per_seq, - context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, - context_parallel_axis=self.context_parallel_axis, - context_parallel_strategy=self.context_parallel_strategy, - context_checkpoint_name=self.context_checkpoint_name, - softmax_offset=softmax_offset, + qkv_layout=qkv_layout, + **fused_attn_kwargs, ) elif self.qkv_layout.is_separate(): if self.transpose_batch_sequence: @@ -409,20 +433,8 @@ def __call__( bias, sequence_descriptor, seed, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, qkv_layout=self.qkv_layout, - softmax_type=self.softmax_type, - scaling_factor=scale_factor, - dropout_probability=self.attention_dropout, - is_training=not deterministic, - window_size=self.window_size, - max_segments_per_seq=self.max_segments_per_seq, - context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, - context_parallel_axis=self.context_parallel_axis, - context_parallel_strategy=self.context_parallel_strategy, - context_checkpoint_name=self.context_checkpoint_name, - softmax_offset=softmax_offset, + **fused_attn_kwargs, ) else: raise ValueError(f"Unsupported {self.qkv_layout=}.") @@ -602,6 +614,16 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention (``'zero sink'`` and ``'learnable sink'``). + score_mod: Optional[Callable], default = None + Experimental cuDNN frontend score modification callback. When set, fused attention is + required and the same score_mod limitations as :func:`fused_attn` apply. + score_mod_bprop: Optional[Callable], default = None + Optional score modification backward callback. + score_mod_tensors: Optional[Mapping[str, Any]], default = None + Additional tensors or pass-by-value scalars for ``score_mod``. This is a call-time + argument to keep tensor operands as normal JAX inputs. + score_mod_bprop_tensors: Optional[Mapping[str, Any]], default = None + Additional tensors or pass-by-value scalars for ``score_mod_bprop``. Optimization parameters ----------------------- @@ -628,6 +650,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods context_parallel_strategy: str = "DEFAULT" context_checkpoint_name: str = "context" softmax_type: str = "vanilla" + score_mod: Optional[Callable] = None + score_mod_bprop: Optional[Callable] = None def __post_init__(self): # TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12 @@ -669,6 +693,8 @@ def __call__( *, deterministic: bool = False, mask: Optional[Union[SequenceDescriptor, Array]] = None, + score_mod_tensors: Optional[Mapping[str, Any]] = None, + score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, ) -> Array: """ Parameters @@ -689,6 +715,10 @@ def __call__( Below parameters are keyword only deterministic: bool, default = False Disable dropout layers if set to True. + score_mod_tensors: Optional[Mapping[str, Any]], default = None + Additional tensor/scalar operands passed to ``score_mod``. + score_mod_bprop_tensors: Optional[Mapping[str, Any]], default = None + Additional tensor/scalar operands passed to ``score_mod_bprop``. Returns ------- @@ -715,6 +745,12 @@ def __call__( qkv_layout = QKVLayout[self.qkv_layout.upper()] softmax_type = AttnSoftmaxType.from_str(self.softmax_type) del self.attn_bias_type, self.attn_mask_type, self.qkv_layout + score_mod_requested = ( + self.score_mod is not None + or self.score_mod_bprop is not None + or score_mod_tensors is not None + or score_mod_bprop_tensors is not None + ) if attn_bias_type == AttnBiasType.NO_BIAS: assert ( @@ -760,25 +796,30 @@ def __call__( head_dim_qk = self.head_dim head_dim_v = self.head_dim - has_fused_attn_kernel = is_fused_attn_kernel_available( - # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. - not deterministic, - input_dtype, - # self._assert_dtypes enforces Q, K, V, bias to have the same dtype so using input_dtype as kv dtype is sufficient - input_dtype, - qkv_layout, - attn_bias_type, - attn_mask_type, - softmax_type, - self.attention_dropout, - self.num_attention_heads, - self.num_gqa_groups, - seqlen_q, - seqlen_kv, - head_dim_qk, - head_dim_v, - self.window_size, - ) + if score_mod_requested: + if not enable_fused_attn: + raise ValueError("score_mod requires fused attention, but NVTE_FUSED_ATTN=0.") + has_fused_attn_kernel = True + else: + has_fused_attn_kernel = is_fused_attn_kernel_available( + # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. + not deterministic, + input_dtype, + # self._assert_dtypes enforces Q, K, V, bias to have the same dtype so using input_dtype as kv dtype is sufficient + input_dtype, + qkv_layout, + attn_bias_type, + attn_mask_type, + softmax_type, + self.attention_dropout, + self.num_attention_heads, + self.num_gqa_groups, + seqlen_q, + seqlen_kv, + head_dim_qk, + head_dim_v, + self.window_size, + ) use_fused_attn = enable_fused_attn and has_fused_attn_kernel @@ -872,6 +913,8 @@ def __call__( context_parallel_strategy=context_parallel_strategy, context_checkpoint_name=self.context_checkpoint_name, softmax_type=softmax_type, + score_mod=self.score_mod, + score_mod_bprop=self.score_mod_bprop, )( query, key, @@ -880,6 +923,8 @@ def __call__( bias, dropout_rng=dropout_rng, deterministic=deterministic, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, ) assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}" return x @@ -1178,6 +1223,16 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention (``'zero sink'`` and ``'learnable sink'``). + score_mod: Optional[Callable], default = None + Experimental cuDNN frontend score modification callback. When set, fused attention is + required and the same score_mod limitations as :func:`fused_attn` apply. + score_mod_bprop: Optional[Callable], default = None + Optional score modification backward callback. + score_mod_tensors: Optional[Mapping[str, Any]], default = None + Additional tensors or pass-by-value scalars for ``score_mod``. This is a call-time + argument to keep tensor operands as normal JAX inputs. + score_mod_bprop_tensors: Optional[Mapping[str, Any]], default = None + Additional tensors or pass-by-value scalars for ``score_mod_bprop``. """ head_dim: int @@ -1210,6 +1265,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods float32_logits: bool = False window_size: Optional[Tuple[int, int]] = None softmax_type: str = "vanilla" + score_mod: Optional[Callable] = None + score_mod_bprop: Optional[Callable] = None # Deprecated parameters num_heads: Optional[int] = None @@ -1278,6 +1335,8 @@ def __call__( *, decode: bool = False, deterministic: bool = False, + score_mod_tensors: Optional[Mapping[str, Any]] = None, + score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, ) -> Array: """ MultiHeadAttention Layer: @@ -1300,6 +1359,10 @@ def __call__( Indicate whether to prepare and use an autoregressive cache. deterministic: bool, default = False Disable dropout layers if set to True. + score_mod_tensors: Optional[Mapping[str, Any]], default = None + Additional tensor/scalar operands passed to ``score_mod``. + score_mod_bprop_tensors: Optional[Mapping[str, Any]], default = None + Additional tensor/scalar operands passed to ``score_mod_bprop``. Returns ------- @@ -1631,7 +1694,16 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): transpose_batch_sequence=self.transpose_batch_sequence, window_size=self.window_size, softmax_type=self.softmax_type, - )(*dpa_args, mask, bias, deterministic=deterministic) + score_mod=self.score_mod, + score_mod_bprop=self.score_mod_bprop, + )( + *dpa_args, + mask, + bias, + deterministic=deterministic, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, + ) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES) @@ -1939,6 +2011,17 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention (``'zero sink'`` and ``'learnable sink'``). Only supported for fused attention backend. + score_mod: Optional[Callable], default = None + Experimental cuDNN frontend score modification callback for the self-attention block. + When set, fused attention is required and the same score_mod limitations as + :func:`fused_attn` apply. + score_mod_bprop: Optional[Callable], default = None + Optional score modification backward callback for the self-attention block. + score_mod_tensors: Optional[Mapping[str, Any]], default = None + Additional tensors or pass-by-value scalars for ``score_mod``. This is a call-time + argument to keep tensor operands as normal JAX inputs. + score_mod_bprop_tensors: Optional[Mapping[str, Any]], default = None + Additional tensors or pass-by-value scalars for ``score_mod_bprop``. Optimization parameters ----------------------- @@ -2005,6 +2088,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods scaled_query_init: bool = True window_size: Optional[Tuple[int, int]] = None softmax_type: str = "vanilla" + score_mod: Optional[Callable] = None + score_mod_bprop: Optional[Callable] = None def __post_init__(self): if self.mha_kernel_init is None: @@ -2029,6 +2114,8 @@ def __call__( deterministic: bool = False, decode: bool = False, max_decode_length: bool = None, + score_mod_tensors: Optional[Mapping[str, Any]] = None, + score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, ): """ Transformer Layer: attention block and a feedforward network (MLP) @@ -2057,6 +2144,10 @@ def __call__( The maximum length to generate relative embedding biases when :attr:`layer_type=TransformerLayerType.DECODER` and :attr:`enable_relative_embedding=True`. + score_mod_tensors: Optional[Mapping[str, Any]], default = None + Additional tensor/scalar operands passed to the self-attention ``score_mod``. + score_mod_bprop_tensors: Optional[Mapping[str, Any]], default = None + Additional tensor/scalar operands passed to the self-attention ``score_mod_bprop``. Returns ------- @@ -2168,7 +2259,18 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): name=mha_name, window_size=self.window_size, softmax_type=self.softmax_type, - )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode) + score_mod=self.score_mod, + score_mod_bprop=self.score_mod_bprop, + )( + inputs, + inputs, + attention_mask, + attn_bias, + deterministic=deterministic, + decode=decode, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, + ) def hidden_dropout(x, deterministic): assert isinstance(