[JAX] Support for cuDNN-backed flex attention#2985
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces a cuDNN-frontend-backed
Confidence Score: 4/5The core score_mod dispatch, VJP wiring, C++ deserialization, and graph execution are all correct; the only new finding is a silent performance trap when lambdas or closures are used as score_mod callbacks — no wrong results, no crashes. The new dispatch path has clean separation of concerns, thorough input validation, and correct JAX custom_vjp plumbing. The C++ double-checked-locking cache and thread-local handles are sound. No data-corruption or execution-correctness bugs were found beyond what is already under discussion in existing threads. transformer_engine/jax/cpp_extensions/attention.py — the uncacheable-key path in _score_mod_callback_cache_key silently skips graph caching for lambdas and functools.partial without any diagnostic; transformer_engine/jax/csrc/extensions/attention.cpp — minor typo in getScoreModeGraphCache accessor name. Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant fused_attn
participant make_config as make_fused_attn_score_mod_config
participant fwd_rule as _fused_attn_score_mod_fwd_rule
participant py_cache as _score_mod_graph_cache (Python)
participant build_graph as _build_score_mod_fwd_graph
participant cudnn as cudnn.pygraph
participant ffi as ffi.ffi_call
participant cpp as C++ FFI Handler
participant cpp_cache as C++ Graph Cache
User->>fused_attn: "fused_attn(..., score_mod=fn, ...)"
fused_attn->>make_config: classify callback key, split tensors/scalars
make_config-->>fused_attn: _FusedAttnScoreModConfig
fused_attn->>fwd_rule: _fused_attn_score_mod(qkv, tensors, config)
Note over fwd_rule,py_cache: JAX trace time
fwd_rule->>py_cache: lookup (direction, config, avals)
alt cache miss
py_cache->>build_graph: _build_score_mod_fwd_graph(avals, config)
build_graph->>cudnn: "pygraph.sdpa(score_mod=wrapped_fn)"
cudnn-->>build_graph: serialized graph + workspace_size
build_graph-->>py_cache: _SerializedScoreModGraph
end
py_cache-->>fwd_rule: serialized graph
fwd_rule->>ffi: ffi_call(te_fused_attn_score_mod_forward_ffi, serialized_graph as attr)
Note over ffi,cpp_cache: Runtime execution
ffi->>cpp: FusedAttnScoreModForwardFFI(stream, q,k,v, variadic_tensors, attrs)
cpp->>cpp_cache: lookup by (device_id, hash0, hash1, frontend_version)
alt C++ cache miss
cpp->>cpp: deserialize graph via cudnn_frontend
cpp->>cpp_cache: store shared_ptr Graph
end
cpp->>cpp: "graph->execute(handle, variant_pack, workspace)"
cpp-->>ffi: output, stats
fwd_rule-->>User: output (+ residuals for VJP)
Reviews (5): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| struct ScoreModGraphEntry { | ||
| PyObject *py_graph = nullptr; | ||
| std::vector<int64_t> user_uids; | ||
| std::vector<int64_t> input_uids; | ||
| std::vector<int64_t> output_uids; | ||
| std::vector<int64_t> scalar_uids; | ||
| std::vector<ScoreModScalarStorage> scalar_values; | ||
| }; |
There was a problem hiding this comment.
Python reference leak:
Py_INCREF without a matching Py_DECREF
ScoreModGraphEntry stores a raw PyObject* and its refcount is bumped at registration (Py_INCREF(entry->py_graph) at line 833), but the struct has no destructor to call Py_DECREF. Because ScoreModGraphRegistry never removes entries either, every cuDNN Python graph object registered here is permanently immortalised — it will never be collected by Python's GC regardless of what the call site does. Over many different attention shapes or graph configurations this accumulates silently. The fix is to add a destructor that acquires the GIL and calls Py_DECREF, or to store a pybind11::object (which manages the refcount automatically) and ensure destruction always happens under the GIL.
There was a problem hiding this comment.
@vcherepanov-nv This seems like a valid comment from greptile about leaking pygraphs. But I'm also not sure if that is the intended design to prevent GC freeing up a graph too early by mistake before we use it in the XLA C++ FFI. I'm not sure what the best option is here
There was a problem hiding this comment.
Not sure if I'd call it a leak, but yes, currently the cache is process-lifetime. If we ever encounter an issue with it's growth, then we'll need to implement some kind of eviction policy. But it is out of scope of this PR.
| intermediate_data_type=cudnn.data_type.FLOAT, | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
|
|
||
| q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape) | ||
| k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape) | ||
| v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape) | ||
| o_dim, o_stride = _bshd_as_bhsd_dim_stride(output_aval.shape) | ||
| do_dim, do_stride = _bshd_as_bhsd_dim_stride(doutput_aval.shape) |
There was a problem hiding this comment.
id()-based cache keys can produce false cache hits after GC
_score_mod_callback_cache_key builds its key from id(self_obj) and id(func). Python recycles object addresses after GC, so if a callback instance is collected and a new object (of a different class or with different graph logic) is allocated at the same address, the new config will compare equal to the old one under __eq__. JAX's nondiff-argnum caching then reuses the traced function and graph built for the original callback, silently executing the wrong cuDNN graph. The risk is low for long-lived module-level functions but real for short-lived class instances. Anchoring the key to a non-id stable identifier (e.g., a weakref plus explicit id, or requiring callers to supply an explicit stable key) would eliminate the ambiguity.
| Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id, | ||
| const std::vector<void *> &input_ptrs, | ||
| const std::vector<void *> &output_ptrs, void *workspace) { | ||
| auto entry = GetScoreModGraphEntry(graph_id); | ||
| NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), "cuDNN score_mod graph expected ", | ||
| entry->input_uids.size(), " inputs but got ", input_ptrs.size()); | ||
| NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(), | ||
| "cuDNN score_mod graph expected at least ", entry->output_uids.size(), | ||
| " outputs but got ", output_ptrs.size()); | ||
|
|
||
| std::unordered_map<int64_t, void *> variant_pack; | ||
| for (size_t i = 0; i < entry->input_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->input_uids[i], input_ptrs[i]); | ||
| } | ||
| for (size_t i = 0; i < entry->output_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->output_uids[i], output_ptrs[i]); | ||
| } | ||
| for (size_t i = 0; i < entry->scalar_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->scalar_uids[i], entry->scalar_values[i].data.data()); | ||
| } | ||
|
|
||
| std::vector<std::intptr_t> user_ptrs; | ||
| user_ptrs.reserve(entry->user_uids.size()); | ||
| for (const auto uid : entry->user_uids) { | ||
| auto it = variant_pack.find(uid); | ||
| NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", uid); | ||
| user_ptrs.push_back(reinterpret_cast<std::intptr_t>(it->second)); | ||
| } | ||
|
|
||
| auto handle = GetScoreModCudnnHandle(); | ||
| NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); | ||
| { | ||
| pybind11::gil_scoped_acquire gil; | ||
| try { | ||
| auto graph = pybind11::reinterpret_borrow<pybind11::object>(entry->py_graph); | ||
| graph.attr("_execute_with_ptrs")(user_ptrs, reinterpret_cast<std::intptr_t>(workspace), | ||
| reinterpret_cast<std::intptr_t>(handle)); | ||
| } catch (const pybind11::error_already_set &exc) { | ||
| NVTE_ERROR("cuDNN score_mod SDPA graph execution failed: ", exc.what()); | ||
| } | ||
| } | ||
| return ffi_with_cuda_error_check(); | ||
| } |
There was a problem hiding this comment.
GIL held across a CUDA FFI call boundary
ExecuteScoreModGraph acquires pybind11::gil_scoped_acquire while the CUDA stream is live and calls a Python method (_execute_with_ptrs) synchronously. Any other Python thread that holds the GIL and is waiting on CUDA work will deadlock. More broadly, acquiring the GIL inside an XLA/JAX FFI handler — which JAX may dispatch from a non-Python thread — creates a locking inversion risk. This is by-design if cuDNN's Python frontend has no C-level execution path, but the limitation should be documented and the possibility of multi-threaded JAX dispatch should be explicitly considered.
| _SCORE_MOD_UID_DQ = 7 | ||
| _SCORE_MOD_UID_DK = 8 | ||
| _SCORE_MOD_UID_DV = 9 | ||
| _SCORE_MOD_FWD_TENSOR_UID_BASE = 1000 |
There was a problem hiding this comment.
_score_mod_graph_cache and C++ registry grow without bound
_score_mod_graph_cache is a module-level dict that accumulates (graph_id, workspace_size) entries for every unique (direction, config, aval-tuple) seen during tracing, and the C++ ScoreModGraphRegistry holds the corresponding cuDNN graph objects forever. Each entry keeps a Python cuDNN graph alive (and, due to the missing Py_DECREF noted separately, prevents GC). In long-running services or evaluation loops that sweep over many shapes/dtypes, this leads to unbounded cuDNN graph memory accumulation. An LRU eviction strategy or an explicit graph-release API paired with cache invalidation would contain the growth.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
|
|
||
| def forward(self, graph, score, tensors): | ||
| import cudnn # pylint: disable=import-outside-toplevel | ||
|
|
||
| self.before_tanh_activation = graph.div( | ||
| a=score, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
| self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) | ||
| tanh_out = graph.tanh(input=self.before_tanh_activation) | ||
| tanh_out.set_data_type(cudnn.data_type.FLOAT) | ||
| return graph.mul( | ||
| a=tanh_out, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
|
|
||
| def backward(self, graph, dscore, tensors): | ||
| import cudnn # pylint: disable=import-outside-toplevel | ||
|
|
||
| d_tanh_out = graph.mul( | ||
| a=dscore, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
| d_tanh_out.set_data_type(cudnn.data_type.FLOAT) | ||
| d_before_tanh_activation = graph.tanh_backward( | ||
| loss=d_tanh_out, | ||
| input=self.before_tanh_activation, | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
| d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) | ||
| return graph.div( | ||
| a=d_before_tanh_activation, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
|
|
||
|
|
||
| def _reference_attention( | ||
| query, key, value, scale, *, causal=False, relative_position=False, softcap=None | ||
| ): | ||
| scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale | ||
| if causal: | ||
| q_pos = jnp.arange(query.shape[1])[:, None] | ||
| kv_pos = jnp.arange(key.shape[1])[None, :] | ||
| scores = jnp.where(q_pos >= kv_pos, scores, -1e9) | ||
| if relative_position: | ||
| q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None] | ||
| kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :] | ||
| scores = scores + q_pos - kv_pos | ||
| if softcap is not None: |
There was a problem hiding this comment.
_ScoreModSoftcap.backward relies on undocumented cuDNN callback ordering
backward reads self.before_tanh_activation, which is written by forward during sdpa_backward graph construction. This is only safe if cuDNN's sdpa_backward guarantees it calls score_mod (the forward callback) before score_mod_bprop (the backward callback) within the same graph-build invocation. If that order is ever reversed, self.before_tanh_activation is None at the time backward runs, and graph.tanh_backward(input=None, ...) will fail silently or crash at execution time rather than at graph-build time.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
| _SCORE_MOD_UID_K = 2 | ||
| _SCORE_MOD_UID_V = 3 | ||
| _SCORE_MOD_UID_O = 4 | ||
| _SCORE_MOD_UID_STATS = 5 |
There was a problem hiding this comment.
Where do these _SCORE_MOD_UID_XXXXX come from? Is it a C/C++ enum? If so, we should make this a Python Enum that derives its values from the C/C++ enum exposed via pybind
See this enum for reference:
There was a problem hiding this comment.
These are just arbitrary numbers, really. In fact, assigning UIDs is completely optional, cuDNN can auto-assign. UIDs are added here just for determinism / to make future troubleshooting easier, e.g. so that we know that 4 is the output tensor.
| struct ScoreModGraphEntry { | ||
| PyObject *py_graph = nullptr; | ||
| std::vector<int64_t> user_uids; | ||
| std::vector<int64_t> input_uids; | ||
| std::vector<int64_t> output_uids; | ||
| std::vector<int64_t> scalar_uids; | ||
| std::vector<ScoreModScalarStorage> scalar_values; | ||
| }; |
There was a problem hiding this comment.
@vcherepanov-nv This seems like a valid comment from greptile about leaking pygraphs. But I'm also not sure if that is the intended design to prevent GC freeing up a graph too early by mistake before we use it in the XLA C++ FFI. I'm not sure what the best option is here
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
| status.get_message()); | ||
|
|
||
| std::lock_guard<std::mutex> lock(ScoreModGraphCacheMutex()); | ||
| auto &cache = ScoreModGraphCache(); |
There was a problem hiding this comment.
SR: Can we name ScoreModGraphCache() something like getScoreModeGraphCache()? On my first read-thru I read ScoreModeGraphCache as a constructing a new object and thought this was always using a fresh cache.
| NVTE_CHECK(status.is_good(), "Failed to deserialize cuDNN score_mod SDPA graph: ", | ||
| status.get_message()); | ||
|
|
||
| std::lock_guard<std::mutex> lock(ScoreModGraphCacheMutex()); |
There was a problem hiding this comment.
Same here about ScoreModGraphCacheMutex() -> getScoreModGraphCacheMutex()
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Description
This PR introduces an alternative code path for the FusedAttention backend for JAX.
The user can specify score_mod and score_mod_bprop functions, which get routed to the corresponding parameters of the sdpa and sdpa_backward calls to cuDNN FE.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: