Skip to content

[JAX] Support for cuDNN-backed flex attention#2985

Open
vcherepanov-nv wants to merge 10 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax
Open

[JAX] Support for cuDNN-backed flex attention#2985
vcherepanov-nv wants to merge 10 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

Description

This PR introduces an alternative code path for the FusedAttention backend for JAX.
The user can specify score_mod and score_mod_bprop functions, which get routed to the corresponding parameters of the sdpa and sdpa_backward calls to cuDNN FE.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • A new code path for FusedAttention backend, when score_mod (and the related parameters) is specified
  • Tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 13, 2026

Greptile Summary

This PR introduces a cuDNN-frontend-backed score_mod code path for JAX's FusedAttention. Users supply Python callbacks that are invoked during JAX tracing to build, serialize, and cache a cuDNN pygraph; the serialized graph is then passed as an FFI attribute to new C++ handlers that deserialize and execute it at runtime.

  • New fused_attn parameters: score_mod, score_mod_bprop, score_mod_tensors, and score_mod_bprop_tensors trigger a separate validation and dispatch path with strict mutual-exclusion checks against masks, bias, dropout, and context parallelism.
  • Python-side graph lifecycle: _FusedAttnScoreModConfig classifies callbacks into cacheable (module-level functions, keyed bound methods) and uncacheable buckets; the Python _score_mod_graph_cache and C++ SHA256-keyed deserialization cache together avoid repeated graph builds for stable shapes and topologies.
  • C++ execution path: Thread-local cuDNN handles, double-checked-locking graph deserialization, and variadic FFI buffers for auxiliary score-mod tensor operands.

Confidence Score: 4/5

The core score_mod dispatch, VJP wiring, C++ deserialization, and graph execution are all correct; the only new finding is a silent performance trap when lambdas or closures are used as score_mod callbacks — no wrong results, no crashes.

The new dispatch path has clean separation of concerns, thorough input validation, and correct JAX custom_vjp plumbing. The C++ double-checked-locking cache and thread-local handles are sound. No data-corruption or execution-correctness bugs were found beyond what is already under discussion in existing threads.

transformer_engine/jax/cpp_extensions/attention.py — the uncacheable-key path in _score_mod_callback_cache_key silently skips graph caching for lambdas and functools.partial without any diagnostic; transformer_engine/jax/csrc/extensions/attention.cpp — minor typo in getScoreModeGraphCache accessor name.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/attention.py New 800-line score_mod path: graph building, caching, FFI dispatch, and callback key classification all look structurally correct; minor concerns around the uncacheable-key path and silent graph rebuilds for common callable types.
transformer_engine/jax/csrc/extensions/attention.cpp Adds C++ FFI handlers for score_mod forward/backward with a SHA256-keyed deserialization cache, double-checked locking, and thread-local cuDNN handles; contains a typo in the cache accessor name (getScoreModeGraphCache instead of getScoreModGraphCache).
transformer_engine/jax/attention.py Adds score_mod dispatch branch in fused_attn with thorough validation, custom_vjp wiring, and correct nondiff_argnums usage; early return prevents any interaction with the existing code path.
tests/jax/test_fused_attn.py New score_mod test suite with unit tests for validation, config splitting, cache key stability, and numerical accuracy; _ScoreModSoftcap stores inter-call state on self, which is safe for single sequential traces but fragile if the same instance is reused across multiple fused_attn calls in one trace.
tests/jax/test_distributed_fused_attn.py Adds distributed score_mod test covering DP/TP sharding and backward correctness; imports and structure are consistent with existing distributed test patterns.
build_tools/jax.py Adds a two-phase search for the cuDNN frontend include directory, appended to include_dirs only when found; gracefully silent when not present.
transformer_engine/jax/csrc/extensions/utils.cpp Adds GetCudnnFrontendVersion() exposing CUDNN_FRONTEND_VERSION macro; straightforward and correct.

Sequence Diagram

