Skip to content

Add guard at lowest JAX version that still supports triton kernel calling#2741

Open
tdophung wants to merge 7 commits intoNVIDIA:mainfrom
tdophung:triton_jax_bwd_compat
Open

Add guard at lowest JAX version that still supports triton kernel calling#2741
tdophung wants to merge 7 commits intoNVIDIA:mainfrom
tdophung:triton_jax_bwd_compat

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Mar 6, 2026

Description

To provide backward compatibility with older jax versions, we need to have a safeguard in place for jax versions too old to work with triton kernel calling. Using Claude Code to automate bisecting through JAX toolbox nightly containers between Sep 1, 2025 and Oct 1, 2025 (*), I have found that the first passing version of the container starts on Sep 24th, 2025, corresponding to jax 0.8.0.dev20250924 hence the guard is put there.

(*) the date range is determined by having a data point that the officially released jax toolbox (nvcr.io/nvidia/jax:25.10-py3 fails while the nightly jax container on Oct 1st passed.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Handles jax < 0.8.0.dev20250924 segfault error when calling triton kernels frfom JAX side

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Collaborator Author

tdophung commented Mar 6, 2026

/te-ci jax

@tdophung tdophung changed the title add guard at bisected jax version where lower is segfault Add guard at lowest JAX version that still supports triton kernel calling Mar 6, 2026
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, LGTM pending CI, thanks!

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 6, 2026

Greptile Summary

This PR successfully introduces a minimum JAX version guard (>= 0.8.0) to prevent segfaults when dispatching Triton kernels from JAX on older jaxlib versions. The implementation is robust:

Key changes:

  • transformer_engine/jax/version_utils.py (new): Centralizes _jax_version_meet_requirement, TRITON_EXTENSION_MIN_JAX_VERSION = "0.8.0", and the public is_triton_extension_supported() helper, reused across quantize and triton_extensions modules.
  • triton_extensions/utils.py: Version guard placed before the gpu_triton import block as a belt-and-suspenders measure with clear error messaging.
  • quantize/helper.py: Refactored to import _jax_version_meet_requirement from the new shared module.
  • Test files: Module-level require_triton() guards prevent triton imports on old JAX; triton-marked tests are batch-skipped via conftest on old JAX versions.
  • Version threshold: Deliberately set to official release "0.8.0" rather than nightly builds for stability.

All critical issues (missing module, guard placement, pytest flags, hardcoded versions) have been properly addressed. The PR is safe to merge.

Confidence Score: 5/5

  • The PR introduces robust safeguards for old JAX versions with clear error messages and proper test skipping; no regressions or breaking changes.
  • All critical issues from prior reviews have been resolved. The version guard is correctly positioned before GPU triton imports. Test skip mechanisms use proper pytest flags. Module imports are properly structured. The shared version utilities are correctly implemented and integrated across all affected modules. The code is production-ready.
  • No files require special attention.

Last reviewed commit: 93f342e

# jaxlib < 0.8.0.dev20250924 segfaults in pxla.py during Triton kernel execution
# (bisected: last known segfault = jax-2025-09-23, first known pass = jax-2025-09-24).
_JAXLIB_MIN_VERSION = "0.8.0.dev20250924"
import jaxlib # noqa: E402
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import placed mid-module when jax is already available

jax is already imported at line 42, and jax.__version__ always matches jaxlib.__version__ (they are co-released together). The mid-file import jaxlib requires a # noqa: E402 suppression and adds an extra package import for information that's already available via the existing jax import.

Alternatively, if you specifically want jaxlib.__version__ for precision, this import should be hoisted to the top-of-file import block alongside the other jax/jaxlib imports rather than sitting between executable module-level code.

Suggested change
import jaxlib # noqa: E402
import jaxlib # jaxlib version guard below

(move this line to the top of the file alongside the other imports at lines 41–43)

Or, to avoid the extra import entirely:

if version.parse(jax.__version__) < version.parse(_JAXLIB_MIN_VERSION):

since jax.__version__ == jaxlib.__version__ for any properly-installed JAX environment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also about to comment about jax vs jaxlib.

Besides, we have an existing helper here as well https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/quantize/helper.py#L72-L78, perhaps we should find a way to repurpose this helper, especially since we will need to have jax guard all over places in the future.

On the side note, should we check for a release version, i.e., >= 0.8.0 instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great points! Yes, reusing _jax_version_meet_requirement would be much cleaner. You could move it to a shared utility module if multiple files need it.

Regarding dev vs release version:

  • >= 0.8.0.dev20250924 allows: 0.8.0.dev20250924, later dev versions, and all 0.8.x releases
  • >= 0.8.0 only allows official releases (excludes all dev versions)

Since dev versions sort before releases in PEP 440, using >= 0.8.0 would actually block 0.8.0.dev20250924 even though your bisection shows it works.

Recommendation: If most users will use official releases (not nightlies), >= 0.8.0 is cleaner and safer. Dev/nightly users typically expect to handle version quirks. But if you want to explicitly support those working dev versions, keep >= 0.8.0.dev20250924.

For the code:

# At top of file with other imports
from .quantize.helper import _jax_version_meet_requirement

# Later in the file
_JAXLIB_MIN_VERSION = "0.8.0"  # or keep dev version if needed
if not _jax_version_meet_requirement(_JAXLIB_MIN_VERSION):
    raise RuntimeError(f"JAX >= {_JAXLIB_MIN_VERSION} required for Triton kernel support")

This eliminates the mid-file import and reuses existing infrastructure.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phu0ngng I think greptile has analyzed it pretty thoroughly on checking for > dev_version or >release_version. So the decision is between supporting dev versions after 0924 or not. I think we should support these dev versions, just in case other teams have not updated their pipeline to use a newer/release version. The ugly code to have the dev version instead of release is not that ugly and does not need any maintenance anyways

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regarding the other poitns (jax vs jaxlib) and using the _jax_version_meets_requirement : I agree and will fix

Copy link
Collaborator

@phu0ngng phu0ngng Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, for the older version, I would prefer to guard against release versions for the two reasons. First, it is very unlikely that users will update their JAX version to another old dev version, but rather another release version. Second, I would prefer not to have to keep track of dev versions in the support matrix.

On the other side, when we want a guard but need the latest unreleased features, we could add a guard against a dev version, but move to a release version eventually when it is released.

@jberchtold-nvidia for additional opinion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with versioning against an official release version. The official release versions are tested more thoroughly than a nightly dev version, so I'd rather have a user be overly cautious and upgrade JAX slightly further than have them only upgrade the minimal amount to a dev version and potentially run into other issues on the particular nightly dev version.

With versioning to an official release version, we will be missing out on a certain period between [jax_dev_release_where_it_is_fixed, jax_official_release_when_it_is_fixed), but I think that is an okay tradeoff for the additional stability I'd expect in an official release compared to a nightly dev release.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed check to release version 0.8.0

) from e

