-
Notifications
You must be signed in to change notification settings - Fork 723
[JAX] Support for cuDNN-backed flex attention #2985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f967a26
6b05328
1a96352
3bf9e97
29bbac7
c597af5
f8bd844
2c01c5e
ba6a1a7
deebf8e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,14 +6,25 @@ | |
| import pytest | ||
| import jax | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
| from jax import random | ||
| from jax.sharding import Mesh, NamedSharding, PartitionSpec | ||
| from distributed_test_base import ( | ||
| generate_configs, | ||
| generate_context_parallel_configs_for_attn, | ||
| generate_collectives_count, | ||
| ) | ||
| from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat | ||
| from utils import pytest_parametrize_wrapper | ||
| from test_fused_attn import ( | ||
| FusedAttnRunner, | ||
| BiasShape, | ||
| SeqDescFormat, | ||
| _ScoreModSoftcap, | ||
| _has_cudnn_frontend_python, | ||
| _reference_attention, | ||
| _require_cudnn_frontend_score_mod, | ||
| ) | ||
| from utils import assert_allclose, pytest_parametrize_wrapper | ||
| from transformer_engine.jax import autocast | ||
| from transformer_engine.jax.attention import ( | ||
| is_fused_attn_kernel_available, | ||
| AttnBiasType, | ||
|
|
@@ -25,6 +36,7 @@ | |
| inverse_reorder_causal_load_balancing, | ||
| CPStrategy, | ||
| ReorderStrategy, | ||
| fused_attn, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -272,6 +284,114 @@ def test_cross_attn( | |
| runner.test_backward() | ||
|
|
||
|
|
||
| DISTRIBUTED_SCORE_MOD_DATA_SHAPES = { | ||
| "L0": [], | ||
| "L1": [(4, 16, 4, 64)], | ||
| "L2": [(4, 16, 4, 64)], | ||
| } | ||
|
Comment on lines
+289
to
+291
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we running tests with the same shape ? |
||
|
|
||
|
|
||
| @pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") | ||
| class TestDistributedScoreModSelfAttn: | ||
| @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) | ||
| @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SCORE_MOD_DATA_SHAPES) | ||
| @pytest.mark.parametrize("dtype", DTYPES) | ||
| def test_softcap_score_mod_with_aux_params_backward( | ||
| self, | ||
| device_count, | ||
| mesh_shape, | ||
| mesh_axes, | ||
| mesh_resource, | ||
| data_shape, | ||
| dtype, | ||
| ): | ||
| _require_cudnn_frontend_score_mod() | ||
| batch, seqlen, num_heads, head_dim = data_shape | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There seems to be a lot of input setup here, All tests in attention end up using it so would make sense to maintain uniformity if possible. Additionally all tests end up using the FusedAttnRunner() - are you not able to use it for the flex tests with minor changes ? Using the runner also means that you do not need to import fused_attn in the distributed tests. The above two might help reduce duplication of code and maintain uniformity and the test inputs generated. Moreover, we would not want to add setup, breakdown for any new attn types we add, especially when we are using the same fused_attn API for all. |
||
| dp_axis = mesh_resource.dp_resource | ||
| tp_axis = mesh_resource.tpsp_resource | ||
|
|
||
| if dp_axis is not None: | ||
| dp_size = mesh_shape[mesh_axes.index(dp_axis)] | ||
| if batch % dp_size != 0: | ||
| pytest.skip(f"{batch=} must be divisible by {dp_size=}") | ||
| if tp_axis is not None: | ||
| tp_size = mesh_shape[mesh_axes.index(tp_axis)] | ||
| if num_heads % tp_size != 0: | ||
| pytest.skip(f"{num_heads=} must be divisible by {tp_size=}") | ||
|
|
||
| devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) | ||
| mesh = Mesh(devices, mesh_axes) | ||
| qkv_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, tp_axis, None)) | ||
|
|
||
| key = random.PRNGKey(2025) | ||
| q_key, k_key, v_key, dout_key = random.split(key, 4) | ||
| query = (0.125 * random.normal(q_key, data_shape, dtype=dtype)).astype(dtype) | ||
| key_tensor = (0.125 * random.normal(k_key, data_shape, dtype=dtype)).astype(dtype) | ||
| value = (0.125 * random.normal(v_key, data_shape, dtype=dtype)).astype(dtype) | ||
| doutput = random.normal(dout_key, data_shape, dtype=dtype) | ||
|
|
||
| scaling_factor = head_dim**-0.5 | ||
| softcap = 0.8 | ||
| softcap_score_mod = _ScoreModSoftcap() | ||
|
|
||
| def score_mod_loss(q, k, v, dout): | ||
| out = fused_attn( | ||
| (q, k, v), | ||
| None, | ||
| None, | ||
| None, | ||
| AttnBiasType.NO_BIAS, | ||
| AttnMaskType.NO_MASK, | ||
| QKVLayout.BSHD_BSHD_BSHD, | ||
| AttnSoftmaxType.VANILLA_SOFTMAX, | ||
| scaling_factor, | ||
| 0.0, | ||
| True, | ||
| score_mod=softcap_score_mod.forward, | ||
| score_mod_bprop=softcap_score_mod.backward, | ||
| score_mod_tensors={"softcap": softcap}, | ||
| score_mod_bprop_tensors={"softcap": softcap}, | ||
| ) | ||
| loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32)) | ||
| return loss, out | ||
|
|
||
| def ref_loss(q, k, v, dout): | ||
| out = _reference_attention(q, k, v, scaling_factor, softcap=softcap) | ||
| loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32)) | ||
| return loss, out | ||
|
|
||
| jitted_score_mod = jax.jit( | ||
| jax.value_and_grad(score_mod_loss, argnums=(0, 1, 2), has_aux=True), | ||
| in_shardings=( | ||
| qkv_sharding, | ||
| qkv_sharding, | ||
| qkv_sharding, | ||
| qkv_sharding, | ||
| ), | ||
| out_shardings=((None, qkv_sharding), (qkv_sharding, qkv_sharding, qkv_sharding)), | ||
| ) | ||
| jitted_ref = jax.jit(jax.value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)) | ||
|
|
||
| sharded_args = ( | ||
| jax.device_put(query, qkv_sharding), | ||
| jax.device_put(key_tensor, qkv_sharding), | ||
| jax.device_put(value, qkv_sharding), | ||
| jax.device_put(doutput, qkv_sharding), | ||
| ) | ||
| with mesh, autocast(mesh_resource=mesh_resource): | ||
| (score_mod_value, score_mod_out), score_mod_grads = jitted_score_mod(*sharded_args) | ||
| (ref_value, ref_out), ref_grads = jitted_ref(query, key_tensor, value, doutput) | ||
|
|
||
| assert score_mod_out.sharding == qkv_sharding | ||
| for grad in score_mod_grads: | ||
| assert grad.sharding == qkv_sharding | ||
|
|
||
| assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2) | ||
| assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2) | ||
| for grad, ref_grad in zip(score_mod_grads, ref_grads): | ||
| assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2) | ||
|
|
||
|
|
||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [ | ||
| pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), | ||
| pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see below comment
Would be good to avoid this import here if possible
The CP tests also do not end up importing it but use it via the runner customcall_fused_dpa
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot fo the other imports like assert_allclose acan also be prevented by using the FusedAttnRunner nad setup_inputs in the fused_attn tests