sequenceDiagram
    participant User
    participant fused_attn
    participant make_config as make_fused_attn_score_mod_config
    participant fwd_rule as _fused_attn_score_mod_fwd_rule
    participant py_cache as _score_mod_graph_cache (Python)
    participant build_graph as _build_score_mod_fwd_graph
    participant cudnn as cudnn.pygraph
    participant ffi as ffi.ffi_call
    participant cpp as C++ FFI Handler
    participant cpp_cache as C++ Graph Cache

    User->>fused_attn: "fused_attn(..., score_mod=fn, ...)"
    fused_attn->>make_config: classify callback key, split tensors/scalars
    make_config-->>fused_attn: _FusedAttnScoreModConfig
    fused_attn->>fwd_rule: _fused_attn_score_mod(qkv, tensors, config)

    Note over fwd_rule,py_cache: JAX trace time
    fwd_rule->>py_cache: lookup (direction, config, avals)
    alt cache miss
        py_cache->>build_graph: _build_score_mod_fwd_graph(avals, config)
        build_graph->>cudnn: "pygraph.sdpa(score_mod=wrapped_fn)"
        cudnn-->>build_graph: serialized graph + workspace_size
        build_graph-->>py_cache: _SerializedScoreModGraph
    end
    py_cache-->>fwd_rule: serialized graph

    fwd_rule->>ffi: ffi_call(te_fused_attn_score_mod_forward_ffi, serialized_graph as attr)

    Note over ffi,cpp_cache: Runtime execution
    ffi->>cpp: FusedAttnScoreModForwardFFI(stream, q,k,v, variadic_tensors, attrs)
    cpp->>cpp_cache: lookup by (device_id, hash0, hash1, frontend_version)
    alt C++ cache miss
        cpp->>cpp: deserialize graph via cudnn_frontend
        cpp->>cpp_cache: store shared_ptr Graph
    end
    cpp->>cpp: "graph->execute(handle, variant_pack, workspace)"
    cpp-->>ffi: output, stats

    fwd_rule-->>User: output (+ residuals for VJP)
Loading

Reviews (5): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +706 to +713
struct ScoreModGraphEntry {
PyObject *py_graph = nullptr;
std::vector<int64_t> user_uids;
std::vector<int64_t> input_uids;
std::vector<int64_t> output_uids;
std::vector<int64_t> scalar_uids;
std::vector<ScoreModScalarStorage> scalar_values;
};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Python reference leak: Py_INCREF without a matching Py_DECREF

ScoreModGraphEntry stores a raw PyObject* and its refcount is bumped at registration (Py_INCREF(entry->py_graph) at line 833), but the struct has no destructor to call Py_DECREF. Because ScoreModGraphRegistry never removes entries either, every cuDNN Python graph object registered here is permanently immortalised — it will never be collected by Python's GC regardless of what the call site does. Over many different attention shapes or graph configurations this accumulates silently. The fix is to add a destructor that acquires the GIL and calls Py_DECREF, or to store a pybind11::object (which manages the refcount automatically) and ensure destruction always happens under the GIL.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vcherepanov-nv This seems like a valid comment from greptile about leaking pygraphs. But I'm also not sure if that is the intended design to prevent GC freeing up a graph too early by mistake before we use it in the XLA C++ FFI. I'm not sure what the best option is here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I'd call it a leak, but yes, currently the cache is process-lifetime. If we ever encounter an issue with it's growth, then we'll need to implement some kind of eviction policy. But it is out of scope of this PR.

Comment on lines +684 to +692
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)