# Minimum jaxlib version required for Triton kernel dispatch to work correctly.
# jaxlib < 0.8.0.dev20250924 segfaults in pxla.py during Triton kernel execution
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! I think we should expose a function here, something like is_triton_extension_supported() so us or users can check as well. Then in our Triton tests we can add something like the following block at the top:

def test_some_triton_extension(...):
    if not is_triton_extension_supported():
        pytest.skip("... same info message about triton jax compatibility in older versions ...")
       return
    main test code

Similar to how we guard tests based on compute arch

Without this change, I think the tests on old containers would still fail, though the new error message is much more informative then before!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking more about this and I realize that we could also use pytest markers to mark certain tests as triton kernel tests where we can check for the cached jaxlib version (with that exported is_triton_extension_supported() function) without having to add the same boilerplate to each test.

Another approach is to do it for each test file (by doing conftest.py pytest hook, or 1 call to the skip function with is_triton_extension_supported at the beginning). Currently there are only 4 files: test_permutation.py, test_fused_router.py, anfd the distributed versions of each. However, I think we might want to use either NCCL deepEP or cuTile in replacement of these triton kernels in the future, or keep a combination of them, so this whole file approach might not be ideal, and I would rather do the marker approach above.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point about switching backends, the pytest marker approach sounds good to me!

@jberchtold-nvidia jberchtold-nvidia self-requested a review March 6, 2026 16:01
…lper.py

Signed-off-by: tdophung <tdophung@nvidia.com>
tdophung and others added 2 commits March 9, 2026 16:16
- Add version_utils.py with is_triton_extension_supported() checking JAX >= 0.8.0
  (release version, not dev snapshot) and TRITON_EXTENSION_MIN_JAX_VERSION constant
- Add pytest.mark.triton marker and conftest hook to skip marked tests on old JAX
- Add require_triton() for module-level skipping in test files
- Rewrite triton_extensions to use is_triton_extension_supported() instead of
  direct jaxlib dev-version comparison

Signed-off-by: tdophung <tdophung@nvidia.com>
…d re-export, revert test.sh

- require_triton(): add allow_module_level=True to pytest.skip() so module-level
  calls on old JAX produce a proper skip instead of a collection failure
- Remove is_triton_extension_supported from triton_extensions/utils.py __all__:
  importing triton_extensions on JAX < 0.8.0 raises immediately, so re-exporting
  the check from there defeats its purpose; callers should import directly from
  transformer_engine.jax.version_utils
- Revert qa/L0_jax_lint/test.sh TE_PATH to /opt/transformerengine (local dev
  path was accidentally committed; pass TE_PATH= at invocation time instead)

Signed-off-by: tdophung <tdophung@nvidia.com>
…l__ and hardcoded version

- Move is_triton_extension_supported() guard before the gpu_triton import block
  with a comment clarifying the segfault is at dispatch time, not import time
- Remove _jax_version_meet_requirement from version_utils __all__ (private helper,
  not a public API; callers import it explicitly as needed)
- Use TRITON_EXTENSION_MIN_JAX_VERSION constant in conftest marker description
  instead of hardcoded '0.8.0'

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Collaborator Author

/te-ci jax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants