Skip to content
Open
14 changes: 14 additions & 0 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ def setup_jax_extension(

# Header files
include_dirs = get_cuda_include_dirs()
cudnn_frontend_include_dir = None
for base_path in (Path(common_header_files), *Path(common_header_files).parents):
candidate = base_path / "3rdparty" / "cudnn-frontend" / "include"
if candidate.exists():
cudnn_frontend_include_dir = candidate
break
if cudnn_frontend_include_dir is None:
for base_path in Path(__file__).resolve().parents:
candidate = base_path / "3rdparty" / "cudnn-frontend" / "include"
if candidate.exists():
cudnn_frontend_include_dir = candidate
break
if cudnn_frontend_include_dir is not None:
include_dirs.append(cudnn_frontend_include_dir)
include_dirs.extend(
[
common_header_files,
Expand Down
124 changes: 122 additions & 2 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,25 @@
import pytest
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import (
generate_configs,
generate_context_parallel_configs_for_attn,
generate_collectives_count,
)
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
from utils import pytest_parametrize_wrapper
from test_fused_attn import (
FusedAttnRunner,
BiasShape,
SeqDescFormat,
_ScoreModSoftcap,
_has_cudnn_frontend_python,
_reference_attention,
_require_cudnn_frontend_score_mod,
)
from utils import assert_allclose, pytest_parametrize_wrapper
from transformer_engine.jax import autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
AttnBiasType,
Expand All @@ -25,6 +36,7 @@
inverse_reorder_causal_load_balancing,
CPStrategy,
ReorderStrategy,
fused_attn,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see below comment
Would be good to avoid this import here if possible
The CP tests also do not end up importing it but use it via the runner customcall_fused_dpa

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot fo the other imports like assert_allclose acan also be prevented by using the FusedAttnRunner nad setup_inputs in the fused_attn tests

)


Expand Down Expand Up @@ -272,6 +284,114 @@ def test_cross_attn(
runner.test_backward()


DISTRIBUTED_SCORE_MOD_DATA_SHAPES = {
"L0": [],
"L1": [(4, 16, 4, 64)],
"L2": [(4, 16, 4, 64)],
}
Comment on lines +289 to +291
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we running tests with the same shape ?
If only 1 shape is needed, only L1 should suffice - no need for L2 then



@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required")
class TestDistributedScoreModSelfAttn:
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SCORE_MOD_DATA_SHAPES)
@pytest.mark.parametrize("dtype", DTYPES)
def test_softcap_score_mod_with_aux_params_backward(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
):
_require_cudnn_frontend_score_mod()
batch, seqlen, num_heads, head_dim = data_shape
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a lot of input setup here,
Can we not use _setup_inputs() with minor modifications ?
You could also use _setup_inputs() and "undo" some of the setup in the test if you do not wish to modify setup_inputs() - that is fine too.

All tests in attention end up using it so would make sense to maintain uniformity if possible.
However, if it involves too much branching and changes for flex attn, I understand keeping it separate.

Additionally all tests end up using the FusedAttnRunner() - are you not able to use it for the flex tests with minor changes ? Using the runner also means that you do not need to import fused_attn in the distributed tests.
The current setup only imports the fused_attn in the test_fused_attn.py tests the the distributed tests setup and use it directly from there

The above two might help reduce duplication of code and maintain uniformity and the test inputs generated. Moreover, we would not want to add setup, breakdown for any new attn types we add, especially when we are using the same fused_attn API for all.

dp_axis = mesh_resource.dp_resource
tp_axis = mesh_resource.tpsp_resource

if dp_axis is not None:
dp_size = mesh_shape[mesh_axes.index(dp_axis)]
if batch % dp_size != 0:
pytest.skip(f"{batch=} must be divisible by {dp_size=}")
if tp_axis is not None:
tp_size = mesh_shape[mesh_axes.index(tp_axis)]
if num_heads % tp_size != 0:
pytest.skip(f"{num_heads=} must be divisible by {tp_size=}")

devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
qkv_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, tp_axis, None))

key = random.PRNGKey(2025)
q_key, k_key, v_key, dout_key = random.split(key, 4)
query = (0.125 * random.normal(q_key, data_shape, dtype=dtype)).astype(dtype)
key_tensor = (0.125 * random.normal(k_key, data_shape, dtype=dtype)).astype(dtype)
value = (0.125 * random.normal(v_key, data_shape, dtype=dtype)).astype(dtype)
doutput = random.normal(dout_key, data_shape, dtype=dtype)

scaling_factor = head_dim**-0.5
softcap = 0.8
softcap_score_mod = _ScoreModSoftcap()

def score_mod_loss(q, k, v, dout):
out = fused_attn(
(q, k, v),
None,
None,
None,
AttnBiasType.NO_BIAS,
AttnMaskType.NO_MASK,
QKVLayout.BSHD_BSHD_BSHD,
AttnSoftmaxType.VANILLA_SOFTMAX,
scaling_factor,
0.0,
True,
score_mod=softcap_score_mod.forward,
score_mod_bprop=softcap_score_mod.backward,
score_mod_tensors={"softcap": softcap},
score_mod_bprop_tensors={"softcap": softcap},
)
loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32))
return loss, out

def ref_loss(q, k, v, dout):
out = _reference_attention(q, k, v, scaling_factor, softcap=softcap)
loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32))
return loss, out

jitted_score_mod = jax.jit(
jax.value_and_grad(score_mod_loss, argnums=(0, 1, 2), has_aux=True),
in_shardings=(
qkv_sharding,
qkv_sharding,
qkv_sharding,
qkv_sharding,
),
out_shardings=((None, qkv_sharding), (qkv_sharding, qkv_sharding, qkv_sharding)),
)
jitted_ref = jax.jit(jax.value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True))

sharded_args = (
jax.device_put(query, qkv_sharding),
jax.device_put(key_tensor, qkv_sharding),
jax.device_put(value, qkv_sharding),
jax.device_put(doutput, qkv_sharding),
)
with mesh, autocast(mesh_resource=mesh_resource):
(score_mod_value, score_mod_out), score_mod_grads = jitted_score_mod(*sharded_args)
(ref_value, ref_out), ref_grads = jitted_ref(query, key_tensor, value, doutput)

assert score_mod_out.sharding == qkv_sharding
for grad in score_mod_grads:
assert grad.sharding == qkv_sharding

assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2)
assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2)
for grad, ref_grad in zip(score_mod_grads, ref_grads):
assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2)


DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
Expand Down
Loading
Loading