q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape)
k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape)
v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape)
o_dim, o_stride = _bshd_as_bhsd_dim_stride(output_aval.shape)
do_dim, do_stride = _bshd_as_bhsd_dim_stride(doutput_aval.shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 id()-based cache keys can produce false cache hits after GC

_score_mod_callback_cache_key builds its key from id(self_obj) and id(func). Python recycles object addresses after GC, so if a callback instance is collected and a new object (of a different class or with different graph logic) is allocated at the same address, the new config will compare equal to the old one under __eq__. JAX's nondiff-argnum caching then reuses the traced function and graph built for the original callback, silently executing the wrong cuDNN graph. The risk is low for long-lived module-level functions but real for short-lived class instances. Anchoring the key to a non-id stable identifier (e.g., a weakref plus explicit id, or requiring callers to supply an explicit stable key) would eliminate the ambiguity.

Comment on lines +765 to +807
Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id,
const std::vector<void *> &input_ptrs,
const std::vector<void *> &output_ptrs, void *workspace) {
auto entry = GetScoreModGraphEntry(graph_id);
NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), "cuDNN score_mod graph expected ",
entry->input_uids.size(), " inputs but got ", input_ptrs.size());
NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(),
"cuDNN score_mod graph expected at least ", entry->output_uids.size(),
" outputs but got ", output_ptrs.size());

std::unordered_map<int64_t, void *> variant_pack;
for (size_t i = 0; i < entry->input_uids.size(); ++i) {
variant_pack.emplace(entry->input_uids[i], input_ptrs[i]);
}
for (size_t i = 0; i < entry->output_uids.size(); ++i) {
variant_pack.emplace(entry->output_uids[i], output_ptrs[i]);
}
for (size_t i = 0; i < entry->scalar_uids.size(); ++i) {
variant_pack.emplace(entry->scalar_uids[i], entry->scalar_values[i].data.data());
}

std::vector<std::intptr_t> user_ptrs;
user_ptrs.reserve(entry->user_uids.size());
for (const auto uid : entry->user_uids) {
auto it = variant_pack.find(uid);
NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", uid);
user_ptrs.push_back(reinterpret_cast<std::intptr_t>(it->second));
}

auto handle = GetScoreModCudnnHandle();
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
{
pybind11::gil_scoped_acquire gil;
try {
auto graph = pybind11::reinterpret_borrow<pybind11::object>(entry->py_graph);
graph.attr("_execute_with_ptrs")(user_ptrs, reinterpret_cast<std::intptr_t>(workspace),
reinterpret_cast<std::intptr_t>(handle));
} catch (const pybind11::error_already_set &exc) {
NVTE_ERROR("cuDNN score_mod SDPA graph execution failed: ", exc.what());
}
}
return ffi_with_cuda_error_check();
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 GIL held across a CUDA FFI call boundary

ExecuteScoreModGraph acquires pybind11::gil_scoped_acquire while the CUDA stream is live and calls a Python method (_execute_with_ptrs) synchronously. Any other Python thread that holds the GIL and is waiting on CUDA work will deadlock. More broadly, acquiring the GIL inside an XLA/JAX FFI handler — which JAX may dispatch from a non-Python thread — creates a locking inversion risk. This is by-design if cuDNN's Python frontend has no C-level execution path, but the limitation should be documented and the possibility of multi-threaded JAX dispatch should be explicitly considered.

_SCORE_MOD_UID_DQ = 7
_SCORE_MOD_UID_DK = 8
_SCORE_MOD_UID_DV = 9
_SCORE_MOD_FWD_TENSOR_UID_BASE = 1000
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _score_mod_graph_cache and C++ registry grow without bound

_score_mod_graph_cache is a module-level dict that accumulates (graph_id, workspace_size) entries for every unique (direction, config, aval-tuple) seen during tracing, and the C++ ScoreModGraphRegistry holds the corresponding cuDNN graph objects forever. Each entry keeps a Python cuDNN graph alive (and, due to the missing Py_DECREF noted separately, prevents GC). In long-running services or evaluation loops that sweep over many shapes/dtypes, this leads to unbounded cuDNN graph memory accumulation. An LRU eviction strategy or an explicit graph-release API paired with cache invalidation would contain the growth.

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment on lines +143 to +195

def forward(self, graph, score, tensors):
import cudnn # pylint: disable=import-outside-toplevel

self.before_tanh_activation = graph.div(
a=score,
b=tensors["softcap"],
compute_data_type=cudnn.data_type.FLOAT,
)
self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT)
tanh_out = graph.tanh(input=self.before_tanh_activation)
tanh_out.set_data_type(cudnn.data_type.FLOAT)
return graph.mul(
a=tanh_out,
b=tensors["softcap"],
compute_data_type=cudnn.data_type.FLOAT,
)

def backward(self, graph, dscore, tensors):
import cudnn # pylint: disable=import-outside-toplevel

d_tanh_out = graph.mul(
a=dscore,
b=tensors["softcap"],
compute_data_type=cudnn.data_type.FLOAT,
)
d_tanh_out.set_data_type(cudnn.data_type.FLOAT)
d_before_tanh_activation = graph.tanh_backward(
loss=d_tanh_out,
input=self.before_tanh_activation,
compute_data_type=cudnn.data_type.FLOAT,
)
d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT)
return graph.div(
a=d_before_tanh_activation,
b=tensors["softcap"],
compute_data_type=cudnn.data_type.FLOAT,
)


def _reference_attention(
query, key, value, scale, *, causal=False, relative_position=False, softcap=None
):
scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale
if causal:
q_pos = jnp.arange(query.shape[1])[:, None]
kv_pos = jnp.arange(key.shape[1])[None, :]
scores = jnp.where(q_pos >= kv_pos, scores, -1e9)
if relative_position:
q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None]
kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :]
scores = scores + q_pos - kv_pos
if softcap is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _ScoreModSoftcap.backward relies on undocumented cuDNN callback ordering

backward reads self.before_tanh_activation, which is written by forward during sdpa_backward graph construction. This is only safe if cuDNN's sdpa_backward guarantees it calls score_mod (the forward callback) before score_mod_bprop (the backward callback) within the same graph-build invocation. If that order is ever reversed, self.before_tanh_activation is None at the time backward runs, and graph.tanh_backward(input=None, ...) will fail silently or crash at execution time rather than at graph-build time.

vcherepanov-nv and others added 2 commits May 15, 2026 03:35
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
_SCORE_MOD_UID_K = 2
_SCORE_MOD_UID_V = 3
_SCORE_MOD_UID_O = 4
_SCORE_MOD_UID_STATS = 5
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do these _SCORE_MOD_UID_XXXXX come from? Is it a C/C++ enum? If so, we should make this a Python Enum that derives its values from the C/C++ enum exposed via pybind

See this enum for reference:

NO_SCALING = JAXX_Scaling_Mode.NO_SCALING

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are just arbitrary numbers, really. In fact, assigning UIDs is completely optional, cuDNN can auto-assign. UIDs are added here just for determinism / to make future troubleshooting easier, e.g. so that we know that 4 is the output tensor.

Comment thread transformer_engine/jax/cpp_extensions/attention.py
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment on lines +706 to +713
struct ScoreModGraphEntry {
PyObject *py_graph = nullptr;
std::vector<int64_t> user_uids;
std::vector<int64_t> input_uids;
std::vector<int64_t> output_uids;
std::vector<int64_t> scalar_uids;
std::vector<ScoreModScalarStorage> scalar_values;
};
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vcherepanov-nv This seems like a valid comment from greptile about leaking pygraphs. But I'm also not sure if that is the intended design to prevent GC freeing up a graph too early by mistake before we use it in the XLA C++ FFI. I'm not sure what the best option is here

Comment thread transformer_engine/jax/attention.py
vcherepanov-nv and others added 2 commits May 18, 2026 23:54
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
status.get_message());

std::lock_guard<std::mutex> lock(ScoreModGraphCacheMutex());
auto &cache = ScoreModGraphCache();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SR: Can we name ScoreModGraphCache() something like getScoreModeGraphCache()? On my first read-thru I read ScoreModeGraphCache as a constructing a new object and thought this was always using a fresh cache.

NVTE_CHECK(status.is_good(), "Failed to deserialize cuDNN score_mod SDPA graph: ",
status.get_message());

std::lock_guard<std::mutex> lock(ScoreModGraphCacheMutex());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here about ScoreModGraphCacheMutex() -> getScoreModGraphCacheMutex()

vcherepanov-nv and others added 2 commits May 19, 2026 00:32
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants